Commit 883266fd authored by szr712's avatar szr712

一些修改

parent 6442322f
# 训练
```
python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
```
parameters:
- -src_data 拼音文件
- -trg_data 汉字文件
- -epochs 训练epoch
- -model_name 模型名称
- -src_voc 拼音字典
- -trg_voc 汉字字典
- -batchsize 默认64
# 验证集验证模型
```
python eval_model.py -pkl_dir weights/token_classification_split_new/11-28_14:53:18/ -dev_dir data/dev -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt -model_name token_classification_split_new
```
parameters:
- -pkl_dir 存储pkl(Filed对象)的目录
- -dev_dir 验证集存储目录
- -model_name 模型名称
- -src_voc 拼音字典
- -trg_voc 汉字字典
# 测试
```
python translate_file2.py -load_weights weights/token_classification_split_new/11-28_14:53:18/token_classification_split_new_26_0.025378503443207592 -pkl_dir weights/token_classification_split_new/11-28_14:53:18 -test_dir data/test_data/split_new_data_daoxuehao/pinyin_random_change_tones -result_dir data/test_data/split_new_data_daoxuehao/result_random_change_tones -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
```
parameters:
- -load_weights:模型存储路径
- -pkl_dir 存储pkl(Filed对象)的目录
- -test_dir 测试集目录
- -result_dir 翻译结果存储目录
- -src_voc 拼音字典
- -trg_voc 汉字字典
\ No newline at end of file
......@@ -57,7 +57,7 @@ def cer(preFile,textFile):
round(float(total_edit_distance)/num_chars, 5)))
if __name__ == "__main__":
preFile = "./data/test_data/result_split"
textFile = "./data/test_data/hanzi"
preFile = "./data/test_data/split_new_data_daoxuehao/result_random_change_tones"
textFile = "./data/test_data/hanzi_new"
cer(preFile,textFile)
\ No newline at end of file
......@@ -3,9 +3,18 @@ import os
from tqdm import tqdm
from build_corpus import split_initials_finals, wenzi2pinyin
import random
def random_change_tones(tones):
options=[0,1,2,3,4]
random.seed(42)
for i,x in enumerate(tones):
if random.randint(0,99) < 30:
tones[i]=random.choice(options)
return tones
hanzi_dir="./data/test_data/hanzi_new"
pinyin_dir="./data/test_data/pinyin_split_new"
pinyin_dir="./data/test_data/split_new_data_daoxuehao/pinyin_random_change_tones"
with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
yunmus=f.readlines()
......@@ -20,6 +29,7 @@ for file in os.listdir(hanzi_dir):
sent = line.strip()
sent = sent.replace(" ","")
pinyins,tones=wenzi2pinyin(sent)
# tones=random_change_tones(tones)
pnyns=[]
i=0
for pinyin,tone in zip(pinyins,tones):
......
......@@ -63,7 +63,7 @@ def translate(i, opt, model, SRC, TRG):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-load_weights', required=True)
parser.add_argument('-load_weights')
parser.add_argument('-pkl_dir', required=True)
parser.add_argument('-k', type=int, default=3)
parser.add_argument('-max_len', type=int, default=80)
......@@ -76,6 +76,7 @@ def main():
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-floyd', action='store_true')
parser.add_argument("-dev_dir", type=str, required=True)
parser.add_argument('-model_name', required=True)
parser.add_argument('-src_voc')
parser.add_argument('-trg_voc')
......@@ -88,13 +89,13 @@ def main():
SRC, TRG = create_fields(opt)
i=1
while i<=60:
i = 1
while i <= 60:
for model_name in os.listdir(opt.pkl_dir):
if "token_classification_split_new_"+str(i)+"_" in model_name:
if opt.model_name+"_"+str(i)+"_" in model_name:
print("model_name:{}".format(model_name))
opt.load_weights=os.path.join(opt.pkl_dir,model_name)
opt.load_weights = os.path.join(opt.pkl_dir, model_name)
model = get_model_token_classification(
opt, len(SRC.vocab), len(TRG.vocab))
......@@ -102,13 +103,13 @@ def main():
contents = open(os.path.join(opt.dev_dir, "dev_pinyin_split.txt")
).read().strip().split('\n')
translates = [translate(i, opt, model, SRC, TRG)
for i in tqdm(contents)]
for i in tqdm(contents)]
# with open(os.path.join(opt.dev_dir,model_name),'w',encoding='utf-8') as f:
# f.write("\n".join(translates))
gt = open(os.path.join(opt.dev_dir, "dev_hanzi.txt")
).read().strip().split('\n')
).read().strip().split('\n')
total_edit_distance, num_chars = 0, 0
for pred, expected in tqdm(zip(translates, gt)):
......@@ -119,10 +120,10 @@ def main():
num_chars += len(expected)
print("Total CER: {}/{}={}\n".format(total_edit_distance,
num_chars,
round(float(total_edit_distance)/num_chars, 5)))
num_chars,
round(float(total_edit_distance)/num_chars, 5)))
break
i=i+1
i = i+1
if __name__ == '__main__':
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment