Commit c217436c authored by szr712's avatar szr712

修改拼音作为token中遇到的错误

parent a0e5b946
......@@ -27,3 +27,11 @@ CUDA_VISIBLE_DEVICES=1 nohup python train2.py -src_data data/pinyin_2.txt -trg_d
CUDA_VISIBLE_DEVICES=3 python translate_pkl.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_35_0.055335590355098246 -pkl_dir weights/token_classification/11-09_22:00:55 -test_dir data/pkl/test-pkl -result_dir data/pkl/test-pkl-result
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_split.txt -trg_data data/hanzi_split.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification_split -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
CUDA_VISIBLE_DEVICES=1 python translate2.py -load_weights weights/token_classification_split_2/11-19_17:16:18/token_classification_split_2_5_0.05776993067935109 -pkl_dir weights/token_classification_split_2/11-19_17:16:18 -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
CUDA_VISIBLE_DEVICES=4 nohup python translate_file2.py -load_weights weights/token_classification_split_3/11-22_21:56:11/token_classification_split_3_25_0.029638311734888702 -pkl_dir weights/token_classification_split_3/11-22_21:56:11 -test_dir data/test_data/pinyin_split -result_dir data/test_data/result_split -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log1 2>&1 &
CUDA_VISIBLE_DEVICES=3 python eval_model.py -load_weights weights/token_classification_split_3/11-22_21:56:11/token_classification_split_3_1_0.09703897424042225 -pkl_dir weights/token_classification_split_3/11-22_21:56:11 -dev_dir data/dev -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_split.txt -trg_data data/hanzi_split.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification_split_4 -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
\ No newline at end of file
#-*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
#!/usr/bin/python2
"""
Before running this code, make sure that you've downloaded Leipzig Chinese Corpus
......@@ -29,17 +29,26 @@ import itertools
def wenzi2pinyin(text):
pinyin_list = lazy_pinyin(text, style=Style.TONE3)
# print(pinyin_list)
# tones_list = [int(py[-1]) if py[-1].isdigit()
# else 0 for py in pinyin_list]
# pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
tones_list = [int(py[-1]) if py[-1].isdigit()
else 0 for py in pinyin_list]
pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
return pinyin_list
return pinyin_list, tones_list
def split_initials_finals(pinyin):
if get_initials(pinyin, False)!="":
return [get_initials(pinyin, False), get_finals(pinyin, False)]
def split_initials_finals(pinyin, tone, char):
strict = True
if not ('\u4e00' <= char <= '\u9fff'):
# return [a for a in get_initials(pinyin, strict)+get_finals(pinyin, strict)]
return [a for a in pinyin]
else:
pinyin = pinyin.replace("v", "ü")
if get_initials(pinyin, strict) != "":
return [get_initials(pinyin, strict), get_finals(pinyin, strict)+str(tone)]
else:
return [get_finals(pinyin, False)]
return [get_finals(pinyin, strict)+str(tone)]
def align(sent):
'''
......@@ -52,21 +61,21 @@ def align(sent):
pnyns = pinyin.get_pinyin(sent, " ").split()
hanzis = []
x=0
x = 0
for i in range(len(pnyns)):
if i+x<len(sent.replace(" ", "")):
char=sent[i+x]
p=pnyns[i]
if i+x < len(sent.replace(" ", "")):
char = sent[i+x]
p = pnyns[i]
if '\u4e00' <= char <= '\u9fa5':
hanzis.extend([char] + ["_"] * (len(p) - 1))
else:
while not '\u4e00' <= char <= '\u9fa5':
hanzis.extend([char])
x=x+1
if x+i>=len(sent.replace(" ", "")):
x = x+1
if x+i >= len(sent.replace(" ", "")):
break
char=sent[i+x]
x=x-1
char = sent[i+x]
x = x-1
# for char, p in zip(sent.replace(" ", ""), pnyns):
# if '\u4e00' <= char <= '\u9fa5':
......@@ -75,9 +84,11 @@ def align(sent):
pnyns = "".join(pnyns)
hanzis = "".join(hanzis)
assert len(pnyns) == len(hanzis), "The hanzis and the pinyins must be the same in length."
assert len(pnyns) == len(
hanzis), "The hanzis and the pinyins must be the same in length."
return pnyns, hanzis
def align2(sent):
'''
Args:
......@@ -87,21 +98,30 @@ def align2(sent):
A tuple of pinyin and chinese sentence.
'''
# pnyns = pinyin.get_pinyin(sent, " ").split()
pinyins=wenzi2pinyin(sent)
pnyns = [split_initials_finals(a) for a in pinyins]
pinyins, tones = wenzi2pinyin(sent)
pnyns = []
i = 0
for pinyin, tone in zip(pinyins, tones):
if '\u4e00' <= sent[i] <= '\u9fa5':
pnyns.append(split_initials_finals(pinyin, tone, sent[i]))
i += 1
else:
pnyns.append(split_initials_finals(
pinyin, tone, sent[i:i+len(pinyin)]))
i += len(pinyin)
hanzis = []
x=0
x = 0
for i in range(len(pnyns)):
if i+x<len(sent.replace(" ", "")):
char=sent[i+x]
p=pnyns[i]
if i+x < len(sent.replace(" ", "")):
char = sent[i+x]
p = pnyns[i]
if '\u4e00' <= char <= '\u9fa5':
hanzis.extend([char] + ["_"] * (len(p) - 1))
else:
for q in p:
hanzis.append(q)
x+=len(q)
x-=1
x += len(q)
x -= 1
# word=""
# while not '\u4e00' <= char <= '\u9fa5':
......@@ -129,34 +149,38 @@ def align2(sent):
# assert len(pnyns.split(" ")) == len(hanzis.split(" ")), "The hanzis and the pinyins must be the same in length."
return pnyns, hanzis
def clean(text):
# if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
# # if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
# return ""
text = regex.sub(u"[^ \p{\u4e00-\u9fa5}。,!?]", "", text)
text_new=""
flag=0
text_new = ""
flag = 0
for char in text:
if char !="。" and char !="," and char !="!" and char!="?":
flag=0
text_new=text_new+char
if char != "。" and char != "," and char != "!" and char != "?":
flag = 0
text_new = text_new+char
elif not flag:
flag=1
text_new=text_new+char
flag = 1
text_new = text_new+char
# while ",," in text:
# text=text.replace(",,",",")
if len(text_new)<10: return ""
if len(text_new) < 10:
return ""
return text_new
def build_corpus(src_file,pinyin_file,hanzi_file):
pinyin_list=[]
hanzi_list=[]
def build_corpus(src_file, pinyin_file, hanzi_file):
pinyin_list = []
hanzi_list = []
# with codecs.open("data/zho_news_2007-2009_1M-sentences.txt", 'r', 'utf-8') as fin:
with codecs.open(src_file, 'r', 'utf-8') as fin:
i = 1
while 1:
line = fin.readline()
if not line: break
if not line:
break
try:
# idx, sent = line.strip().split("\t")
......@@ -167,7 +191,7 @@ def build_corpus(src_file,pinyin_file,hanzi_file):
# pnyns, hanzis = align(sent)
# fout.write(u"{}\t{}\t{}\n".format(idx, pnyns, hanzis))
sent = line.strip()
sent = sent.replace(" ","")
sent = sent.replace(" ", "")
# sent = clean(sent)
if len(sent) > 0:
pnyns, hanzis = align2(sent)
......@@ -177,12 +201,19 @@ def build_corpus(src_file,pinyin_file,hanzi_file):
traceback.print_exc()
continue # it's okay as we have a pretty big corpus!
if i % 10000 == 0: print(i, )
if i % 10000 == 0:
print(i, )
i += 1
with open(pinyin_file,'w',encoding='utf-8') as f:
with open(pinyin_file, 'w', encoding='utf-8') as f:
f.write("\n".join(pinyin_list))
with open(hanzi_file,'w',encoding='utf-8') as f:
with open(hanzi_file, 'w', encoding='utf-8') as f:
f.write("\n".join(hanzi_list))
if __name__ == "__main__":
build_corpus("./data/train_set_total.txt","./data/pinyin_spilit.txt","./data/hanzi_spilit.txt"); print("Done")
# with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
# yunmus=f.readlines()
# yunmus=[a.strip() for a in yunmus]
build_corpus("./data/train_set_total.txt",
"./data/pinyin_split.txt", "./data/hanzi_split.txt")
print("Done")
"""
计算CER
需比对文字放入preFile文件夹下
原文放入textFile文件夹下
"""
import os
import re
from tqdm import tqdm
import multiprocess
import math
import time
import distance
def ishan(text):
"""去除输入字符串中除中文字符串的内容
Args:
text (str): 字符串
Returns:
str: 去除非中文字符后的字符串
"""
# for python 3.x
# sample: ishan('一') == True, ishan('我&&你') == False
result= [char if '\u4e00' <= char and char<= '\u9fff' else "" for char in text]
return "".join(result)
def cer(preFile,textFile):
# preList=getList(preFile,"./preFile")
# textList=getList(textFile,"./textFile")
# for a,b in zip(textList,preList):
# print('pred: {}, gt: {}'.format(b, a))
for pre in os.listdir(preFile):
# text=pre[:-11]+".txt"
text=pre
print("filename:{}".format(pre))
preList = []
textList = []
with open(os.path.join(preFile, pre), "r", encoding="utf-8") as fw:
preList = fw.readlines()
with open(os.path.join(textFile, text), "r", encoding="utf-8") as fw:
textList = fw.readlines()
total_edit_distance, num_chars = 0, 0
for pred, expected in tqdm(zip(preList, textList)):
pred = ishan(pred)
expected = ishan(expected)
edit_distance = distance.levenshtein(expected, pred)
total_edit_distance += edit_distance
num_chars += len(expected)
print("Total CER: {}/{}={}\n".format(total_edit_distance,
num_chars,
round(float(total_edit_distance)/num_chars, 5)))
if __name__ == "__main__":
preFile = "./data/test_data/result_split"
textFile = "./data/test_data/hanzi"
cer(preFile,textFile)
\ No newline at end of file
......@@ -101,7 +101,8 @@ def cer_multi(num_process,preFile,textFile):
# print('pred: {}, gt: {}'.format(b, a))
for pre in os.listdir(preFile):
text=pre[:-11]+".txt"
# text=pre[:-11]+".txt"
text=pre
preList = []
textList = []
with open(os.path.join(preFile, pre), "r", encoding="utf-8") as fw:
......@@ -144,7 +145,7 @@ def cer_multi(num_process,preFile,textFile):
if __name__ == "__main__":
num_process = 128
preFile = "./data/test_data/result_token_classification"
preFile = "./data/test_data/result_split"
textFile = "./data/test_data/hanzi"
cer_multi(num_process,preFile,textFile)
\ No newline at end of file
import itertools
import os
from tqdm import tqdm
from build_corpus import split_initials_finals, wenzi2pinyin
hanzi_dir="./data/test_data/hanzi"
pinyin_dir="./data/test_data/pinyin_split"
with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
yunmus=f.readlines()
yunmus=[a.strip() for a in yunmus]
for file in os.listdir(hanzi_dir):
print(file)
with open(os.path.join(hanzi_dir,file),'r',encoding="utf-8") as f:
contents=f.readlines()
result=[]
for line in tqdm(contents):
sent = line.strip()
sent = sent.replace(" ","")
pinyins,tones=wenzi2pinyin(sent)
pnyns=[]
i=0
for pinyin,tone in zip(pinyins,tones):
if '\u4e00' <= sent[i] <= '\u9fa5':
pnyns.append(split_initials_finals(pinyin,tone,sent[i]))
i+=1
else:
pnyns.append(split_initials_finals(pinyin,tone,sent[i:i+len(pinyin)]))
i+=len(pinyin)
pnyns = " ".join(list(itertools.chain.from_iterable(pnyns)))
result.append(pnyns)
with open(os.path.join(pinyin_dir,file),"w",encoding="utf-8") as f:
f.write("\n".join(result))
\ No newline at end of file
......@@ -76,6 +76,8 @@ def main():
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-floyd', action='store_true')
parser.add_argument("-dev_dir", type=str, required=True)
parser.add_argument('-src_voc')
parser.add_argument('-trg_voc')
opt = parser.parse_args()
......@@ -86,10 +88,10 @@ def main():
SRC, TRG = create_fields(opt)
i=35
i=1
while i<=60:
for model_name in os.listdir(opt.pkl_dir):
if "token_classification_"+str(i)+"_" in model_name:
if "token_classification_split_3_"+str(i)+"_" in model_name:
print("model_name:{}".format(model_name))
opt.load_weights=os.path.join(opt.pkl_dir,model_name)
......@@ -97,10 +99,10 @@ def main():
model = get_model_token_classification(
opt, len(SRC.vocab), len(TRG.vocab))
contents = open(os.path.join(opt.dev_dir, "dev_pinyin.txt")
contents = open(os.path.join(opt.dev_dir, "dev_pinyin_split.txt")
).read().strip().split('\n')
translates = [translate(i, opt, model, SRC, TRG)
for i in contents]
for i in tqdm(contents)]
# with open(os.path.join(opt.dev_dir,model_name),'w',encoding='utf-8') as f:
# f.write("\n".join(translates))
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
loading src_data
0%| | 0/1859421 [00:00<?, ?it/s] 18%|█▊ | 333875/1859421 [00:00<00:00, 3338402.97it/s] 40%|████ | 750599/1859421 [00:00<00:00, 3825763.75it/s] 63%|██████▎ | 1178441/1859421 [00:00<00:00, 4032329.55it/s] 86%|████████▋ | 1606694/1859421 [00:00<00:00, 4130966.80it/s] 100%|██████████| 1859421/1859421 [00:00<00:00, 4065261.71it/s]
0%| | 0/1859421 [00:00<?, ?it/s] 18%|█▊ | 326179/1859421 [00:00<00:00, 3261458.75it/s] 39%|███▉ | 730438/1859421 [00:00<00:00, 3720753.18it/s] 62%|██████▏ | 1153268/1859421 [00:00<00:00, 3952358.25it/s] 85%|████████▌ | 1586346/1859421 [00:00<00:00, 4101625.35it/s] 100%|██████████| 1859421/1859421 [00:00<00:00, 4022344.96it/s]
loading trg_data
0%| | 0/1859421 [00:00<?, ?it/s] 19%|█▉ | 352081/1859421 [00:00<00:00, 3520477.62it/s] 43%|████▎ | 805292/1859421 [00:00<00:00, 4115361.52it/s] 69%|██████▊ | 1275948/1859421 [00:00<00:00, 4385171.62it/s] 94%|█████████▍| 1752303/1859421 [00:00<00:00, 4534430.97it/s] 100%|██████████| 1859421/1859421 [00:00<00:00, 4405676.46it/s]
\ No newline at end of file
0%| | 0/1859421 [00:00<?, ?it/s] 20%|█▉ | 370561/1859421 [00:00<00:00, 3704818.57it/s] 40%|███▉ | 741043/1859421 [00:00<00:00, 3428932.21it/s] 60%|█████▉ | 1107775/1859421 [00:00<00:00, 3533118.56it/s] 84%|████████▍ | 1567222/1859421 [00:00<00:00, 3941186.11it/s] 100%|██████████| 1859421/1859421 [00:00<00:00, 3944662.75it/s]
/mnt/data2/shaozirui/anaconda3/envs/bert/lib/python3.7/site-packages/torchtext/data/field.py:150: UserWarning: Field class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.
warnings.warn('{} class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.'.format(self.__class__.__name__), UserWarning)
/mnt/data2/shaozirui/anaconda3/envs/bert/lib/python3.7/site-packages/torchtext/data/example.py:68: UserWarning: Example class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.
warnings.warn('Example class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.', UserWarning)
/mnt/data2/shaozirui/anaconda3/envs/bert/lib/python3.7/site-packages/torchtext/data/example.py:78: UserWarning: Example class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.
warnings.warn('Example class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.', UserWarning)
/mnt/data2/shaozirui/anaconda3/envs/bert/lib/python3.7/site-packages/torchtext/data/iterator.py:48: UserWarning: MyIterator class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.
warnings.warn('{} class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.'.format(self.__class__.__name__), UserWarning)
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
len of src_data:1859421 ; len of trg_data:1859421
loading spacy tokenizers...
creating dataset and iterator...
src trg
457130 sh i4 j ie4 sh ang4 m ei2 iou3 r en4 h e2 l i4... 世 _ 界 _ 上 _ 没 _ 有 任 _ 何 _ 力 _ 量 _ 敢 _ 于 向 _ 棉 ...
647959 z ai4 zh an3 h uei4 sh ang4 l iang4 x iang4 d ... 在 _ 展 _ 会 _ 上 _ 亮 _ 相 _ 的 _ 一 系 _ 列 _ 的 _ 手 _ ...
1599472 h u2 zh ao1 g uang3 、 x ü2 x i1 an1 d eng3 j i... 胡 _ 昭 _ 广 _ 、 徐 _ 锡 _ 安 等 _ 嘉 _ 宾 _ 为 获 _ 得 _ ...
736864 x ing2 r en2 n eng2 g ou4 an4 zh ao4 h ong2 l ... 行 _ 人 _ 能 _ 够 _ 按 照 _ 红 _ 绿 _ 灯 _ 信 _ 号 _ 给 _ ...
1306262 n ong2 ch an3 p in3 d e0 t ou2 z i1 e2 uei4 1 ... 农 _ 产 _ 品 _ 的 _ 投 _ 资 _ 额 为 1 亿 元 ,
0lines [00:00, ?lines/s] 733lines [00:00, 418744.87lines/s]
0lines [00:00, ?lines/s] 7857lines [00:00, 490491.41lines/s]
/mnt/data2/shaozirui/anaconda3/envs/bert/lib/python3.7/site-packages/torchtext/data/batch.py:23: UserWarning: Batch class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.
warnings.warn('{} class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.'.format(self.__class__.__name__), UserWarning)
train_len:29053
create ./weights/token_classification_split_4/11-23_18:39:31 dir
training model...
0m: epoch 1 [ ] 0% loss = ... Traceback (most recent call last):
File "train_token_classification.py", line 228, in <module>
main()
File "train_token_classification.py", line 145, in main
train_model(model, opt, start_time)
File "train_token_classification.py", line 45, in train_model
preds = model(src, src_mask)
File "/mnt/data2/shaozirui/anaconda3/envs/bert/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/mnt/data2/shaozirui/workspace/Transformer_model/Models.py", line 61, in forward
outputs = self.encoder(src, src_mask)
File "/mnt/data2/shaozirui/anaconda3/envs/bert/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/mnt/data2/shaozirui/workspace/Transformer_model/Models.py", line 23, in forward
x = self.layers[i](x, mask)
File "/mnt/data2/shaozirui/anaconda3/envs/bert/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/mnt/data2/shaozirui/workspace/Transformer_model/Layers.py", line 17, in forward
x = x + self.dropout_1(self.attn(x2,x2,x2,mask))
File "/mnt/data2/shaozirui/anaconda3/envs/bert/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/mnt/data2/shaozirui/workspace/Transformer_model/Sublayers.py", line 71, in forward
scores = attention(q, k, v, self.d_k, mask, self.dropout)
File "/mnt/data2/shaozirui/workspace/Transformer_model/Sublayers.py", line 37, in attention
output = torch.matmul(scores, v)
RuntimeError: CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 10.76 GiB total capacity; 2.01 GiB already allocated; 15.44 MiB free; 2.35 GiB reserved in total by PyTorch)
......
from pypinyin.standard import convert_finals
from pypinyin.style._utils import get_initials, get_finals
from pypinyin import Style, pinyin
from build_corpus import align2, split_initials_finals, wenzi2pinyin
from pypinyin.core import lazy_pinyin
a = 'ying1'
strict = False
with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
yunmus=f.readlines()
yunmus=[a.strip() for a in yunmus]
a = 'ying'
a=convert_finals(a)
print(a)
strict = True
if get_initials(a, strict)=="":
print("True")
text="s,但季节分布不均,汛期流量极大,而枯水期几乎河底朝天,旱突频发,居民吃水困难。"
text="一鸣惊人"
# ,Datsun新推出的轻型跑车可谓一呜惊人。
# pinyins,tones=wenzi2pinyin(text)
# print(lazy_pinyin(text, style=Style.TONE3))
......@@ -18,6 +26,7 @@ text="s,但季节分布不均,汛期流量极大,而枯水期几乎河底朝天,
# pnyns = [split_initials_finals(a) for a in lazy_pinyin(text, style=Style.TONE3)]
# print(pnyns)
text = text.replace(" ","")
# print(wenzi2pinyin(text))
pnyns, hanzis=align2(text)
print(pnyns)
......
......@@ -68,6 +68,8 @@ def main():
parser.add_argument('-dropout', type=int, default=0.1)
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-floyd', action='store_true')
parser.add_argument('-src_voc')
parser.add_argument('-trg_voc')
opt = parser.parse_args()
......
......@@ -71,6 +71,8 @@ def main():
parser.add_argument('-floyd', action='store_true')
parser.add_argument('-test_dir', type=str, required=True)
parser.add_argument('-result_dir', type=str, required=True)
parser.add_argument('-src_voc')
parser.add_argument('-trg_voc')
opt = parser.parse_args()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment