Commit 636dc31e authored by szr712's avatar szr712

增加模型测试与构建数据集的代码

parent 3fc8f0c1
......@@ -16,8 +16,12 @@ CUDA_VISIBLE_DEVICES=2 nohup python translate_file.py -load_weights weights/piny
CUDA_VISIBLE_DEVICES=3 nohup python translate_file.py -load_weights weights/pinyin_to_hanzi_onlyChinese/11-04_16:36:46/pinyin_to_hanzi_onlyChinese_27_0.0009592685928873834 -pkl_dir weights/pinyin_to_hanzi_onlyChinese/11-04_16:36:46 -test_dir data/test_data/pinyin_onlyChinese -result_dir data/test_data/result_onlyChinese >log2 2>&1 &
CUDA_VISIBLE_DEVICES=3 python translate_file2.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_17_0.06582485044375062 -pkl_dir weights/token_classification/11-09_22:00:55 -test_dir data/test_data/pinyin -result_dir data/test_data/result_token_classification
CUDA_VISIBLE_DEVICES=3 python translate_file2.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_35_0.055335590355098246 -pkl_dir weights/token_classification/11-09_22:00:55 -test_dir data/test_data/pinyin -result_dir data/test_data/result_token_classification
CUDA_VISIBLE_DEVICES=4 nohup python eval_model.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_17_0.06582485044375062 -pkl_dir weights/token_classification/11-09_22:00:55 -dev_dir data/dev >log 2>&1 &
CUDA_VISIBLE_DEVICES=1 python train2.py -src_data data/pinyin_2.txt -trg_data data/hanzi_2.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification
CUDA_VISIBLE_DEVICES=1 nohup python train2.py -src_data data/pinyin_2.txt -trg_data data/hanzi_2.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification
\ No newline at end of file
CUDA_VISIBLE_DEVICES=1 nohup python train2.py -src_data data/pinyin_2.txt -trg_data data/hanzi_2.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification
CUDA_VISIBLE_DEVICES=3 python translate_pkl.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_35_0.055335590355098246 -pkl_dir weights/token_classification/11-09_22:00:55 -test_dir data/pkl/test-pkl -result_dir data/pkl/test-pkl-result
\ No newline at end of file
#-*- coding: utf-8 -*-
#!/usr/bin/python2
"""
Before running this code, make sure that you've downloaded Leipzig Chinese Corpus
(http://corpora2.informatik.uni-leipzig.de/downloads/zho_news_2007-2009_1M-text.tar.gz)
Extract and copy the `zho_news_2007-2009_1M-sentences.txt` to `data/` folder.
This code should generate a file which looks like this:
2[Tab]zhegeyemianxianzaiyijingzuofei...。[Tab]这__个_页_面___现___在__已_经___作__废__...。
In each line, the id, pinyin, and a chinese sentence are separated by a tab.
Note that _ means blanks.
Created in Aug. 2017, kyubyong. kbpark.linguist@gmail.com
"""
from __future__ import print_function
import codecs
import os
from threading import Semaphore
import regex # pip install regex
from xpinyin import Pinyin # pip install xpinyin
import traceback
pinyin = Pinyin()
def align(sent):
'''
Args:
sent: A string. A sentence.
Returns:
A tuple of pinyin and chinese sentence.
'''
pnyns = pinyin.get_pinyin(sent, " ").split()
hanzis = []
x=0
for i in range(len(pnyns)):
if i+x<len(sent.replace(" ", "")):
char=sent[i+x]
p=pnyns[i]
if '\u4e00' <= char <= '\u9fa5':
hanzis.extend([char] + ["_"] * (len(p) - 1))
else:
while not '\u4e00' <= char <= '\u9fa5':
hanzis.extend([char])
x=x+1
if x+i>=len(sent.replace(" ", "")):
break
char=sent[i+x]
x=x-1
# for char, p in zip(sent.replace(" ", ""), pnyns):
# if '\u4e00' <= char <= '\u9fa5':
# hanzis.extend([char] + ["_"] * (len(p) - 1))
pnyns = "".join(pnyns)
hanzis = "".join(hanzis)
assert len(pnyns) == len(hanzis), "The hanzis and the pinyins must be the same in length."
return pnyns, hanzis
def clean(text):
# if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
# # if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
# return ""
text = regex.sub(u"[^ \p{\u4e00-\u9fa5}。,!?]", "", text)
text_new=""
flag=0
for char in text:
if char !="。" and char !="," and char !="!" and char!="?":
flag=0
text_new=text_new+char
elif not flag:
flag=1
text_new=text_new+char
# while ",," in text:
# text=text.replace(",,",",")
if len(text_new)<10: return ""
return text_new
def build_corpus():
pinyin_list=[]
hanzi_list=[]
with codecs.open("data/zh.tsv", 'w', 'utf-8') as fout:
# with codecs.open("data/zho_news_2007-2009_1M-sentences.txt", 'r', 'utf-8') as fin:
with codecs.open("data/train_set_total.txt", 'r', 'utf-8') as fin:
i = 1
while 1:
line = fin.readline()
if not line: break
try:
# idx, sent = line.strip().split("\t")
# if idx == "234":
# print(sent)
# sent = clean(sent)
# if len(sent) > 0:
# pnyns, hanzis = align(sent)
# fout.write(u"{}\t{}\t{}\n".format(idx, pnyns, hanzis))
sent = line.strip()
sent = sent.replace(" ","")
# sent = clean(sent)
if len(sent) > 0:
pnyns, hanzis = align(sent)
fout.write(u"{}\t{}\n".format(pnyns, hanzis))
except:
traceback.print_exc()
continue # it's okay as we have a pretty big corpus!
if i % 10000 == 0: print(i, )
i += 1
if __name__ == "__main__":
build_corpus(); print("Done")
......@@ -91,10 +91,7 @@ def getList(fileList, dirPath):
return result
if __name__ == "__main__":
num_process = 128
preFile = "./data/test_data/result_short"
textFile = "./data/test_data/hanzi_short"
def cer_multi(num_process,preFile,textFile):
lock = multiprocess.Lock()
# preList=getList(preFile,"./preFile")
......@@ -144,3 +141,10 @@ if __name__ == "__main__":
n+=value
# print(n)
print('{} \n total char:{} CER: {:.3f}'.format(pre[:-7],n,w/float(n)))
if __name__ == "__main__":
num_process = 128
preFile = "./data/test_data/result_token_classification"
textFile = "./data/test_data/hanzi"
cer_multi(num_process,preFile,textFile)
\ No newline at end of file
import argparse
import time
import torch
from Models import get_model, get_model_token_classification
from Process import *
import torch.nn.functional as F
from Optim import CosineWithRestarts
from Batch import create_masks
import pdb
import dill as pickle
import argparse
from Models import get_model
from Beam import beam_search
# from nltk.corpus import wordnet
from torch.autograd import Variable
import re
import time
import random
import distance
from cer_multi import ishan
def get_result(src, model, SRC, TRG, opt):
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
output = model(src, src_mask)
output = F.softmax(output, dim=-1)
preds = torch.argmax(output, dim=-1)
return ''.join([TRG.vocab.itos[tok] for tok in preds[0][:]]).replace("_", "")
def translate_sentence(sentence, model, opt, SRC, TRG):
model.eval()
indexed = []
sentence = SRC.preprocess(sentence)
for tok in sentence:
if SRC.vocab.stoi[tok] != 0 or opt.floyd == True:
indexed.append(SRC.vocab.stoi[tok])
else:
# indexed.append(get_synonym(tok, SRC))
pass
sentence = Variable(torch.LongTensor([indexed]))
if opt.device == 0:
sentence = sentence.cuda()
sentence = get_result(sentence, model, SRC, TRG, opt)
return sentence
def translate(i, opt, model, SRC, TRG):
# sentences = opt.text.lower().split('.')
sentence = i
# sentences=[a for a in sentences if len(a)>0]
translated = []
translated.append(translate_sentence(sentence, model, opt, SRC, TRG))
return (' '.join(translated)[:])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-load_weights', required=True)
parser.add_argument('-pkl_dir', required=True)
parser.add_argument('-k', type=int, default=3)
parser.add_argument('-max_len', type=int, default=80)
parser.add_argument('-d_model', type=int, default=512)
parser.add_argument('-n_layers', type=int, default=6)
# parser.add_argument('-src_lang', required=True)
# parser.add_argument('-trg_lang', required=True)
parser.add_argument('-heads', type=int, default=8)
parser.add_argument('-dropout', type=int, default=0.1)
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-floyd', action='store_true')
parser.add_argument("-dev_dir", type=str, required=True)
opt = parser.parse_args()
opt.device = 0 if opt.no_cuda is False else -1
assert opt.k > 0
assert opt.max_len > 10
SRC, TRG = create_fields(opt)
i=35
while i<=60:
for model_name in os.listdir(opt.pkl_dir):
if "token_classification_"+str(i)+"_" in model_name:
print("model_name:{}".format(model_name))
opt.load_weights=os.path.join(opt.pkl_dir,model_name)
model = get_model_token_classification(
opt, len(SRC.vocab), len(TRG.vocab))
contents = open(os.path.join(opt.dev_dir, "dev_pinyin.txt")
).read().strip().split('\n')
translates = [translate(i, opt, model, SRC, TRG)
for i in contents]
# with open(os.path.join(opt.dev_dir,model_name),'w',encoding='utf-8') as f:
# f.write("\n".join(translates))
gt = open(os.path.join(opt.dev_dir, "dev_hanzi.txt")
).read().strip().split('\n')
total_edit_distance, num_chars = 0, 0
for pred, expected in tqdm(zip(translates, gt)):
pred = ishan(pred)
expected = ishan(expected)
edit_distance = distance.levenshtein(expected, pred)
total_edit_distance += edit_distance
num_chars += len(expected)
print("Total CER: {}/{}={}\n".format(total_edit_distance,
num_chars,
round(float(total_edit_distance)/num_chars, 5)))
break
i=i+1
if __name__ == '__main__':
main()
No preview for this file type
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment