Commit d833c315 authored by szr712's avatar szr712

修改logits拼接函数

parent ea10822c
b
m
d
p
f
t
n
g
j
h
x
ch
r
c
l
k
q
zh
sh
z
s
\ No newline at end of file
......@@ -30,11 +30,20 @@ def get_yunmus(file_path):
return yunmus
def get_hanzi_logits(logits, SRC, TRG, opt):
def get_hanzi_logits_yunmu(src,logits, SRC, TRG, opt):
preds = torch.argmax(logits, dim=-1)
result = torch.randn(1, logits.shape[2]).cuda()
for i, tok in enumerate(preds[0][:]):
if '\u4e00' <= TRG.vocab.itos[tok] <= '\u9fa5': # 判断是否是中文
for i, index in enumerate(src[0][:]):
if SRC.vocab.itos[index] in opt.yunmus: # 判断是否是韵母
result = torch.cat((result, logits[0][i:i+1]), dim=0)
result = result.unsqueeze(dim=0)
return result
def get_hanzi_logits_shengmu(src,logits, SRC, TRG, opt):
preds = torch.argmax(logits, dim=-1)
result = torch.randn(1, logits.shape[2]).cuda()
for i, index in enumerate(src[0][:]):
if SRC.vocab.itos[index] in opt.shengmu: # 判断是否是声母
result = torch.cat((result, logits[0][i:i+1]), dim=0)
result = result.unsqueeze(dim=0)
return result
......@@ -91,7 +100,7 @@ def get_result(src, model, SRC, TRG, opt):
return ''.join(result).replace("_", "").replace(" ", "")
else:
# output=get_hanzi_logits(output, SRC, TRG, opt)
# output=get_hanzi_logits_shengmu(src,output, SRC, TRG, opt)
preds = torch.argmax(output, dim=-1)
return ''.join([TRG.vocab.itos[tok] for tok in preds[0][:] if tok.item() != 0]).replace("_", "").replace(" ", "")
# return ' '.join([TRG.vocab.itos[tok] for tok in preds[0][:]])
......@@ -208,6 +217,8 @@ def main():
SRC, TRG = create_fields(opt)
model = get_model_token_classification(opt, len(SRC.vocab), len(TRG.vocab))
opt.yunmus = get_yunmus("./data/voc/yunmus.txt")
opt.shengmu = get_yunmus("./data/voc/shengmu.txt")
for file in os.listdir(opt.test_dir):
print("filename:{}".format(file))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment