Commit 6b8a8e1a authored by szr712's avatar szr712

修改错误

parent e9340bc4
......@@ -33,11 +33,9 @@ def get_result(src, model, SRC, TRG, opt):
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
output = model(src, src_mask)
output = F.softmax(output, dim=-1)
print(output.shape)
if opt.tone_filter == True:
vals, indices = output.topk(k=5, dim=-1, largest=True, sorted=True)
print(indices.shape)
result = []
for x, index in enumerate(src[0][:]):
if SRC.vocab.itos[index] in opt.yunmus and SRC.vocab.itos[index][-1] != "0":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment