Commit 5a6bdd15 authored by szr712's avatar szr712

修改模型load

parent 7cb62a36
...@@ -93,15 +93,15 @@ def get_model_token_classification(opt, src_vocab, trg_vocab): ...@@ -93,15 +93,15 @@ def get_model_token_classification(opt, src_vocab, trg_vocab):
if opt.load_weights is not None: if opt.load_weights is not None:
print("loading pretrained weights...") print("loading pretrained weights...")
model.load_state_dict(torch.load(f'{opt.load_weights}')) # model.load_state_dict(torch.load(f'{opt.load_weights}'))
# checkpoint = torch.load(opt.load_weights, map_location=lambda storage, loc: storage) checkpoint = torch.load(opt.load_weights, map_location=lambda storage, loc: storage)
# state_dict = {} state_dict = {}
# for k, v in checkpoint.items(): for k, v in checkpoint.items():
# if k.startswith('module'): if k.startswith('module'):
# state_dict[k[7:]] = v state_dict[k[7:]] = v
# else: else:
# state_dict[k] = v state_dict[k] = v
# model.load_state_dict(state_dict) model.load_state_dict(state_dict)
else: else:
for p in model.parameters(): for p in model.parameters():
if p.dim() > 1: if p.dim() > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment