Commit a0e5b946 authored by szr712's avatar szr712

使用韵母声母作为token

parent 636dc31e
from http.client import NO_CONTENT
import pandas as pd
import torchtext
from torchtext import data
......@@ -8,6 +9,9 @@ import dill as pickle
from pypinyin import Style, pinyin
from pypinyin.core import lazy_pinyin
from tqdm import tqdm
from torchtext import vocab
from torchtext.vocab import build_vocab_from_iterator
import io
......@@ -19,13 +23,18 @@ def wenzi2pinyin(text):
pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
return "".join(pinyin_list)
def yield_tokens(file_path):
with io.open(file_path, encoding = 'utf-8') as f:
for line in f:
yield line.strip().split()
def read_data(opt):
if opt.src_data is not None:
try:
print("loading src_data")
opt.src_data = open(opt.src_data).read().strip().split('\n')
opt.src_data=[x for x in tqdm(opt.src_data) if len(x)<=200]
opt.src_data=[x for x in tqdm(opt.src_data)]
# print(len(opt.src_data))
except:
print("error: '" + opt.src_data + "' file not found")
......@@ -36,7 +45,7 @@ def read_data(opt):
print("loading trg_data")
opt.trg_data = open(opt.trg_data).read().strip().split('\n')
# opt.trg_data=[x for x in tqdm(opt.trg_data) if len(wenzi2pinyin(x))<=200]
opt.trg_data=[x for x in tqdm(opt.trg_data) if len(x)<=200]
opt.trg_data=[x for x in tqdm(opt.trg_data)]
except:
print("error: '" + opt.trg_data + "' file not found")
quit()
......@@ -45,6 +54,9 @@ def read_data(opt):
def my_tokenize(text):
return list(text)
def my_tokenize2(text):
return text.split(" ")
def create_fields(opt):
spacy_langs = ['en', 'fr', 'de', 'es', 'pt', 'it', 'nl']
......@@ -62,8 +74,12 @@ def create_fields(opt):
# SRC = data.Field(lower=True, tokenize=t_src.tokenizer)
# TRG = data.Field(tokenize=my_tokenize, init_token='<sos>', eos_token='<eos>')
TRG = data.Field(tokenize=my_tokenize)
SRC = data.Field(tokenize=my_tokenize)
if opt.src_voc is None and opt.trg_voc is None:
TRG = data.Field(tokenize=my_tokenize)
SRC = data.Field(tokenize=my_tokenize)
else:
TRG = data.Field(tokenize=my_tokenize2)
SRC = data.Field(tokenize=my_tokenize2)
if opt.pkl_dir is not None:
try:
......@@ -83,7 +99,7 @@ def create_dataset(opt, SRC, TRG):
raw_data = {'src' : [line for line in opt.src_data], 'trg': [line for line in opt.trg_data]}
df = pd.DataFrame(raw_data, columns=["src", "trg"])
print(df.head())
print(df.sample(5))
mask = (df['src'].str.count(' ') < opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen)
# print(mask)
......@@ -104,9 +120,19 @@ def create_dataset(opt, SRC, TRG):
os.remove('translate_transformer_temp.csv')
if opt.load_weights is None:
SRC.build_vocab(train)
if opt.src_voc is not None:
vocab=build_vocab_from_iterator(yield_tokens(opt.src_voc))
SRC.vocab=vocab
else:
SRC.build_vocab(train)
# print(SRC.vocab.stoi)
TRG.build_vocab(train)
if opt.trg_voc is not None:
vocab=build_vocab_from_iterator(yield_tokens(opt.trg_voc))
TRG.vocab=vocab
else:
TRG.build_vocab(train)
# print(TRG.vocab.stoi)
if opt.checkpoint > 0:
try:
......
......@@ -24,4 +24,6 @@ CUDA_VISIBLE_DEVICES=1 python train2.py -src_data data/pinyin_2.txt -trg_data da
CUDA_VISIBLE_DEVICES=1 nohup python train2.py -src_data data/pinyin_2.txt -trg_data data/hanzi_2.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification
CUDA_VISIBLE_DEVICES=3 python translate_pkl.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_35_0.055335590355098246 -pkl_dir weights/token_classification/11-09_22:00:55 -test_dir data/pkl/test-pkl -result_dir data/pkl/test-pkl-result
\ No newline at end of file
CUDA_VISIBLE_DEVICES=3 python translate_pkl.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_35_0.055335590355098246 -pkl_dir weights/token_classification/11-09_22:00:55 -test_dir data/pkl/test-pkl -result_dir data/pkl/test-pkl-result
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_split.txt -trg_data data/hanzi_split.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification_split -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
\ No newline at end of file
......@@ -20,8 +20,26 @@ from threading import Semaphore
import regex # pip install regex
from xpinyin import Pinyin # pip install xpinyin
import traceback
from pypinyin.style._utils import get_initials, get_finals
from pypinyin import Style, pinyin
from pypinyin.core import lazy_pinyin
import itertools
pinyin = Pinyin()
def wenzi2pinyin(text):
pinyin_list = lazy_pinyin(text, style=Style.TONE3)
# print(pinyin_list)
# tones_list = [int(py[-1]) if py[-1].isdigit()
# else 0 for py in pinyin_list]
# pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
return pinyin_list
def split_initials_finals(pinyin):
if get_initials(pinyin, False)!="":
return [get_initials(pinyin, False), get_finals(pinyin, False)]
else:
return [get_finals(pinyin, False)]
def align(sent):
'''
......@@ -60,6 +78,57 @@ def align(sent):
assert len(pnyns) == len(hanzis), "The hanzis and the pinyins must be the same in length."
return pnyns, hanzis
def align2(sent):
'''
Args:
sent: A string. A sentence.
Returns:
A tuple of pinyin and chinese sentence.
'''
# pnyns = pinyin.get_pinyin(sent, " ").split()
pinyins=wenzi2pinyin(sent)
pnyns = [split_initials_finals(a) for a in pinyins]
hanzis = []
x=0
for i in range(len(pnyns)):
if i+x<len(sent.replace(" ", "")):
char=sent[i+x]
p=pnyns[i]
if '\u4e00' <= char <= '\u9fa5':
hanzis.extend([char] + ["_"] * (len(p) - 1))
else:
for q in p:
hanzis.append(q)
x+=len(q)
x-=1
# word=""
# while not '\u4e00' <= char <= '\u9fa5':
# word=word+char
# x=x+1
# if x+i>=len(sent.replace(" ", "")):
# break
# char=sent[i+x]
# hanzis.extend([word])
# x=x-1
# for char, p in zip(sent.replace(" ", ""), pnyns):
# if '\u4e00' <= char <= '\u9fa5':
# hanzis.extend([char] + ["_"] * (len(p) - 1))
pnyns = " ".join(list(itertools.chain.from_iterable(pnyns)))
hanzis = " ".join(hanzis)
if len(pnyns.split(" ")) != len(hanzis.split(" ")):
print(sent)
print(pnyns)
print(hanzis)
# assert len(pnyns.split(" ")) == len(hanzis.split(" ")), "The hanzis and the pinyins must be the same in length."
return pnyns, hanzis
def clean(text):
# if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
# # if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
......@@ -79,37 +148,41 @@ def clean(text):
if len(text_new)<10: return ""
return text_new
def build_corpus():
def build_corpus(src_file,pinyin_file,hanzi_file):
pinyin_list=[]
hanzi_list=[]
with codecs.open("data/zh.tsv", 'w', 'utf-8') as fout:
# with codecs.open("data/zho_news_2007-2009_1M-sentences.txt", 'r', 'utf-8') as fin:
with codecs.open("data/train_set_total.txt", 'r', 'utf-8') as fin:
i = 1
while 1:
line = fin.readline()
if not line: break
try:
# idx, sent = line.strip().split("\t")
# if idx == "234":
# print(sent)
# sent = clean(sent)
# if len(sent) > 0:
# pnyns, hanzis = align(sent)
# fout.write(u"{}\t{}\t{}\n".format(idx, pnyns, hanzis))
sent = line.strip()
sent = sent.replace(" ","")
# sent = clean(sent)
if len(sent) > 0:
pnyns, hanzis = align(sent)
fout.write(u"{}\t{}\n".format(pnyns, hanzis))
except:
traceback.print_exc()
continue # it's okay as we have a pretty big corpus!
if i % 10000 == 0: print(i, )
i += 1
# with codecs.open("data/zho_news_2007-2009_1M-sentences.txt", 'r', 'utf-8') as fin:
with codecs.open(src_file, 'r', 'utf-8') as fin:
i = 1
while 1:
line = fin.readline()
if not line: break
try:
# idx, sent = line.strip().split("\t")
# if idx == "234":
# print(sent)
# sent = clean(sent)
# if len(sent) > 0:
# pnyns, hanzis = align(sent)
# fout.write(u"{}\t{}\t{}\n".format(idx, pnyns, hanzis))
sent = line.strip()
sent = sent.replace(" ","")
# sent = clean(sent)
if len(sent) > 0:
pnyns, hanzis = align2(sent)
pinyin_list.append(pnyns)
hanzi_list.append(hanzis)
except:
traceback.print_exc()
continue # it's okay as we have a pretty big corpus!
if i % 10000 == 0: print(i, )
i += 1
with open(pinyin_file,'w',encoding='utf-8') as f:
f.write("\n".join(pinyin_list))
with open(hanzi_file,'w',encoding='utf-8') as f:
f.write("\n".join(hanzi_list))
if __name__ == "__main__":
build_corpus(); print("Done")
build_corpus("./data/train_set_total.txt","./data/pinyin_spilit.txt","./data/hanzi_spilit.txt"); print("Done")
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment