Commit 883266fd authored by szr712's avatar szr712

一些修改

parent 6442322f
# 训练
```
python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
```
parameters:
- -src_data 拼音文件
- -trg_data 汉字文件
- -epochs 训练epoch
- -model_name 模型名称
- -src_voc 拼音字典
- -trg_voc 汉字字典
- -batchsize 默认64
# 验证集验证模型
```
python eval_model.py -pkl_dir weights/token_classification_split_new/11-28_14:53:18/ -dev_dir data/dev -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt -model_name token_classification_split_new
```
parameters:
- -pkl_dir 存储pkl(Filed对象)的目录
- -dev_dir 验证集存储目录
- -model_name 模型名称
- -src_voc 拼音字典
- -trg_voc 汉字字典
# 测试
```
python translate_file2.py -load_weights weights/token_classification_split_new/11-28_14:53:18/token_classification_split_new_26_0.025378503443207592 -pkl_dir weights/token_classification_split_new/11-28_14:53:18 -test_dir data/test_data/split_new_data_daoxuehao/pinyin_random_change_tones -result_dir data/test_data/split_new_data_daoxuehao/result_random_change_tones -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
```
parameters:
- -load_weights:模型存储路径
- -pkl_dir 存储pkl(Filed对象)的目录
- -test_dir 测试集目录
- -result_dir 翻译结果存储目录
- -src_voc 拼音字典
- -trg_voc 汉字字典
\ No newline at end of file
...@@ -57,7 +57,7 @@ def cer(preFile,textFile): ...@@ -57,7 +57,7 @@ def cer(preFile,textFile):
round(float(total_edit_distance)/num_chars, 5))) round(float(total_edit_distance)/num_chars, 5)))
if __name__ == "__main__": if __name__ == "__main__":
preFile = "./data/test_data/result_split" preFile = "./data/test_data/split_new_data_daoxuehao/result_random_change_tones"
textFile = "./data/test_data/hanzi" textFile = "./data/test_data/hanzi_new"
cer(preFile,textFile) cer(preFile,textFile)
\ No newline at end of file
...@@ -3,9 +3,18 @@ import os ...@@ -3,9 +3,18 @@ import os
from tqdm import tqdm from tqdm import tqdm
from build_corpus import split_initials_finals, wenzi2pinyin from build_corpus import split_initials_finals, wenzi2pinyin
import random
def random_change_tones(tones):
options=[0,1,2,3,4]
random.seed(42)
for i,x in enumerate(tones):
if random.randint(0,99) < 30:
tones[i]=random.choice(options)
return tones
hanzi_dir="./data/test_data/hanzi_new" hanzi_dir="./data/test_data/hanzi_new"
pinyin_dir="./data/test_data/pinyin_split_new" pinyin_dir="./data/test_data/split_new_data_daoxuehao/pinyin_random_change_tones"
with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f: with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
yunmus=f.readlines() yunmus=f.readlines()
...@@ -20,6 +29,7 @@ for file in os.listdir(hanzi_dir): ...@@ -20,6 +29,7 @@ for file in os.listdir(hanzi_dir):
sent = line.strip() sent = line.strip()
sent = sent.replace(" ","") sent = sent.replace(" ","")
pinyins,tones=wenzi2pinyin(sent) pinyins,tones=wenzi2pinyin(sent)
# tones=random_change_tones(tones)
pnyns=[] pnyns=[]
i=0 i=0
for pinyin,tone in zip(pinyins,tones): for pinyin,tone in zip(pinyins,tones):
......
...@@ -63,7 +63,7 @@ def translate(i, opt, model, SRC, TRG): ...@@ -63,7 +63,7 @@ def translate(i, opt, model, SRC, TRG):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-load_weights', required=True) parser.add_argument('-load_weights')
parser.add_argument('-pkl_dir', required=True) parser.add_argument('-pkl_dir', required=True)
parser.add_argument('-k', type=int, default=3) parser.add_argument('-k', type=int, default=3)
parser.add_argument('-max_len', type=int, default=80) parser.add_argument('-max_len', type=int, default=80)
...@@ -76,6 +76,7 @@ def main(): ...@@ -76,6 +76,7 @@ def main():
parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-floyd', action='store_true') parser.add_argument('-floyd', action='store_true')
parser.add_argument("-dev_dir", type=str, required=True) parser.add_argument("-dev_dir", type=str, required=True)
parser.add_argument('-model_name', required=True)
parser.add_argument('-src_voc') parser.add_argument('-src_voc')
parser.add_argument('-trg_voc') parser.add_argument('-trg_voc')
...@@ -88,13 +89,13 @@ def main(): ...@@ -88,13 +89,13 @@ def main():
SRC, TRG = create_fields(opt) SRC, TRG = create_fields(opt)
i=1 i = 1
while i<=60: while i <= 60:
for model_name in os.listdir(opt.pkl_dir): for model_name in os.listdir(opt.pkl_dir):
if "token_classification_split_new_"+str(i)+"_" in model_name: if opt.model_name+"_"+str(i)+"_" in model_name:
print("model_name:{}".format(model_name)) print("model_name:{}".format(model_name))
opt.load_weights=os.path.join(opt.pkl_dir,model_name) opt.load_weights = os.path.join(opt.pkl_dir, model_name)
model = get_model_token_classification( model = get_model_token_classification(
opt, len(SRC.vocab), len(TRG.vocab)) opt, len(SRC.vocab), len(TRG.vocab))
...@@ -102,13 +103,13 @@ def main(): ...@@ -102,13 +103,13 @@ def main():
contents = open(os.path.join(opt.dev_dir, "dev_pinyin_split.txt") contents = open(os.path.join(opt.dev_dir, "dev_pinyin_split.txt")
).read().strip().split('\n') ).read().strip().split('\n')
translates = [translate(i, opt, model, SRC, TRG) translates = [translate(i, opt, model, SRC, TRG)
for i in tqdm(contents)] for i in tqdm(contents)]
# with open(os.path.join(opt.dev_dir,model_name),'w',encoding='utf-8') as f: # with open(os.path.join(opt.dev_dir,model_name),'w',encoding='utf-8') as f:
# f.write("\n".join(translates)) # f.write("\n".join(translates))
gt = open(os.path.join(opt.dev_dir, "dev_hanzi.txt") gt = open(os.path.join(opt.dev_dir, "dev_hanzi.txt")
).read().strip().split('\n') ).read().strip().split('\n')
total_edit_distance, num_chars = 0, 0 total_edit_distance, num_chars = 0, 0
for pred, expected in tqdm(zip(translates, gt)): for pred, expected in tqdm(zip(translates, gt)):
...@@ -119,10 +120,10 @@ def main(): ...@@ -119,10 +120,10 @@ def main():
num_chars += len(expected) num_chars += len(expected)
print("Total CER: {}/{}={}\n".format(total_edit_distance, print("Total CER: {}/{}={}\n".format(total_edit_distance,
num_chars, num_chars,
round(float(total_edit_distance)/num_chars, 5))) round(float(total_edit_distance)/num_chars, 5)))
break break
i=i+1 i = i+1
if __name__ == '__main__': if __name__ == '__main__':
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -30,11 +30,11 @@ CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data ...@@ -30,11 +30,11 @@ CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data
CUDA_VISIBLE_DEVICES=1 python translate2.py -load_weights weights/token_classification_split_4/11-23_22:02:06/token_classification_split_4_25_0.02742394618457183 -pkl_dir weights/token_classification_split_4/11-23_22:02:06 -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt CUDA_VISIBLE_DEVICES=1 python translate2.py -load_weights weights/token_classification_split_4/11-23_22:02:06/token_classification_split_4_25_0.02742394618457183 -pkl_dir weights/token_classification_split_4/11-23_22:02:06 -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
CUDA_VISIBLE_DEVICES=4 nohup python translate_file2.py -load_weights weights/token_classification_split_4/11-23_22:02:06/token_classification_split_4_25_0.02742394618457183 -pkl_dir weights/token_classification_split_4/11-23_22:02:06 -test_dir data/test_data/pinyin_split -result_dir data/test_data/result_split -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log1 2>&1 & CUDA_VISIBLE_DEVICES=4 nohup python translate_file2.py -load_weights weights/token_classification_split_new/11-28_14:53:18/token_classification_split_new_26_0.025378503443207592 -pkl_dir weights/token_classification_split_new/11-28_14:53:18 -test_dir data/test_data/split_new_data_daoxuehao/pinyin_random_change_tones -result_dir data/test_data/split_new_data_daoxuehao/result_random_change_tones -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log 2>&1 &
CUDA_VISIBLE_DEVICES=1 python eval_model.py -load_weights weights/token_classification_split_4/11-23_22:02:06/token_classification_split_4_1_0.09183966986835003 -pkl_dir weights/token_classification_split_4/11-23_22:02:06 -dev_dir data/dev -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log1 2>&1 & CUDA_VISIBLE_DEVICES=4 python eval_model.py -load_weights weights/token_classification_split_new/11-28_14:53:18/token_classification_split_new_1_0.08814825743436813 -pkl_dir weights/token_classification_split_new/11-28_14:53:18/ -dev_dir data/dev -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log 2>&1 &
CUDA_VISIBLE_DEVICES=6 nohup python train_token_classification.py -src_data data/pinyin_split.txt -trg_data data/hanzi_split.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification_split_4 -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt CUDA_VISIBLE_DEVICES=6 nohup python train_token_classification.py -src_data data/pinyin_split.txt -trg_data data/hanzi_split.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification_split_4 -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log1 2>&1 & CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log1 2>&1 &
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -15,7 +15,7 @@ strict = True ...@@ -15,7 +15,7 @@ strict = True
if get_initials(a, strict)=="": if get_initials(a, strict)=="":
print("True") print("True")
text="一鸣惊人" text="儿童节的设立"
# ,Datsun新推出的轻型跑车可谓一呜惊人。 # ,Datsun新推出的轻型跑车可谓一呜惊人。
# pinyins,tones=wenzi2pinyin(text) # pinyins,tones=wenzi2pinyin(text)
......
...@@ -81,8 +81,8 @@ def main(): ...@@ -81,8 +81,8 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-src_data', required=True) parser.add_argument('-src_data', required=True)
parser.add_argument('-trg_data', required=True) parser.add_argument('-trg_data', required=True)
parser.add_argument('-src_lang', required=True) parser.add_argument('-src_lang', required=True,default="en_core_web_sm")
parser.add_argument('-trg_lang', required=True) parser.add_argument('-trg_lang', required=True,default="fr_core_news_sm")
parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-SGDR', action='store_true') parser.add_argument('-SGDR', action='store_true')
parser.add_argument('-epochs', type=int, default=2) parser.add_argument('-epochs', type=int, default=2)
......
...@@ -83,8 +83,8 @@ def main(): ...@@ -83,8 +83,8 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-src_data', required=True) parser.add_argument('-src_data', required=True)
parser.add_argument('-trg_data', required=True) parser.add_argument('-trg_data', required=True)
parser.add_argument('-src_lang', required=True) parser.add_argument('-src_lang', required=True,default="en_core_web_sm")
parser.add_argument('-trg_lang', required=True) parser.add_argument('-trg_lang', required=True,default="fr_core_news_sm")
parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-SGDR', action='store_true') parser.add_argument('-SGDR', action='store_true')
parser.add_argument('-epochs', type=int, default=2) parser.add_argument('-epochs', type=int, default=2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment