Commit 237ee68c authored by szr712's avatar szr712

增加在线随机抹除音调

parent 8c385980
......@@ -2,41 +2,95 @@ import torch
from torchtext import data
import numpy as np
from torch.autograd import Variable
import copy
import random
from tqdm import tqdm
import time
def nopeak_mask(size, opt):
np_mask = np.triu(np.ones((1, size, size)),
k=1).astype('uint8')
np_mask = Variable(torch.from_numpy(np_mask) == 0)
k=1).astype('uint8')
np_mask = Variable(torch.from_numpy(np_mask) == 0)
if opt.device == 0:
np_mask = np_mask.cuda()
np_mask = np_mask.cuda()
return np_mask
def create_masks(src, trg, opt):
src_mask = (src != opt.src_pad).unsqueeze(-2)
if trg is not None:
trg_mask = (trg != opt.trg_pad).unsqueeze(-2)
size = trg.size(1) # get seq_len for matrix
size = trg.size(1) # get seq_len for matrix
np_mask = nopeak_mask(size, opt)
if trg.is_cuda:
np_mask.cuda()
trg_mask = trg_mask & np_mask
else:
trg_mask = None
return src_mask, trg_mask
def create_masks2(src, opt):
src_mask = (src != opt.src_pad).unsqueeze(-2)
return src_mask
# patch on Torchtext's batching process that makes it more efficient
# from http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks
class MyIterator(data.Iterator):
def __init__(self, dataset, batch_size, sort_key=None, device=None,
batch_size_fn=None, train=True,
repeat=False, shuffle=None, sort=None,
sort_within_batch=None, augment=False, change_possibility=[0.5, 0.6, 0.7, 0.8, 0.9, 1]):
super().__init__(dataset, batch_size, sort_key, device,
batch_size_fn, train,
repeat, shuffle, sort,
sort_within_batch)
self.augment = augment
self.change_possibility = change_possibility
print("start copy...")
start = time.time()
self.ori_examples = copy.deepcopy(self.dataset.examples)
print("end copy..., cost total:{}s".format(time.time()-start))
with open("./data/voc/yunmus.txt", "r", encoding="utf-8") as f:
yunmu = f.readlines()
self.yunmus = [a.strip() for a in yunmu]
def data(self):
"""重载data.Iterator的data方法,增加扩容代码
Return the examples in the dataset in order, sorted, or shuffled."""
if self.augment:
print("augmenting data...")
self.dataset.examples = []
for ex in tqdm(self.ori_examples):
for p in self.change_possibility:
new_ex = copy.deepcopy(ex)
for i, char in enumerate(ex.src):
r=random.random()
if r < p and char in self.yunmus:
new_ex.src[i] = char[:-1]+"0"
self.dataset.examples.append(new_ex)
print("data len:{}".format(len(self.dataset.examples)))
# print("src:{}\ntrg:{}".format(type(ex.src),type(ex.trg)))
if self.sort:
xs = sorted(self.dataset, key=self.sort_key)
elif self.shuffle:
xs = [self.dataset[i]
for i in self.random_shuffler(range(len(self.dataset)))]
else:
xs = self.dataset
return xs
def create_batches(self):
if self.train:
def pool(d, random_shuffler):
......@@ -47,15 +101,17 @@ class MyIterator(data.Iterator):
for b in random_shuffler(list(p_batch)):
yield b
self.batches = pool(self.data(), self.random_shuffler)
else:
self.batches = []
for b in data.batch(self.data(), self.batch_size,
self.batch_size_fn):
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key))
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
"Keep augmenting batch and calculate total number of tokens + padding."
global max_src_in_batch, max_tgt_in_batch
......
......@@ -33,8 +33,15 @@ def read_data(opt):
if opt.src_data is not None:
try:
print("loading src_data")
opt.src_data = open(opt.src_data).read().strip().split('\n')
opt.src_data=[x for x in tqdm(opt.src_data)]
if os.path.isdir(opt.src_data):
train_set=[]
for file in os.listdir(opt.src_data):
train_set.appand(open(os.path.join(opt.src_data,file)).read().strip().split('\n'))
opt.src_data=train_set
opt.src_data=[x for x in tqdm(opt.src_data)]
else:
opt.src_data = open(opt.src_data).read().strip().split('\n')
opt.src_data=[x for x in tqdm(opt.src_data)]
# print(len(opt.src_data))
except:
print("error: '" + opt.src_data + "' file not found")
......@@ -108,14 +115,18 @@ def create_dataset(opt, SRC, TRG):
df.to_csv("translate_transformer_temp.csv", index=False)
data_fields = [('src', SRC), ('trg', TRG)]
train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)
train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields,skip_header=True)
# train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
# repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
# batch_size_fn=batch_size_fn, train=True, shuffle=True)
# train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
# repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
# batch_size_fn=None, train=True, shuffle=True)
train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=None, train=True, shuffle=True)
batch_size_fn=None, train=True, shuffle=True,augment=True)
os.remove('translate_transformer_temp.csv')
......
......@@ -214,6 +214,10 @@ if __name__ == "__main__":
# with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
# yunmus=f.readlines()
# yunmus=[a.strip() for a in yunmus]
build_corpus("./data/train_set_new.txt",
"./data/pinyin_new_split.txt", "./data/hanzi_new_split.txt")
print("Done")
ori_dir="./data/train_file/ori_file_split_random_wo_tones"
hanzi_dir="./data/train_file/hanzi_split_random_wo_tones"
pinyin_dir="./data/train_file/pinyin_split_random_wo_tones"
for file in os.listdir(ori_dir):
build_corpus(os.path.join(ori_dir,file),
os.path.join(pinyin_dir,file), os.path.join(hanzi_dir,file))
print("Done")
......@@ -38,3 +38,6 @@ CUDA_VISIBLE_DEVICES=6 nohup python train_token_classification.py -src_data data
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log1 2>&1 &
CUDA_VISIBLE_DEVICES=2 python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
......@@ -81,8 +81,8 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('-src_data', required=True)
parser.add_argument('-trg_data', required=True)
parser.add_argument('-src_lang', required=True,default="en_core_web_sm")
parser.add_argument('-trg_lang', required=True,default="fr_core_news_sm")
parser.add_argument('-src_lang', default="en_core_web_sm")
parser.add_argument('-trg_lang', default="fr_core_news_sm")
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-SGDR', action='store_true')
parser.add_argument('-epochs', type=int, default=2)
......
......@@ -8,7 +8,8 @@ from Optim import CosineWithRestarts
from Batch import create_masks, create_masks2
import dill as pickle
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from Process import get_len
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def train_model(model, opt, start_time):
......@@ -35,7 +36,7 @@ def train_model(model, opt, start_time):
# torch.save(model.state_dict(), 'weights/model_weights')
for i, batch in enumerate(opt.train):
src = batch.src.transpose(0, 1).cuda()
trg = batch.trg.transpose(0, 1).cuda()
# print("src shape:{} trg shape:{}".format(src.shape,trg.shape))
......@@ -59,8 +60,8 @@ def train_model(model, opt, start_time):
p = int(100 * (i + 1) / opt.train_len)
avg_loss = total_loss/opt.printevery
if opt.floyd is False:
print(" %dm: epoch %d [%s%s] %d%% loss = %.3f" %
((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r')
print(" %dm: epoch %d [%s%s] %d%% steps = %d loss = %.3f" %
((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, i,avg_loss), end='\r')
else:
print(" %dm: epoch %d [%s%s] %d%% loss = %.3f" %
((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss))
......@@ -70,8 +71,8 @@ def train_model(model, opt, start_time):
# torch.save(model.state_dict(), 'weights/model_weights')
# cptime = time.time()
print("%dm: epoch %d [%s%s] %d%% loss = %.3f\nepoch %d complete, loss = %.03f" %
((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_loss))
print("%dm: epoch %d [%s%s] %d%% steps = %d loss = %.3f\nepoch %d complete, total steps = %d, loss = %.03f" %
((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, i ,avg_loss, epoch + 1, i, avg_loss))
torch.save(model.state_dict(), os.path.join(
dst, opt.model_name+"_{}_{}".format(epoch + 1, avg_loss)))
print("model saved as {}".format(opt.model_name +
......@@ -83,8 +84,8 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('-src_data', required=True)
parser.add_argument('-trg_data', required=True)
parser.add_argument('-src_lang', required=True,default="en_core_web_sm")
parser.add_argument('-trg_lang', required=True,default="fr_core_news_sm")
parser.add_argument('-src_lang', default="en_core_web_sm")
parser.add_argument('-trg_lang', default="fr_core_news_sm")
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-SGDR', action='store_true')
parser.add_argument('-epochs', type=int, default=2)
......@@ -117,7 +118,7 @@ def main():
SRC, TRG = create_fields(opt)
opt.train = create_dataset(opt, SRC, TRG)
print("train_len:{}".format(opt.train_len))
# print("train_len:{}".format(opt.train_len))
# model = get_model(opt, len(SRC.vocab), len(TRG.vocab))
model = get_model_token_classification(opt, len(SRC.vocab), len(TRG.vocab))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment