Commit 3fc8f0c1 authored by szr712's avatar szr712

完成token classification模型的修改

parent f93d31d8
...@@ -28,6 +28,11 @@ def create_masks(src, trg, opt): ...@@ -28,6 +28,11 @@ def create_masks(src, trg, opt):
trg_mask = None trg_mask = None
return src_mask, trg_mask return src_mask, trg_mask
def create_masks2(src, opt):
src_mask = (src != opt.src_pad).unsqueeze(-2)
return src_mask
# patch on Torchtext's batching process that makes it more efficient # patch on Torchtext's batching process that makes it more efficient
# from http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks # from http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks
......
...@@ -51,6 +51,19 @@ class Transformer(nn.Module): ...@@ -51,6 +51,19 @@ class Transformer(nn.Module):
output = self.out(d_output) output = self.out(d_output)
return output return output
class TransformerForTokenClassification(nn.Module):
def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout):
super().__init__()
self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)
# self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout)
self.out = nn.Linear(d_model, trg_vocab)
def forward(self, src, src_mask):
outputs = self.encoder(src, src_mask)
#print("DECODER")
# d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
output = self.out(outputs)
return output
def get_model(opt, src_vocab, trg_vocab): def get_model(opt, src_vocab, trg_vocab):
assert opt.d_model % opt.heads == 0 assert opt.d_model % opt.heads == 0
...@@ -70,4 +83,25 @@ def get_model(opt, src_vocab, trg_vocab): ...@@ -70,4 +83,25 @@ def get_model(opt, src_vocab, trg_vocab):
model = model.cuda() model = model.cuda()
return model return model
def get_model_token_classification(opt, src_vocab, trg_vocab):
assert opt.d_model % opt.heads == 0
assert opt.dropout < 1
model = TransformerForTokenClassification(src_vocab, trg_vocab, opt.d_model, opt.n_layers, opt.heads, opt.dropout)
if opt.load_weights is not None:
print("loading pretrained weights...")
model.load_state_dict(torch.load(f'{opt.load_weights}'))
else:
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
if opt.device == 0:
model = model.cuda()
return model
...@@ -60,7 +60,9 @@ def create_fields(opt): ...@@ -60,7 +60,9 @@ def create_fields(opt):
# TRG = data.Field(lower=True, tokenize=t_trg.tokenizer, init_token='<sos>', eos_token='<eos>') # TRG = data.Field(lower=True, tokenize=t_trg.tokenizer, init_token='<sos>', eos_token='<eos>')
# SRC = data.Field(lower=True, tokenize=t_src.tokenizer) # SRC = data.Field(lower=True, tokenize=t_src.tokenizer)
TRG = data.Field(tokenize=my_tokenize, init_token='<sos>', eos_token='<eos>')
# TRG = data.Field(tokenize=my_tokenize, init_token='<sos>', eos_token='<eos>')
TRG = data.Field(tokenize=my_tokenize)
SRC = data.Field(tokenize=my_tokenize) SRC = data.Field(tokenize=my_tokenize)
if opt.pkl_dir is not None: if opt.pkl_dir is not None:
...@@ -92,17 +94,20 @@ def create_dataset(opt, SRC, TRG): ...@@ -92,17 +94,20 @@ def create_dataset(opt, SRC, TRG):
data_fields = [('src', SRC), ('trg', TRG)] data_fields = [('src', SRC), ('trg', TRG)]
train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields) train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)
# train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
# repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
# batch_size_fn=batch_size_fn, train=True, shuffle=True)
train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device, train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=batch_size_fn, train=True, shuffle=True) batch_size_fn=None, train=True, shuffle=True)
os.remove('translate_transformer_temp.csv') os.remove('translate_transformer_temp.csv')
if opt.load_weights is None: if opt.load_weights is None:
SRC.build_vocab(train) SRC.build_vocab(train)
print(SRC.vocab.stoi) # print(SRC.vocab.stoi)
TRG.build_vocab(train) TRG.build_vocab(train)
print(TRG.vocab.stoi) # print(TRG.vocab.stoi)
if opt.checkpoint > 0: if opt.checkpoint > 0:
try: try:
os.mkdir("weights") os.mkdir("weights")
......
...@@ -2,12 +2,22 @@ CUDA_VISIBLE_DEVICES=5 nohup python train.py -src_data data/train_set_onlyChines ...@@ -2,12 +2,22 @@ CUDA_VISIBLE_DEVICES=5 nohup python train.py -src_data data/train_set_onlyChines
CUDA_VISIBLE_DEVICES=5 python train.py -src_data data/train_set_pinyin_total.txt -trg_data data/train_set_total.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name pinyin_to_hanzi_total -load_weights weights/pinyin_to_hanzi_total/10-29_18:51:57/pinyin_to_hanzi_total_10_0.1508198243379593 -pkl_dir weights/pinyin_to_hanzi_total/10-29_18:51:57 CUDA_VISIBLE_DEVICES=5 python train.py -src_data data/train_set_pinyin_total.txt -trg_data data/train_set_total.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name pinyin_to_hanzi_total -load_weights weights/pinyin_to_hanzi_total/10-29_18:51:57/pinyin_to_hanzi_total_10_0.1508198243379593 -pkl_dir weights/pinyin_to_hanzi_total/10-29_18:51:57
CUDA_VISIBLE_DEVICES=6 python translate.py -load_weights weights/pinyin_to_hanzi_total/10-29_18:51:57/pinyin_to_hanzi_total_10_0.1508198243379593 -pkl_dir weights/pinyin_to_hanzi_total/10-29_18:51:57 CUDA_VISIBLE_DEVICES=3 python translate.py -load_weights weights/pinyin_to_hanzi_total/10-29_18:51:57/pinyin_to_hanzi_total_10_0.1508198243379593 -pkl_dir weights/pinyin_to_hanzi_total/10-29_18:51:57
CUDA_VISIBLE_DEVICES=3 python translate2.py -load_weights weights/token_classification/11-09_17:20:17/token_classification_4_0.09534996442496776 -pkl_dir weights/token_classification/11-09_17:20:17
CUDA_VISIBLE_DEVICES=5 python translate_file.py -load_weights weights/pinyin_to_hanzi_total/10-30_21:22:49/pinyin_to_hanzi_total_9_0.12325442619621754 -pkl_dir weights/pinyin_to_hanzi_total/10-30_21:22:49 -test_dir data/test_data/pinyin_short -result_dir data/test_data/result_tmp CUDA_VISIBLE_DEVICES=5 python translate_file.py -load_weights weights/pinyin_to_hanzi_total/10-30_21:22:49/pinyin_to_hanzi_total_9_0.12325442619621754 -pkl_dir weights/pinyin_to_hanzi_total/10-30_21:22:49 -test_dir data/test_data/pinyin_short -result_dir data/test_data/result_tmp
CUDA_VISIBLE_DEVICES=6 python translate_pkl.py -load_weights weights/pinyin_to_hanzi_total/10-30_21:22:49/pinyin_to_hanzi_total_9_0.12325442619621754 -pkl_dir weights/pinyin_to_hanzi_total/10-30_21:22:49 -test_dir data/pkl/label_pkl -result_dir data/pkl/lable_pkl_result CUDA_VISIBLE_DEVICES=1 python translate_pkl.py -load_weights weights/pinyin_to_hanzi_total/10-30_21:22:49/pinyin_to_hanzi_total_9_0.12325442619621754 -pkl_dir weights/pinyin_to_hanzi_total/10-30_21:22:49 -test_dir data/pkl/1105-pinyin-pkl -result_dir data/pkl/1105-pinyin-pkl-result
CUDA_VISIBLE_DEVICES=2 nohup python translate_file.py -load_weights weights/pinyin_to_hanzi_total/10-30_21:22:49/pinyin_to_hanzi_total_59_0.07513352055102587 -pkl_dir weights/pinyin_to_hanzi_total/10-30_21:22:49 -test_dir data/test_data/pinyin_short -result_dir data/test_data/result_short >log2 2>&1 & CUDA_VISIBLE_DEVICES=2 nohup python translate_file.py -load_weights weights/pinyin_to_hanzi_total/10-30_21:22:49/pinyin_to_hanzi_total_59_0.07513352055102587 -pkl_dir weights/pinyin_to_hanzi_total/10-30_21:22:49 -test_dir data/test_data/pinyin_short -result_dir data/test_data/result_short >log2 2>&1 &
CUDA_VISIBLE_DEVICES=3 nohup python translate_file.py -load_weights weights/pinyin_to_hanzi_onlyChinese/11-04_16:36:46/pinyin_to_hanzi_onlyChinese_27_0.0009592685928873834 -pkl_dir weights/pinyin_to_hanzi_onlyChinese/11-04_16:36:46 -test_dir data/test_data/pinyin_onlyChinese -result_dir data/test_data/result_onlyChinese >log2 2>&1 & CUDA_VISIBLE_DEVICES=3 nohup python translate_file.py -load_weights weights/pinyin_to_hanzi_onlyChinese/11-04_16:36:46/pinyin_to_hanzi_onlyChinese_27_0.0009592685928873834 -pkl_dir weights/pinyin_to_hanzi_onlyChinese/11-04_16:36:46 -test_dir data/test_data/pinyin_onlyChinese -result_dir data/test_data/result_onlyChinese >log2 2>&1 &
CUDA_VISIBLE_DEVICES=3 python translate_file2.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_17_0.06582485044375062 -pkl_dir weights/token_classification/11-09_22:00:55 -test_dir data/test_data/pinyin -result_dir data/test_data/result_token_classification
CUDA_VISIBLE_DEVICES=1 python train2.py -src_data data/pinyin_2.txt -trg_data data/hanzi_2.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification
CUDA_VISIBLE_DEVICES=1 nohup python train2.py -src_data data/pinyin_2.txt -trg_data data/hanzi_2.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
import pickle
import os
import itertools
for file in os.listdir("./data/pkl/1105-pinyin-pkl")[2:3]:
contents = pickle.load(open(os.path.join("data/pkl/1105-pinyin-pkl",file),"rb"))
contents=[list(itertools.chain.from_iterable(lines)) for lines in contents]
print(contents)
import argparse
import time
import torch
from Models import get_model, get_model_token_classification
from Process import *
import torch.nn.functional as F
from Optim import CosineWithRestarts
from Batch import create_masks, create_masks2
import dill as pickle
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def train_model(model, opt, start_time):
print("training model...")
model.train()
start = time.time()
if opt.checkpoint > 0:
cptime = time.time()
dst = "./weights/{}/{}".format(opt.model_name,
time.strftime("%m-%d_%H:%M:%S", start_time))
if not os.path.exists(dst):
os.makedirs(dst)
print("create {} dir".format(dst))
for epoch in range(opt.epochs):
total_loss = 0
if opt.floyd is False:
print(" %dm: epoch %d [%s] %d%% loss = %s" %
((time.time() - start)//60, epoch + 1, "".join(' '*20), 0, '...'), end='\r')
# if opt.checkpoint > 0:
# torch.save(model.state_dict(), 'weights/model_weights')
for i, batch in enumerate(opt.train):
src = batch.src.transpose(0, 1).cuda()
trg = batch.trg.transpose(0, 1).cuda()
# print("src shape:{} trg shape:{}".format(src.shape,trg.shape))
trg_input = trg[:, :-1]
# src_mask, trg_mask = create_masks(src, trg_input, opt)
src_mask = create_masks2(src, opt)
preds = model(src, src_mask)
# ys = trg[:, 1:].contiguous().view(-1)
ys = trg[:, :].contiguous().view(-1)
opt.optimizer.zero_grad()
loss = F.cross_entropy(
preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
loss.backward()
opt.optimizer.step()
if opt.SGDR == True:
opt.sched.step()
total_loss += loss.item()
if (i + 1) % opt.printevery == 0:
p = int(100 * (i + 1) / opt.train_len)
avg_loss = total_loss/opt.printevery
if opt.floyd is False:
print(" %dm: epoch %d [%s%s] %d%% loss = %.3f" %
((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r')
else:
print(" %dm: epoch %d [%s%s] %d%% loss = %.3f" %
((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss))
total_loss = 0
# if opt.checkpoint > 0 and ((time.time()-cptime)//60) // opt.checkpoint >= 1:
# torch.save(model.state_dict(), 'weights/model_weights')
# cptime = time.time()
print("%dm: epoch %d [%s%s] %d%% loss = %.3f\nepoch %d complete, loss = %.03f" %
((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_loss))
torch.save(model.state_dict(), os.path.join(
dst, opt.model_name+"_{}_{}".format(epoch + 1, avg_loss)))
print("model saved as {}".format(opt.model_name +
"_{}_{}".format(epoch + 1, round(avg_loss, 3))))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-src_data', required=True)
parser.add_argument('-trg_data', required=True)
parser.add_argument('-src_lang', required=True)
parser.add_argument('-trg_lang', required=True)
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-SGDR', action='store_true')
parser.add_argument('-epochs', type=int, default=2)
parser.add_argument('-d_model', type=int, default=512)
parser.add_argument('-n_layers', type=int, default=6)
parser.add_argument('-heads', type=int, default=8)
parser.add_argument('-dropout', type=int, default=0.1)
parser.add_argument('-batchsize', type=int, default=64)
parser.add_argument('-printevery', type=int, default=100)
parser.add_argument('-lr', type=int, default=0.0001)
parser.add_argument('-load_weights')
parser.add_argument('-create_valset', action='store_true')
parser.add_argument('-max_strlen', type=int, default=80)
parser.add_argument('-floyd', action='store_true')
parser.add_argument('-checkpoint', type=int, default=0)
parser.add_argument('-model_name', type=str, default="transformer")
parser.add_argument('-pkl_dir')
opt = parser.parse_args()
start_time = time.localtime()
opt.device = 0 if opt.no_cuda is False else -1
if opt.device == 0:
assert torch.cuda.is_available()
read_data(opt)
SRC, TRG = create_fields(opt)
opt.train = create_dataset(opt, SRC, TRG)
print("train_len:{}".format(opt.train_len))
# model = get_model(opt, len(SRC.vocab), len(TRG.vocab))
model = get_model_token_classification(opt, len(SRC.vocab), len(TRG.vocab))
dst = "./weights/{}/{}".format(opt.model_name,
time.strftime("%m-%d_%H:%M:%S", start_time))
if not os.path.exists(dst):
os.makedirs(dst)
print("create {} dir".format(dst))
pickle.dump(SRC, open(f'{dst}/SRC.pkl', 'wb'))
pickle.dump(TRG, open(f'{dst}/TRG.pkl', 'wb'))
opt.optimizer = torch.optim.Adam(
model.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
if opt.SGDR == True:
opt.sched = CosineWithRestarts(opt.optimizer, T_max=opt.train_len)
if opt.checkpoint > 0:
print("model weights will be saved every %d minutes and at end of epoch to directory weights/" % (opt.checkpoint))
# if opt.load_weights is not None and opt.floyd is not None:
# os.mkdir('weights')
# pickle.dump(SRC, open('weights/SRC.pkl', 'wb'))
# pickle.dump(TRG, open('weights/TRG.pkl', 'wb'))
train_model(model, opt, start_time)
if opt.floyd is False:
promptNextAction(model, opt, SRC, TRG, start_time)
def yesno(response):
while True:
if response != 'y' and response != 'n':
response = input('command not recognised, enter y or n : ')
else:
return response
def promptNextAction(model, opt, SRC, TRG, start_time):
saved_once = 1 if opt.load_weights is not None or opt.checkpoint > 0 else 0
# if opt.load_weights is not None:
# dst = opt.load_weights
# if opt.checkpoint > 0:
# dst = 'weights'
dst = "./weights/{}/{}".format(opt.model_name,
time.strftime("%m-%d_%H:%M:%S", start_time))
# while True:
# save = yesno(input('training complete, save results? [y/n] : '))
# if save == 'y':
# while True:
# if saved_once != 0:
# res = yesno("save to same folder? [y/n] : ")
# if res == 'y':
# break
# dst = input('enter folder name to create for weights (no spaces) : ')
# if ' ' in dst or len(dst) < 1 or len(dst) > 30:
# dst = input("name must not contain spaces and be between 1 and 30 characters length, enter again : ")
# else:
# try:
# os.mkdir(dst)
# except:
# res= yesno(input(dst + " already exists, use anyway? [y/n] : "))
# if res == 'n':
# continue
# break
# print("saving weights to " + dst + "/...")
# torch.save(model.state_dict(), f'{dst}/model_weights')
# if saved_once == 0:
# pickle.dump(SRC, open(f'{dst}/SRC.pkl', 'wb'))
# pickle.dump(TRG, open(f'{dst}/TRG.pkl', 'wb'))
# saved_once = 1
# print("weights and field pickles saved to " + dst)
# res = yesno(input("train for more epochs? [y/n] : "))
# if res == 'y':
# while True:
# epochs = input("type number of epochs to train for : ")
# try:
# epochs = int(epochs)
# except:
# print("input not a number")
# continue
# if epochs < 1:
# print("epochs must be at least 1")
# continue
# else:
# break
# opt.epochs = epochs
# train_model(model, opt)
# else:
# print("exiting program...")
# break
# os.mkdir(dst)
# print("saving weights to " + dst + "/...")
# torch.save(model.state_dict(), f'{dst}/model_weights')
pickle.dump(SRC, open(f'{dst}/SRC.pkl', 'wb'))
pickle.dump(TRG, open(f'{dst}/TRG.pkl', 'wb'))
print("field pickles saved to " + dst)
# for asking about further training use while true loop, and return
if __name__ == "__main__":
main()
import argparse
import torch
import time
import torch
from Models import get_model, get_model_token_classification
from Process import *
import torch.nn.functional as F
from Optim import CosineWithRestarts
from Batch import create_masks
import pdb
import dill as pickle
import argparse
from Models import get_model
from Beam import beam_search
# from nltk.corpus import wordnet
from torch.autograd import Variable
def get_result(src, model, SRC, TRG, opt):
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
output = model(src, src_mask)
output = F.softmax(output,dim=-1)
preds = torch.argmax(output, dim=-1)
return ''.join([TRG.vocab.itos[tok] for tok in preds[0][:]]).replace("_", "")
def translate_sentence(sentence, model, opt, SRC, TRG):
model.eval()
indexed = []
sentence = SRC.preprocess(sentence)
for tok in sentence:
if SRC.vocab.stoi[tok] != 0 or opt.floyd == True:
indexed.append(SRC.vocab.stoi[tok])
else:
# indexed.append(get_synonym(tok, SRC))
pass
sentence = Variable(torch.LongTensor([indexed]))
if opt.device == 0:
sentence = sentence.cuda()
sentence = get_result(sentence, model, SRC, TRG, opt)
return sentence
def translate(opt, model, SRC, TRG):
# sentences = opt.text.lower().split('.')
sentence = opt.text
# sentences=[a for a in sentences if len(a)>0]
translated = []
translated.append(translate_sentence(sentence, model, opt, SRC, TRG))
return (' '.join(translated)[:])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-load_weights', required=True)
parser.add_argument('-pkl_dir', required=True)
parser.add_argument('-k', type=int, default=3)
parser.add_argument('-max_len', type=int, default=80)
parser.add_argument('-d_model', type=int, default=512)
parser.add_argument('-n_layers', type=int, default=6)
# parser.add_argument('-src_lang', required=True)
# parser.add_argument('-trg_lang', required=True)
parser.add_argument('-heads', type=int, default=8)
parser.add_argument('-dropout', type=int, default=0.1)
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-floyd', action='store_true')
opt = parser.parse_args()
opt.device = 0 if opt.no_cuda is False else -1
assert opt.k > 0
assert opt.max_len > 10
SRC, TRG = create_fields(opt)
model = get_model_token_classification(opt, len(SRC.vocab), len(TRG.vocab))
while True:
opt.text =input("Enter a sentence to translate (type 'f' to load from file, or 'q' to quit):\n")
if opt.text=="q":
break
if opt.text=='f':
fpath =input("Enter a sentence to translate (type 'f' to load from file, or 'q' to quit):\n")
try:
opt.text = ' '.join(open(opt.text, encoding='utf-8').read().split('\n'))
except:
print("error opening or reading text file")
continue
phrase = translate(opt, model, SRC, TRG)
print('> '+ phrase + '\n')
if __name__ == '__main__':
main()
import argparse
import time
import torch
from Models import get_model, get_model_token_classification
from Process import *
import torch.nn.functional as F
from Optim import CosineWithRestarts
from Batch import create_masks
import pdb
import dill as pickle
import argparse
from Models import get_model
from Beam import beam_search
# from nltk.corpus import wordnet
from torch.autograd import Variable
import re
import time
import random
def get_result(src, model, SRC, TRG, opt):
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
output = model(src, src_mask)
output = F.softmax(output,dim=-1)
preds = torch.argmax(output, dim=-1)
return ''.join([TRG.vocab.itos[tok] for tok in preds[0][:]]).replace("_", "")
def translate_sentence(sentence, model, opt, SRC, TRG):
model.eval()
indexed = []
sentence = SRC.preprocess(sentence)
for tok in sentence:
if SRC.vocab.stoi[tok] != 0 or opt.floyd == True:
indexed.append(SRC.vocab.stoi[tok])
else:
# indexed.append(get_synonym(tok, SRC))
pass
sentence = Variable(torch.LongTensor([indexed]))
if opt.device == 0:
sentence = sentence.cuda()
sentence = get_result(sentence, model, SRC, TRG, opt)
return sentence
def translate(i,opt, model, SRC, TRG):
# sentences = opt.text.lower().split('.')
sentence = i
# sentences=[a for a in sentences if len(a)>0]
translated = []
translated.append(translate_sentence(sentence, model, opt, SRC, TRG))
return (' '.join(translated)[:])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-load_weights', required=True)
parser.add_argument('-pkl_dir', required=True)
parser.add_argument('-k', type=int, default=3)
parser.add_argument('-max_len', type=int, default=80)
parser.add_argument('-d_model', type=int, default=512)
parser.add_argument('-n_layers', type=int, default=6)
# parser.add_argument('-src_lang', required=True)
# parser.add_argument('-trg_lang', required=True)
parser.add_argument('-heads', type=int, default=8)
parser.add_argument('-dropout', type=int, default=0.1)
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-floyd', action='store_true')
parser.add_argument('-test_dir', type=str, required=True)
parser.add_argument('-result_dir', type=str, required=True)
opt = parser.parse_args()
opt.device = 0 if opt.no_cuda is False else -1
assert opt.k > 0
assert opt.max_len > 10
SRC, TRG = create_fields(opt)
model = get_model_token_classification(opt, len(SRC.vocab), len(TRG.vocab))
for file in os.listdir(opt.test_dir):
print("filename:{}".format(file))
contents = open(os.path.join(opt.test_dir,file)).read().strip().split('\n')
# contents=random.sample(contents,10)
start=time.time()
translates = [translate(i, opt,model, SRC, TRG) for i in tqdm(contents)]
print("Average time: {}".format((time.time()-start)/len(contents)))
with open(os.path.join(opt.result_dir,file),'w',encoding='utf-8') as f:
f.write("\n".join(translates))
# while True:
# opt.text =input("Enter a sentence to translate (type 'f' to load from file, or 'q' to quit):\n")
# if opt.text=="q":
# break
# if opt.text=='f':
# fpath =input("Enter a sentence to translate (type 'f' to load from file, or 'q' to quit):\n")
# try:
# opt.text = ' '.join(open(opt.text, encoding='utf-8').read().split('\n'))
# except:
# print("error opening or reading text file")
# continue
# phrase = translate(opt, model, SRC, TRG)
# print('> '+ phrase + '\n')
if __name__ == '__main__':
main()
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import time import time
from Models import get_model from Models import get_model
import random import random
import itertools
from tqdm import tqdm from tqdm import tqdm
from translate_file import translate from translate_file import translate
...@@ -41,6 +42,8 @@ def main(): ...@@ -41,6 +42,8 @@ def main():
# print("filename:{}".format(file)) # print("filename:{}".format(file))
contents = pickle.load(open(os.path.join(opt.test_dir,file),"rb")) contents = pickle.load(open(os.path.join(opt.test_dir,file),"rb"))
# contents=random.sample(contents,10) # contents=random.sample(contents,10)
contents=[list(itertools.chain.from_iterable(lines)) for lines in contents]
start=time.time() start=time.time()
translates = [translate("".join(i), opt,model, SRC, TRG) for i in contents if len("".join(i))<=200] translates = [translate("".join(i), opt,model, SRC, TRG) for i in contents if len("".join(i))<=200]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment