Commit 5a6bdd15 authored by szr712's avatar szr712

修改模型load

parent 7cb62a36
......@@ -93,15 +93,15 @@ def get_model_token_classification(opt, src_vocab, trg_vocab):
if opt.load_weights is not None:
print("loading pretrained weights...")
model.load_state_dict(torch.load(f'{opt.load_weights}'))
# checkpoint = torch.load(opt.load_weights, map_location=lambda storage, loc: storage)
# state_dict = {}
# for k, v in checkpoint.items():
# if k.startswith('module'):
# state_dict[k[7:]] = v
# else:
# state_dict[k] = v
# model.load_state_dict(state_dict)
# model.load_state_dict(torch.load(f'{opt.load_weights}'))
checkpoint = torch.load(opt.load_weights, map_location=lambda storage, loc: storage)
state_dict = {}
for k, v in checkpoint.items():
if k.startswith('module'):
state_dict[k[7:]] = v
else:
state_dict[k] = v
model.load_state_dict(state_dict)
else:
for p in model.parameters():
if p.dim() > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment