Commit 237ee68c authored by szr712's avatar szr712

增加在线随机抹除音调

parent 8c385980
...@@ -2,41 +2,95 @@ import torch ...@@ -2,41 +2,95 @@ import torch
from torchtext import data from torchtext import data
import numpy as np import numpy as np
from torch.autograd import Variable from torch.autograd import Variable
import copy
import random
from tqdm import tqdm
import time
def nopeak_mask(size, opt): def nopeak_mask(size, opt):
np_mask = np.triu(np.ones((1, size, size)), np_mask = np.triu(np.ones((1, size, size)),
k=1).astype('uint8') k=1).astype('uint8')
np_mask = Variable(torch.from_numpy(np_mask) == 0) np_mask = Variable(torch.from_numpy(np_mask) == 0)
if opt.device == 0: if opt.device == 0:
np_mask = np_mask.cuda() np_mask = np_mask.cuda()
return np_mask return np_mask
def create_masks(src, trg, opt): def create_masks(src, trg, opt):
src_mask = (src != opt.src_pad).unsqueeze(-2) src_mask = (src != opt.src_pad).unsqueeze(-2)
if trg is not None: if trg is not None:
trg_mask = (trg != opt.trg_pad).unsqueeze(-2) trg_mask = (trg != opt.trg_pad).unsqueeze(-2)
size = trg.size(1) # get seq_len for matrix size = trg.size(1) # get seq_len for matrix
np_mask = nopeak_mask(size, opt) np_mask = nopeak_mask(size, opt)
if trg.is_cuda: if trg.is_cuda:
np_mask.cuda() np_mask.cuda()
trg_mask = trg_mask & np_mask trg_mask = trg_mask & np_mask
else: else:
trg_mask = None trg_mask = None
return src_mask, trg_mask return src_mask, trg_mask
def create_masks2(src, opt): def create_masks2(src, opt):
src_mask = (src != opt.src_pad).unsqueeze(-2) src_mask = (src != opt.src_pad).unsqueeze(-2)
return src_mask return src_mask
# patch on Torchtext's batching process that makes it more efficient # patch on Torchtext's batching process that makes it more efficient
# from http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks # from http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks
class MyIterator(data.Iterator): class MyIterator(data.Iterator):
def __init__(self, dataset, batch_size, sort_key=None, device=None,
batch_size_fn=None, train=True,
repeat=False, shuffle=None, sort=None,
sort_within_batch=None, augment=False, change_possibility=[0.5, 0.6, 0.7, 0.8, 0.9, 1]):
super().__init__(dataset, batch_size, sort_key, device,
batch_size_fn, train,
repeat, shuffle, sort,
sort_within_batch)
self.augment = augment
self.change_possibility = change_possibility
print("start copy...")
start = time.time()
self.ori_examples = copy.deepcopy(self.dataset.examples)
print("end copy..., cost total:{}s".format(time.time()-start))
with open("./data/voc/yunmus.txt", "r", encoding="utf-8") as f:
yunmu = f.readlines()
self.yunmus = [a.strip() for a in yunmu]
def data(self):
"""重载data.Iterator的data方法,增加扩容代码
Return the examples in the dataset in order, sorted, or shuffled."""
if self.augment:
print("augmenting data...")
self.dataset.examples = []
for ex in tqdm(self.ori_examples):
for p in self.change_possibility:
new_ex = copy.deepcopy(ex)
for i, char in enumerate(ex.src):
r=random.random()
if r < p and char in self.yunmus:
new_ex.src[i] = char[:-1]+"0"
self.dataset.examples.append(new_ex)
print("data len:{}".format(len(self.dataset.examples)))
# print("src:{}\ntrg:{}".format(type(ex.src),type(ex.trg)))
if self.sort:
xs = sorted(self.dataset, key=self.sort_key)
elif self.shuffle:
xs = [self.dataset[i]
for i in self.random_shuffler(range(len(self.dataset)))]
else:
xs = self.dataset
return xs
def create_batches(self): def create_batches(self):
if self.train: if self.train:
def pool(d, random_shuffler): def pool(d, random_shuffler):
...@@ -47,15 +101,17 @@ class MyIterator(data.Iterator): ...@@ -47,15 +101,17 @@ class MyIterator(data.Iterator):
for b in random_shuffler(list(p_batch)): for b in random_shuffler(list(p_batch)):
yield b yield b
self.batches = pool(self.data(), self.random_shuffler) self.batches = pool(self.data(), self.random_shuffler)
else: else:
self.batches = [] self.batches = []
for b in data.batch(self.data(), self.batch_size, for b in data.batch(self.data(), self.batch_size,
self.batch_size_fn): self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key)) self.batches.append(sorted(b, key=self.sort_key))
global max_src_in_batch, max_tgt_in_batch global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar): def batch_size_fn(new, count, sofar):
"Keep augmenting batch and calculate total number of tokens + padding." "Keep augmenting batch and calculate total number of tokens + padding."
global max_src_in_batch, max_tgt_in_batch global max_src_in_batch, max_tgt_in_batch
......
...@@ -33,8 +33,15 @@ def read_data(opt): ...@@ -33,8 +33,15 @@ def read_data(opt):
if opt.src_data is not None: if opt.src_data is not None:
try: try:
print("loading src_data") print("loading src_data")
opt.src_data = open(opt.src_data).read().strip().split('\n') if os.path.isdir(opt.src_data):
opt.src_data=[x for x in tqdm(opt.src_data)] train_set=[]
for file in os.listdir(opt.src_data):
train_set.appand(open(os.path.join(opt.src_data,file)).read().strip().split('\n'))
opt.src_data=train_set
opt.src_data=[x for x in tqdm(opt.src_data)]
else:
opt.src_data = open(opt.src_data).read().strip().split('\n')
opt.src_data=[x for x in tqdm(opt.src_data)]
# print(len(opt.src_data)) # print(len(opt.src_data))
except: except:
print("error: '" + opt.src_data + "' file not found") print("error: '" + opt.src_data + "' file not found")
...@@ -108,14 +115,18 @@ def create_dataset(opt, SRC, TRG): ...@@ -108,14 +115,18 @@ def create_dataset(opt, SRC, TRG):
df.to_csv("translate_transformer_temp.csv", index=False) df.to_csv("translate_transformer_temp.csv", index=False)
data_fields = [('src', SRC), ('trg', TRG)] data_fields = [('src', SRC), ('trg', TRG)]
train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields) train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields,skip_header=True)
# train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device, # train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
# repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), # repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
# batch_size_fn=batch_size_fn, train=True, shuffle=True) # batch_size_fn=batch_size_fn, train=True, shuffle=True)
# train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
# repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
# batch_size_fn=None, train=True, shuffle=True)
train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device, train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=None, train=True, shuffle=True) batch_size_fn=None, train=True, shuffle=True,augment=True)
os.remove('translate_transformer_temp.csv') os.remove('translate_transformer_temp.csv')
......
...@@ -214,6 +214,10 @@ if __name__ == "__main__": ...@@ -214,6 +214,10 @@ if __name__ == "__main__":
# with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f: # with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
# yunmus=f.readlines() # yunmus=f.readlines()
# yunmus=[a.strip() for a in yunmus] # yunmus=[a.strip() for a in yunmus]
build_corpus("./data/train_set_new.txt", ori_dir="./data/train_file/ori_file_split_random_wo_tones"
"./data/pinyin_new_split.txt", "./data/hanzi_new_split.txt") hanzi_dir="./data/train_file/hanzi_split_random_wo_tones"
print("Done") pinyin_dir="./data/train_file/pinyin_split_random_wo_tones"
for file in os.listdir(ori_dir):
build_corpus(os.path.join(ori_dir,file),
os.path.join(pinyin_dir,file), os.path.join(hanzi_dir,file))
print("Done")
...@@ -38,3 +38,6 @@ CUDA_VISIBLE_DEVICES=6 nohup python train_token_classification.py -src_data data ...@@ -38,3 +38,6 @@ CUDA_VISIBLE_DEVICES=6 nohup python train_token_classification.py -src_data data
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log1 2>&1 & CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log1 2>&1 &
CUDA_VISIBLE_DEVICES=2 python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
...@@ -81,8 +81,8 @@ def main(): ...@@ -81,8 +81,8 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-src_data', required=True) parser.add_argument('-src_data', required=True)
parser.add_argument('-trg_data', required=True) parser.add_argument('-trg_data', required=True)
parser.add_argument('-src_lang', required=True,default="en_core_web_sm") parser.add_argument('-src_lang', default="en_core_web_sm")
parser.add_argument('-trg_lang', required=True,default="fr_core_news_sm") parser.add_argument('-trg_lang', default="fr_core_news_sm")
parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-SGDR', action='store_true') parser.add_argument('-SGDR', action='store_true')
parser.add_argument('-epochs', type=int, default=2) parser.add_argument('-epochs', type=int, default=2)
......
...@@ -8,7 +8,8 @@ from Optim import CosineWithRestarts ...@@ -8,7 +8,8 @@ from Optim import CosineWithRestarts
from Batch import create_masks, create_masks2 from Batch import create_masks, create_masks2
import dill as pickle import dill as pickle
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" from Process import get_len
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def train_model(model, opt, start_time): def train_model(model, opt, start_time):
...@@ -35,7 +36,7 @@ def train_model(model, opt, start_time): ...@@ -35,7 +36,7 @@ def train_model(model, opt, start_time):
# torch.save(model.state_dict(), 'weights/model_weights') # torch.save(model.state_dict(), 'weights/model_weights')
for i, batch in enumerate(opt.train): for i, batch in enumerate(opt.train):
src = batch.src.transpose(0, 1).cuda() src = batch.src.transpose(0, 1).cuda()
trg = batch.trg.transpose(0, 1).cuda() trg = batch.trg.transpose(0, 1).cuda()
# print("src shape:{} trg shape:{}".format(src.shape,trg.shape)) # print("src shape:{} trg shape:{}".format(src.shape,trg.shape))
...@@ -59,8 +60,8 @@ def train_model(model, opt, start_time): ...@@ -59,8 +60,8 @@ def train_model(model, opt, start_time):
p = int(100 * (i + 1) / opt.train_len) p = int(100 * (i + 1) / opt.train_len)
avg_loss = total_loss/opt.printevery avg_loss = total_loss/opt.printevery
if opt.floyd is False: if opt.floyd is False:
print(" %dm: epoch %d [%s%s] %d%% loss = %.3f" % print(" %dm: epoch %d [%s%s] %d%% steps = %d loss = %.3f" %
((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r') ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, i,avg_loss), end='\r')
else: else:
print(" %dm: epoch %d [%s%s] %d%% loss = %.3f" % print(" %dm: epoch %d [%s%s] %d%% loss = %.3f" %
((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss)) ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss))
...@@ -70,8 +71,8 @@ def train_model(model, opt, start_time): ...@@ -70,8 +71,8 @@ def train_model(model, opt, start_time):
# torch.save(model.state_dict(), 'weights/model_weights') # torch.save(model.state_dict(), 'weights/model_weights')
# cptime = time.time() # cptime = time.time()
print("%dm: epoch %d [%s%s] %d%% loss = %.3f\nepoch %d complete, loss = %.03f" % print("%dm: epoch %d [%s%s] %d%% steps = %d loss = %.3f\nepoch %d complete, total steps = %d, loss = %.03f" %
((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_loss)) ((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, i ,avg_loss, epoch + 1, i, avg_loss))
torch.save(model.state_dict(), os.path.join( torch.save(model.state_dict(), os.path.join(
dst, opt.model_name+"_{}_{}".format(epoch + 1, avg_loss))) dst, opt.model_name+"_{}_{}".format(epoch + 1, avg_loss)))
print("model saved as {}".format(opt.model_name + print("model saved as {}".format(opt.model_name +
...@@ -83,8 +84,8 @@ def main(): ...@@ -83,8 +84,8 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-src_data', required=True) parser.add_argument('-src_data', required=True)
parser.add_argument('-trg_data', required=True) parser.add_argument('-trg_data', required=True)
parser.add_argument('-src_lang', required=True,default="en_core_web_sm") parser.add_argument('-src_lang', default="en_core_web_sm")
parser.add_argument('-trg_lang', required=True,default="fr_core_news_sm") parser.add_argument('-trg_lang', default="fr_core_news_sm")
parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-SGDR', action='store_true') parser.add_argument('-SGDR', action='store_true')
parser.add_argument('-epochs', type=int, default=2) parser.add_argument('-epochs', type=int, default=2)
...@@ -117,7 +118,7 @@ def main(): ...@@ -117,7 +118,7 @@ def main():
SRC, TRG = create_fields(opt) SRC, TRG = create_fields(opt)
opt.train = create_dataset(opt, SRC, TRG) opt.train = create_dataset(opt, SRC, TRG)
print("train_len:{}".format(opt.train_len)) # print("train_len:{}".format(opt.train_len))
# model = get_model(opt, len(SRC.vocab), len(TRG.vocab)) # model = get_model(opt, len(SRC.vocab), len(TRG.vocab))
model = get_model_token_classification(opt, len(SRC.vocab), len(TRG.vocab)) model = get_model_token_classification(opt, len(SRC.vocab), len(TRG.vocab))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment