Commit 3fc8f0c1 authored by szr712's avatar szr712

完成token classification模型的修改

parent f93d31d8
......@@ -28,6 +28,11 @@ def create_masks(src, trg, opt):
trg_mask = None
return src_mask, trg_mask
def create_masks2(src, opt):
src_mask = (src != opt.src_pad).unsqueeze(-2)
return src_mask
# patch on Torchtext's batching process that makes it more efficient
# from http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks
......
......@@ -51,6 +51,19 @@ class Transformer(nn.Module):
output = self.out(d_output)
return output
class TransformerForTokenClassification(nn.Module):
def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout):
super().__init__()
self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)
# self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout)
self.out = nn.Linear(d_model, trg_vocab)
def forward(self, src, src_mask):
outputs = self.encoder(src, src_mask)
#print("DECODER")
# d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
output = self.out(outputs)
return output
def get_model(opt, src_vocab, trg_vocab):
assert opt.d_model % opt.heads == 0
......@@ -70,4 +83,25 @@ def get_model(opt, src_vocab, trg_vocab):
model = model.cuda()
return model
def get_model_token_classification(opt, src_vocab, trg_vocab):
assert opt.d_model % opt.heads == 0
assert opt.dropout < 1
model = TransformerForTokenClassification(src_vocab, trg_vocab, opt.d_model, opt.n_layers, opt.heads, opt.dropout)
if opt.load_weights is not None:
print("loading pretrained weights...")
model.load_state_dict(torch.load(f'{opt.load_weights}'))
else:
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
if opt.device == 0:
model = model.cuda()
return model
......@@ -60,7 +60,9 @@ def create_fields(opt):
# TRG = data.Field(lower=True, tokenize=t_trg.tokenizer, init_token='<sos>', eos_token='<eos>')
# SRC = data.Field(lower=True, tokenize=t_src.tokenizer)
TRG = data.Field(tokenize=my_tokenize, init_token='<sos>', eos_token='<eos>')
# TRG = data.Field(tokenize=my_tokenize, init_token='<sos>', eos_token='<eos>')
TRG = data.Field(tokenize=my_tokenize)
SRC = data.Field(tokenize=my_tokenize)
if opt.pkl_dir is not None:
......@@ -92,17 +94,20 @@ def create_dataset(opt, SRC, TRG):
data_fields = [('src', SRC), ('trg', TRG)]
train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)
# train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
# repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
# batch_size_fn=batch_size_fn, train=True, shuffle=True)
train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=batch_size_fn, train=True, shuffle=True)
batch_size_fn=None, train=True, shuffle=True)
os.remove('translate_transformer_temp.csv')
if opt.load_weights is None:
SRC.build_vocab(train)
print(SRC.vocab.stoi)
# print(SRC.vocab.stoi)
TRG.build_vocab(train)
print(TRG.vocab.stoi)
# print(TRG.vocab.stoi)
if opt.checkpoint > 0:
try:
os.mkdir("weights")
......
......@@ -2,12 +2,22 @@ CUDA_VISIBLE_DEVICES=5 nohup python train.py -src_data data/train_set_onlyChines
CUDA_VISIBLE_DEVICES=5 python train.py -src_data data/train_set_pinyin_total.txt -trg_data data/train_set_total.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name pinyin_to_hanzi_total -load_weights weights/pinyin_to_hanzi_total/10-29_18:51:57/pinyin_to_hanzi_total_10_0.1508198243379593 -pkl_dir weights/pinyin_to_hanzi_total/10-29_18:51:57
CUDA_VISIBLE_DEVICES=6 python translate.py -load_weights weights/pinyin_to_hanzi_total/10-29_18:51:57/pinyin_to_hanzi_total_10_0.1508198243379593 -pkl_dir weights/pinyin_to_hanzi_total/10-29_18:51:57
CUDA_VISIBLE_DEVICES=3 python translate.py -load_weights weights/pinyin_to_hanzi_total/10-29_18:51:57/pinyin_to_hanzi_total_10_0.1508198243379593 -pkl_dir weights/pinyin_to_hanzi_total/10-29_18:51:57
CUDA_VISIBLE_DEVICES=3 python translate2.py -load_weights weights/token_classification/11-09_17:20:17/token_classification_4_0.09534996442496776 -pkl_dir weights/token_classification/11-09_17:20:17
CUDA_VISIBLE_DEVICES=5 python translate_file.py -load_weights weights/pinyin_to_hanzi_total/10-30_21:22:49/pinyin_to_hanzi_total_9_0.12325442619621754 -pkl_dir weights/pinyin_to_hanzi_total/10-30_21:22:49 -test_dir data/test_data/pinyin_short -result_dir data/test_data/result_tmp
CUDA_VISIBLE_DEVICES=6 python translate_pkl.py -load_weights weights/pinyin_to_hanzi_total/10-30_21:22:49/pinyin_to_hanzi_total_9_0.12325442619621754 -pkl_dir weights/pinyin_to_hanzi_total/10-30_21:22:49 -test_dir data/pkl/label_pkl -result_dir data/pkl/lable_pkl_result
CUDA_VISIBLE_DEVICES=1 python translate_pkl.py -load_weights weights/pinyin_to_hanzi_total/10-30_21:22:49/pinyin_to_hanzi_total_9_0.12325442619621754 -pkl_dir weights/pinyin_to_hanzi_total/10-30_21:22:49 -test_dir data/pkl/1105-pinyin-pkl -result_dir data/pkl/1105-pinyin-pkl-result
CUDA_VISIBLE_DEVICES=2 nohup python translate_file.py -load_weights weights/pinyin_to_hanzi_total/10-30_21:22:49/pinyin_to_hanzi_total_59_0.07513352055102587 -pkl_dir weights/pinyin_to_hanzi_total/10-30_21:22:49 -test_dir data/test_data/pinyin_short -result_dir data/test_data/result_short >log2 2>&1 &
CUDA_VISIBLE_DEVICES=3 nohup python translate_file.py -load_weights weights/pinyin_to_hanzi_onlyChinese/11-04_16:36:46/pinyin_to_hanzi_onlyChinese_27_0.0009592685928873834 -pkl_dir weights/pinyin_to_hanzi_onlyChinese/11-04_16:36:46 -test_dir data/test_data/pinyin_onlyChinese -result_dir data/test_data/result_onlyChinese >log2 2>&1 &
CUDA_VISIBLE_DEVICES=3 python translate_file2.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_17_0.06582485044375062 -pkl_dir weights/token_classification/11-09_22:00:55 -test_dir data/test_data/pinyin -result_dir data/test_data/result_token_classification
CUDA_VISIBLE_DEVICES=1 python train2.py -src_data data/pinyin_2.txt -trg_data data/hanzi_2.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification
CUDA_VISIBLE_DEVICES=1 nohup python train2.py -src_data data/pinyin_2.txt -trg_data data/hanzi_2.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment