Commit c217436c authored by szr712's avatar szr712

修改拼音作为token中遇到的错误

parent a0e5b946
......@@ -26,4 +26,12 @@ CUDA_VISIBLE_DEVICES=1 nohup python train2.py -src_data data/pinyin_2.txt -trg_d
CUDA_VISIBLE_DEVICES=3 python translate_pkl.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_35_0.055335590355098246 -pkl_dir weights/token_classification/11-09_22:00:55 -test_dir data/pkl/test-pkl -result_dir data/pkl/test-pkl-result
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_split.txt -trg_data data/hanzi_split.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification_split -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
\ No newline at end of file
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_split.txt -trg_data data/hanzi_split.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification_split -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
CUDA_VISIBLE_DEVICES=1 python translate2.py -load_weights weights/token_classification_split_2/11-19_17:16:18/token_classification_split_2_5_0.05776993067935109 -pkl_dir weights/token_classification_split_2/11-19_17:16:18 -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
CUDA_VISIBLE_DEVICES=4 nohup python translate_file2.py -load_weights weights/token_classification_split_3/11-22_21:56:11/token_classification_split_3_25_0.029638311734888702 -pkl_dir weights/token_classification_split_3/11-22_21:56:11 -test_dir data/test_data/pinyin_split -result_dir data/test_data/result_split -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log1 2>&1 &
CUDA_VISIBLE_DEVICES=3 python eval_model.py -load_weights weights/token_classification_split_3/11-22_21:56:11/token_classification_split_3_1_0.09703897424042225 -pkl_dir weights/token_classification_split_3/11-22_21:56:11 -dev_dir data/dev -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_split.txt -trg_data data/hanzi_split.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification_split_4 -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
\ No newline at end of file
#-*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
#!/usr/bin/python2
"""
Before running this code, make sure that you've downloaded Leipzig Chinese Corpus
......@@ -17,8 +17,8 @@ from __future__ import print_function
import codecs
import os
from threading import Semaphore
import regex # pip install regex
from xpinyin import Pinyin # pip install xpinyin
import regex # pip install regex
from xpinyin import Pinyin # pip install xpinyin
import traceback
from pypinyin.style._utils import get_initials, get_finals
from pypinyin import Style, pinyin
......@@ -29,79 +29,99 @@ import itertools
def wenzi2pinyin(text):
pinyin_list = lazy_pinyin(text, style=Style.TONE3)
# print(pinyin_list)
# tones_list = [int(py[-1]) if py[-1].isdigit()
# else 0 for py in pinyin_list]
# pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
return pinyin_list
def split_initials_finals(pinyin):
if get_initials(pinyin, False)!="":
return [get_initials(pinyin, False), get_finals(pinyin, False)]
tones_list = [int(py[-1]) if py[-1].isdigit()
else 0 for py in pinyin_list]
pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
return pinyin_list, tones_list
def split_initials_finals(pinyin, tone, char):
strict = True
if not ('\u4e00' <= char <= '\u9fff'):
# return [a for a in get_initials(pinyin, strict)+get_finals(pinyin, strict)]
return [a for a in pinyin]
else:
return [get_finals(pinyin, False)]
pinyin = pinyin.replace("v", "ü")
if get_initials(pinyin, strict) != "":
return [get_initials(pinyin, strict), get_finals(pinyin, strict)+str(tone)]
else:
return [get_finals(pinyin, strict)+str(tone)]
def align(sent):
'''
Args:
sent: A string. A sentence.
Returns:
A tuple of pinyin and chinese sentence.
'''
'''
pnyns = pinyin.get_pinyin(sent, " ").split()
hanzis = []
x=0
x = 0
for i in range(len(pnyns)):
if i+x<len(sent.replace(" ", "")):
char=sent[i+x]
p=pnyns[i]
if i+x < len(sent.replace(" ", "")):
char = sent[i+x]
p = pnyns[i]
if '\u4e00' <= char <= '\u9fa5':
hanzis.extend([char] + ["_"] * (len(p) - 1))
else:
while not '\u4e00' <= char <= '\u9fa5':
hanzis.extend([char])
x=x+1
if x+i>=len(sent.replace(" ", "")):
x = x+1
if x+i >= len(sent.replace(" ", "")):
break
char=sent[i+x]
x=x-1
char = sent[i+x]
x = x-1
# for char, p in zip(sent.replace(" ", ""), pnyns):
# if '\u4e00' <= char <= '\u9fa5':
# hanzis.extend([char] + ["_"] * (len(p) - 1))
pnyns = "".join(pnyns)
hanzis = "".join(hanzis)
assert len(pnyns) == len(hanzis), "The hanzis and the pinyins must be the same in length."
assert len(pnyns) == len(
hanzis), "The hanzis and the pinyins must be the same in length."
return pnyns, hanzis
def align2(sent):
'''
Args:
sent: A string. A sentence.
Returns:
A tuple of pinyin and chinese sentence.
'''
'''
# pnyns = pinyin.get_pinyin(sent, " ").split()
pinyins=wenzi2pinyin(sent)
pnyns = [split_initials_finals(a) for a in pinyins]
pinyins, tones = wenzi2pinyin(sent)
pnyns = []
i = 0
for pinyin, tone in zip(pinyins, tones):
if '\u4e00' <= sent[i] <= '\u9fa5':
pnyns.append(split_initials_finals(pinyin, tone, sent[i]))
i += 1
else:
pnyns.append(split_initials_finals(
pinyin, tone, sent[i:i+len(pinyin)]))
i += len(pinyin)
hanzis = []
x=0
x = 0
for i in range(len(pnyns)):
if i+x<len(sent.replace(" ", "")):
char=sent[i+x]
p=pnyns[i]
if i+x < len(sent.replace(" ", "")):
char = sent[i+x]
p = pnyns[i]
if '\u4e00' <= char <= '\u9fa5':
hanzis.extend([char] + ["_"] * (len(p) - 1))
else:
for q in p:
hanzis.append(q)
x+=len(q)
x-=1
x += len(q)
x -= 1
# word=""
# while not '\u4e00' <= char <= '\u9fa5':
......@@ -113,11 +133,11 @@ def align2(sent):
# hanzis.extend([word])
# x=x-1
# for char, p in zip(sent.replace(" ", ""), pnyns):
# if '\u4e00' <= char <= '\u9fa5':
# hanzis.extend([char] + ["_"] * (len(p) - 1))
pnyns = " ".join(list(itertools.chain.from_iterable(pnyns)))
hanzis = " ".join(hanzis)
......@@ -125,39 +145,43 @@ def align2(sent):
print(sent)
print(pnyns)
print(hanzis)
# assert len(pnyns.split(" ")) == len(hanzis.split(" ")), "The hanzis and the pinyins must be the same in length."
return pnyns, hanzis
def clean(text):
# if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
# # if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
# return ""
text = regex.sub(u"[^ \p{\u4e00-\u9fa5}。,!?]", "", text)
text_new=""
flag=0
text_new = ""
flag = 0
for char in text:
if char !="。" and char !="," and char !="!" and char!="?":
flag=0
text_new=text_new+char
if char != "。" and char != "," and char != "!" and char != "?":
flag = 0
text_new = text_new+char
elif not flag:
flag=1
text_new=text_new+char
flag = 1
text_new = text_new+char
# while ",," in text:
# text=text.replace(",,",",")
if len(text_new)<10: return ""
if len(text_new) < 10:
return ""
return text_new
def build_corpus(src_file,pinyin_file,hanzi_file):
pinyin_list=[]
hanzi_list=[]
def build_corpus(src_file, pinyin_file, hanzi_file):
pinyin_list = []
hanzi_list = []
# with codecs.open("data/zho_news_2007-2009_1M-sentences.txt", 'r', 'utf-8') as fin:
with codecs.open(src_file, 'r', 'utf-8') as fin:
i = 1
while 1:
line = fin.readline()
if not line: break
if not line:
break
try:
# idx, sent = line.strip().split("\t")
# if idx == "234":
......@@ -167,7 +191,7 @@ def build_corpus(src_file,pinyin_file,hanzi_file):
# pnyns, hanzis = align(sent)
# fout.write(u"{}\t{}\t{}\n".format(idx, pnyns, hanzis))
sent = line.strip()
sent = sent.replace(" ","")
sent = sent.replace(" ", "")
# sent = clean(sent)
if len(sent) > 0:
pnyns, hanzis = align2(sent)
......@@ -175,14 +199,21 @@ def build_corpus(src_file,pinyin_file,hanzi_file):
hanzi_list.append(hanzis)
except:
traceback.print_exc()
continue # it's okay as we have a pretty big corpus!
if i % 10000 == 0: print(i, )
continue # it's okay as we have a pretty big corpus!
if i % 10000 == 0:
print(i, )
i += 1
with open(pinyin_file,'w',encoding='utf-8') as f:
with open(pinyin_file, 'w', encoding='utf-8') as f:
f.write("\n".join(pinyin_list))
with open(hanzi_file,'w',encoding='utf-8') as f:
with open(hanzi_file, 'w', encoding='utf-8') as f:
f.write("\n".join(hanzi_list))
if __name__ == "__main__":
build_corpus("./data/train_set_total.txt","./data/pinyin_spilit.txt","./data/hanzi_spilit.txt"); print("Done")
# with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
# yunmus=f.readlines()
# yunmus=[a.strip() for a in yunmus]
build_corpus("./data/train_set_total.txt",
"./data/pinyin_split.txt", "./data/hanzi_split.txt")
print("Done")
"""
计算CER
需比对文字放入preFile文件夹下
原文放入textFile文件夹下
"""
import os
import re
from tqdm import tqdm
import multiprocess
import math
import time
import distance
def ishan(text):
"""去除输入字符串中除中文字符串的内容
Args:
text (str): 字符串
Returns:
str: 去除非中文字符后的字符串
"""
# for python 3.x
# sample: ishan('一') == True, ishan('我&&你') == False
result= [char if '\u4e00' <= char and char<= '\u9fff' else "" for char in text]
return "".join(result)
def cer(preFile,textFile):
# preList=getList(preFile,"./preFile")
# textList=getList(textFile,"./textFile")
# for a,b in zip(textList,preList):
# print('pred: {}, gt: {}'.format(b, a))
for pre in os.listdir(preFile):
# text=pre[:-11]+".txt"
text=pre
print("filename:{}".format(pre))
preList = []
textList = []
with open(os.path.join(preFile, pre), "r", encoding="utf-8") as fw:
preList = fw.readlines()
with open(os.path.join(textFile, text), "r", encoding="utf-8") as fw:
textList = fw.readlines()
total_edit_distance, num_chars = 0, 0
for pred, expected in tqdm(zip(preList, textList)):
pred = ishan(pred)
expected = ishan(expected)
edit_distance = distance.levenshtein(expected, pred)
total_edit_distance += edit_distance
num_chars += len(expected)
print("Total CER: {}/{}={}\n".format(total_edit_distance,
num_chars,
round(float(total_edit_distance)/num_chars, 5)))
if __name__ == "__main__":
preFile = "./data/test_data/result_split"
textFile = "./data/test_data/hanzi"
cer(preFile,textFile)
\ No newline at end of file
......@@ -101,7 +101,8 @@ def cer_multi(num_process,preFile,textFile):
# print('pred: {}, gt: {}'.format(b, a))
for pre in os.listdir(preFile):
text=pre[:-11]+".txt"
# text=pre[:-11]+".txt"
text=pre
preList = []
textList = []
with open(os.path.join(preFile, pre), "r", encoding="utf-8") as fw:
......@@ -144,7 +145,7 @@ def cer_multi(num_process,preFile,textFile):
if __name__ == "__main__":
num_process = 128
preFile = "./data/test_data/result_token_classification"
preFile = "./data/test_data/result_split"
textFile = "./data/test_data/hanzi"
cer_multi(num_process,preFile,textFile)
\ No newline at end of file
import itertools
import os
from tqdm import tqdm
from build_corpus import split_initials_finals, wenzi2pinyin
hanzi_dir="./data/test_data/hanzi"
pinyin_dir="./data/test_data/pinyin_split"
with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
yunmus=f.readlines()
yunmus=[a.strip() for a in yunmus]
for file in os.listdir(hanzi_dir):
print(file)
with open(os.path.join(hanzi_dir,file),'r',encoding="utf-8") as f:
contents=f.readlines()
result=[]
for line in tqdm(contents):
sent = line.strip()
sent = sent.replace(" ","")
pinyins,tones=wenzi2pinyin(sent)
pnyns=[]
i=0
for pinyin,tone in zip(pinyins,tones):
if '\u4e00' <= sent[i] <= '\u9fa5':
pnyns.append(split_initials_finals(pinyin,tone,sent[i]))
i+=1
else:
pnyns.append(split_initials_finals(pinyin,tone,sent[i:i+len(pinyin)]))
i+=len(pinyin)
pnyns = " ".join(list(itertools.chain.from_iterable(pnyns)))
result.append(pnyns)
with open(os.path.join(pinyin_dir,file),"w",encoding="utf-8") as f:
f.write("\n".join(result))
\ No newline at end of file
......@@ -76,6 +76,8 @@ def main():
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-floyd', action='store_true')
parser.add_argument("-dev_dir", type=str, required=True)
parser.add_argument('-src_voc')
parser.add_argument('-trg_voc')
opt = parser.parse_args()
......@@ -86,10 +88,10 @@ def main():
SRC, TRG = create_fields(opt)
i=35
i=1
while i<=60:
for model_name in os.listdir(opt.pkl_dir):
if "token_classification_"+str(i)+"_" in model_name:
if "token_classification_split_3_"+str(i)+"_" in model_name:
print("model_name:{}".format(model_name))
opt.load_weights=os.path.join(opt.pkl_dir,model_name)
......@@ -97,10 +99,10 @@ def main():
model = get_model_token_classification(
opt, len(SRC.vocab), len(TRG.vocab))
contents = open(os.path.join(opt.dev_dir, "dev_pinyin.txt")
contents = open(os.path.join(opt.dev_dir, "dev_pinyin_split.txt")
).read().strip().split('\n')
translates = [translate(i, opt, model, SRC, TRG)
for i in contents]
for i in tqdm(contents)]
# with open(os.path.join(opt.dev_dir,model_name),'w',encoding='utf-8') as f:
# f.write("\n".join(translates))
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment