Commit e9340bc4 authored by szr712's avatar szr712

增加音调筛选功能,保证翻译结果与盲文中标注声调的拼音同音

parent 118a47d0
......@@ -49,4 +49,4 @@ CUDA_VISIBLE_DEVICES=4 python translate2.py -load_weights weights/token_classifi
CUDA_VISIBLE_DEVICES=4 python translate_file2.py -load_weights weights/token_classifiaction_aug_Chinese/token_classification_split_Chinese_148_0.012723878987599165 -pkl_dir weights/token_classifiaction_aug_Chinese -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt -test_dir data/test_data/end2end/Chinese/pinyin -result_dir data/test_data/end2end/Chinese/hanzi
CUDA_VISIBLE_DEVICES=4 python translate_file2.py -load_weights weights/Chinese_weights/token_classification_split_Chinese_149_0.013830240316456183 -pkl_dir weights/Chinese_weights -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt -test_dir data/test_data/end2end_chinese/pinyin -result_dir data/test_data/end2end_chinese/pre
CUDA_VISIBLE_DEVICES=4 python translate_file2.py -load_weights weights/Chinese_weights/token_classification_split_Chinese_149_0.013830240316456183 -pkl_dir weights/Chinese_weights -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt -test_dir data/test_data/end2end_chinese/pinyin -result_dir data/test_data/end2end_chinese/pre -tone_filter
......@@ -17,6 +17,8 @@ import re
import time
import random
import copy
from pypinyin import pinyin, Style
from pypinyin.style._utils import get_initials, get_finals
def get_yunmus(file_path):
......@@ -31,8 +33,57 @@ def get_result(src, model, SRC, TRG, opt):
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
output = model(src, src_mask)
output = F.softmax(output, dim=-1)
preds = torch.argmax(output, dim=-1)
return ''.join([TRG.vocab.itos[tok] for tok in preds[0][:] if tok.item() != 0 ]).replace("_", "").replace(" ","")
print(output.shape)
if opt.tone_filter == True:
vals, indices = output.topk(k=5, dim=-1, largest=True, sorted=True)
print(indices.shape)
result = []
for x, index in enumerate(src[0][:]):
if SRC.vocab.itos[index] in opt.yunmus and SRC.vocab.itos[index][-1] != "0":
# 单韵母字
if x == 0 or TRG.vocab.itos[indices[0][x][0]] != "_":
flag = False
for i in range(0, 5):
hanzi = TRG.vocab.itos[indices[0][x][i]]
# 考虑多音字
for p in pinyin(TRG.vocab.itos[indices[0][x][i]], style=Style.TONE3, neutral_tone_with_five=True, heteronym=True)[0]:
t_yunmu = get_finals(p[:-1], True)
if p[-1] == SRC.vocab.itos[index][-1] and t_yunmu == SRC.vocab.itos[index][:-1]:
result.append(hanzi)
flag = True
break
if flag:
break
if not flag:
result.append(TRG.vocab.itos[indices[0][x][0]])
# 查看之前的声母
else:
result = result[:-1]
flag = False
for i in range(0, 5):
tmp = indices[0][x-1][i]
hanzi = TRG.vocab.itos[tmp]
for p in pinyin(TRG.vocab.itos[indices[0][x-1][i]], style=Style.TONE3, neutral_tone_with_five=True, heteronym=True)[0]:
t_shenmu = get_initials(p[:-1], True)
t_yunmu = get_finals(p[:-1], True)
if p[-1] == SRC.vocab.itos[index][-1] and t_yunmu == SRC.vocab.itos[index][:-1] and t_shenmu == SRC.vocab.itos[src[0][x-1]]:
result.append(hanzi)
result.append(TRG.vocab.itos[indices[0][x][0]])
flag = True
break
if flag:
break
if not flag:
result.append(TRG.vocab.itos[indices[0][x-1][0]])
result.append(TRG.vocab.itos[indices[0][x][0]])
else:
result.append(TRG.vocab.itos[indices[0][x][0]])
return ''.join(result).replace("_", "").replace(" ", "")
else:
preds = torch.argmax(output, dim=-1)
return ''.join([TRG.vocab.itos[tok] for tok in preds[0][:] if tok.item() != 0]).replace("_", "").replace(" ", "")
# return ' '.join([TRG.vocab.itos[tok] for tok in preds[0][:]])
......@@ -51,7 +102,7 @@ def translate_sentence(sentence, model, opt, SRC, TRG):
# indexed.append(get_synonym(tok, SRC))
pass
sentence = Variable(torch.LongTensor([indexed]))
# if len(ori.split(" "))!=len(indexed):
# print("\n")
# print(ori)
......@@ -72,10 +123,10 @@ def translate(i, opt, model, SRC, TRG):
translated = []
for sentence in sentences:
sentence=sentence.split(" ")
if len(sentence)==1 and sentence[0]=="":
sentence = sentence.split(" ")
if len(sentence) == 1 and sentence[0] == "":
continue
sentence=" ".join([a for a in sentence if a != " " and a!=""])
sentence = " ".join([a for a in sentence if a != " " and a != ""])
# result=translate_sentence(sentence, model, opt, SRC, TRG)
# if len(sentence.split(" "))!=len(result.split(" ")):
# print("\n")
......@@ -133,6 +184,7 @@ def main():
parser.add_argument('-floyd', action='store_true')
parser.add_argument('-test_dir', type=str, required=True)
parser.add_argument('-result_dir', type=str, required=True)
parser.add_argument('-tone_filter', action='store_true')
parser.add_argument('-src_voc')
parser.add_argument('-trg_voc')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment