Commit 1c06d13f authored by szr712's avatar szr712

修改output max的错误

parent a0b3ed3a
...@@ -72,6 +72,8 @@ def get_result(src, model, SRC, TRG, opt): ...@@ -72,6 +72,8 @@ def get_result(src, model, SRC, TRG, opt):
if opt.tone_filter == True: if opt.tone_filter == True:
output_min, _ = output.min(dim=-1) output_min, _ = output.min(dim=-1)
output_min = output_min.min() output_min = output_min.min()
output_max, _ = output.max(dim=-1)
output_max = output_max.max()
vals, indices = output.topk(k=10, dim=-1, largest=True, sorted=True) vals, indices = output.topk(k=10, dim=-1, largest=True, sorted=True)
result = [] result = []
for x, index in enumerate(src[0][:]): for x, index in enumerate(src[0][:]):
...@@ -94,6 +96,7 @@ def get_result(src, model, SRC, TRG, opt): ...@@ -94,6 +96,7 @@ def get_result(src, model, SRC, TRG, opt):
output[0, x, indices[0][x][i]] = output_min output[0, x, indices[0][x][i]] = output_min
if not flag: if not flag:
result.append(TRG.vocab.itos[indices[0][x][0]]) result.append(TRG.vocab.itos[indices[0][x][0]])
output[0, x, indices[0][x][0]] = output_max
# 查看之前的声母 # 查看之前的声母
else: else:
result = result[:-1] result = result[:-1]
...@@ -112,9 +115,10 @@ def get_result(src, model, SRC, TRG, opt): ...@@ -112,9 +115,10 @@ def get_result(src, model, SRC, TRG, opt):
# if flag: # if flag:
# break # break
if not flag: if not flag:
output[0, x, indices[0][x][i]] = output_min output[0, x-1, indices[0][x-1][i]] = output_min
if not flag: if not flag:
result.append(TRG.vocab.itos[indices[0][x-1][0]]) result.append(TRG.vocab.itos[indices[0][x-1][0]])
output[0, x, indices[0][x-1][0]] = output_max
result.append(TRG.vocab.itos[indices[0][x][0]]) result.append(TRG.vocab.itos[indices[0][x][0]])
else: else:
result.append(TRG.vocab.itos[indices[0][x][0]]) result.append(TRG.vocab.itos[indices[0][x][0]])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment