Commit 25cce464 authored by szr712's avatar szr712

增加汉字logits拼接函数

parent d8b00dd6
......@@ -29,6 +29,16 @@ def get_yunmus(file_path):
return yunmus
def get_hanzi_logits(logits, SRC, TRG, opt):
preds = torch.argmax(logits, dim=-1)
result = torch.randn(1, logits.shape[2]).cuda()
for i, tok in enumerate(preds[0][:]):
if '\u4e00' <= TRG.vocab.itos[tok] <= '\u9fa5': # 判断是否是中文
result = torch.cat((result, logits[0][i:i+1]), dim=0)
result = result.unsqueeze(dim=0)
return result
def get_result(src, model, SRC, TRG, opt):
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
output = model(src, src_mask)
......@@ -80,6 +90,7 @@ def get_result(src, model, SRC, TRG, opt):
return ''.join(result).replace("_", "").replace(" ", "")
else:
# output=get_hanzi_logits(output, SRC, TRG, opt)
preds = torch.argmax(output, dim=-1)
return ''.join([TRG.vocab.itos[tok] for tok in preds[0][:] if tok.item() != 0]).replace("_", "").replace(" ", "")
# return ' '.join([TRG.vocab.itos[tok] for tok in preds[0][:]])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment