Commit 2e453384 authored by szr712's avatar szr712

支持多文件训练集

parent 237ee68c
......@@ -31,28 +31,34 @@ def yield_tokens(file_path):
def read_data(opt):
if opt.src_data is not None:
try:
# try:
print("loading src_data")
if os.path.isdir(opt.src_data):
train_set=[]
for file in os.listdir(opt.src_data):
train_set.appand(open(os.path.join(opt.src_data,file)).read().strip().split('\n'))
train_set = train_set + open(os.path.join(opt.src_data,file)).read().strip().split('\n')
opt.src_data=train_set
opt.src_data=[x for x in tqdm(opt.src_data)]
else:
opt.src_data = open(opt.src_data).read().strip().split('\n')
opt.src_data=[x for x in tqdm(opt.src_data)]
# print(len(opt.src_data))
except:
print("error: '" + opt.src_data + "' file not found")
quit()
# except:
# print("error: '" + opt.src_data + "' file not found")
# quit()
if opt.trg_data is not None:
try:
print("loading trg_data")
opt.trg_data = open(opt.trg_data).read().strip().split('\n')
# opt.trg_data=[x for x in tqdm(opt.trg_data) if len(wenzi2pinyin(x))<=200]
opt.trg_data=[x for x in tqdm(opt.trg_data)]
if os.path.isdir(opt.trg_data):
train_set=[]
for file in os.listdir(opt.trg_data):
train_set = train_set + open(os.path.join(opt.trg_data,file)).read().strip().split('\n')
opt.trg_data=train_set
opt.trg_data=[x for x in tqdm(opt.trg_data)]
else:
opt.trg_data = open(opt.trg_data).read().strip().split('\n')
opt.trg_data=[x for x in tqdm(opt.trg_data)]
except:
print("error: '" + opt.trg_data + "' file not found")
quit()
......
......@@ -214,10 +214,12 @@ if __name__ == "__main__":
# with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
# yunmus=f.readlines()
# yunmus=[a.strip() for a in yunmus]
ori_dir="./data/train_file/ori_file_split_random_wo_tones"
hanzi_dir="./data/train_file/hanzi_split_random_wo_tones"
pinyin_dir="./data/train_file/pinyin_split_random_wo_tones"
for file in os.listdir(ori_dir):
build_corpus(os.path.join(ori_dir,file),
os.path.join(pinyin_dir,file), os.path.join(hanzi_dir,file))
print("Done")
# ori_dir="./data/train_file/ori_file_split_random_wo_tones"
# hanzi_dir="./data/train_file/hanzi_split_random_wo_tones"
# pinyin_dir="./data/train_file/pinyin_split_random_wo_tones"
# for file in os.listdir(ori_dir):
# build_corpus(os.path.join(ori_dir,file),
# os.path.join(pinyin_dir,file), os.path.join(hanzi_dir,file))
# print("Done")
build_corpus("./data/dev/dev_hanzi.txt",
"./data/dev/dev_pinyin_split.txt", "./data/dev/dev_hanzi_split.txt")
......@@ -13,8 +13,8 @@ def random_change_tones(tones):
tones[i]=random.choice(options)
return tones
hanzi_dir="./data/test_data/hanzi_new"
pinyin_dir="./data/test_data/split_new_data_daoxuehao/pinyin_random_change_tones"
hanzi_dir="./data/test_data/split_random_wo_tones/hanzi"
pinyin_dir="./data/test_data/split_random_wo_tones/pinyin"
with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
yunmus=f.readlines()
......
......@@ -40,4 +40,4 @@ CUDA_VISIBLE_DEVICES=6 nohup python train_token_classification.py -src_data data
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log1 2>&1 &
CUDA_VISIBLE_DEVICES=2 python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
CUDA_VISIBLE_DEVICES=1 python train_token_classification.py -src_data data/train_file/pinyin_split_random_wo_tones -trg_data data/train_file/hanzi_split_random_wo_tones -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment