Commit 5e1da936 authored by szr712's avatar szr712

修改UNK问题

parent 961b319e
......@@ -99,6 +99,7 @@ def create_fields(opt):
print("loading presaved fields...")
SRC = pickle.load(open(f'{opt.pkl_dir}/SRC.pkl', 'rb'))
TRG = pickle.load(open(f'{opt.pkl_dir}/TRG.pkl', 'rb'))
# print(SRC.vocab.stoi)
except:
print("error opening SRC.pkl and TXT.pkl field files, please ensure they are in " + opt.load_weights + "/")
quit()
......
......@@ -43,3 +43,5 @@ CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data
CUDA_VISIBLE_DEVICES=1 python train_token_classification.py -src_data data/train_file/pinyin_split_random_wo_tones -trg_data data/train_file/hanzi_split_random_wo_tones -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
CUDA_VISIBLE_DEVICES=5 python train_token_classification.py -src_data data/train_file/pinyin_split_random_wo_tones -trg_data data/train_file/hanzi_split_random_wo_tones -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt -gpus 0
CUDA_VISIBLE_DEVICES=4 python translate2.py -load_weights weights/token_classification_split_new/11-28_14:53:18/token_classification_split_new_26_0.025378503443207592 -pkl_dir weights/token_classification_split_new/11-28_14:53:18 -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
......@@ -21,7 +21,7 @@ def get_result(src, model, SRC, TRG, opt):
output = F.softmax(output,dim=-1)
preds = torch.argmax(output, dim=-1)
# return ''.join([TRG.vocab.itos[tok] for tok in preds[0][:]]).replace("_", "")
return ' '.join([TRG.vocab.itos[tok] for tok in preds[0][:]])
return ' '.join([TRG.vocab.itos[tok] if TRG.vocab.itos[tok]!=0 else "" for tok in preds[0][:]])
def translate_sentence(sentence, model, opt, SRC, TRG):
......@@ -35,6 +35,9 @@ def translate_sentence(sentence, model, opt, SRC, TRG):
else:
# indexed.append(get_synonym(tok, SRC))
pass
if len(indexed)==0:
return ""
sentence = Variable(torch.LongTensor([indexed]))
if opt.device == 0:
sentence = sentence.cuda()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment