Commit a0e5b946 authored by szr712's avatar szr712

使用韵母声母作为token

parent 636dc31e
from http.client import NO_CONTENT
import pandas as pd import pandas as pd
import torchtext import torchtext
from torchtext import data from torchtext import data
...@@ -8,6 +9,9 @@ import dill as pickle ...@@ -8,6 +9,9 @@ import dill as pickle
from pypinyin import Style, pinyin from pypinyin import Style, pinyin
from pypinyin.core import lazy_pinyin from pypinyin.core import lazy_pinyin
from tqdm import tqdm from tqdm import tqdm
from torchtext import vocab
from torchtext.vocab import build_vocab_from_iterator
import io
...@@ -19,13 +23,18 @@ def wenzi2pinyin(text): ...@@ -19,13 +23,18 @@ def wenzi2pinyin(text):
pinyin_list = lazy_pinyin(text, style=Style.NORMAL) pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
return "".join(pinyin_list) return "".join(pinyin_list)
def yield_tokens(file_path):
with io.open(file_path, encoding = 'utf-8') as f:
for line in f:
yield line.strip().split()
def read_data(opt): def read_data(opt):
if opt.src_data is not None: if opt.src_data is not None:
try: try:
print("loading src_data") print("loading src_data")
opt.src_data = open(opt.src_data).read().strip().split('\n') opt.src_data = open(opt.src_data).read().strip().split('\n')
opt.src_data=[x for x in tqdm(opt.src_data) if len(x)<=200] opt.src_data=[x for x in tqdm(opt.src_data)]
# print(len(opt.src_data)) # print(len(opt.src_data))
except: except:
print("error: '" + opt.src_data + "' file not found") print("error: '" + opt.src_data + "' file not found")
...@@ -36,7 +45,7 @@ def read_data(opt): ...@@ -36,7 +45,7 @@ def read_data(opt):
print("loading trg_data") print("loading trg_data")
opt.trg_data = open(opt.trg_data).read().strip().split('\n') opt.trg_data = open(opt.trg_data).read().strip().split('\n')
# opt.trg_data=[x for x in tqdm(opt.trg_data) if len(wenzi2pinyin(x))<=200] # opt.trg_data=[x for x in tqdm(opt.trg_data) if len(wenzi2pinyin(x))<=200]
opt.trg_data=[x for x in tqdm(opt.trg_data) if len(x)<=200] opt.trg_data=[x for x in tqdm(opt.trg_data)]
except: except:
print("error: '" + opt.trg_data + "' file not found") print("error: '" + opt.trg_data + "' file not found")
quit() quit()
...@@ -45,6 +54,9 @@ def read_data(opt): ...@@ -45,6 +54,9 @@ def read_data(opt):
def my_tokenize(text): def my_tokenize(text):
return list(text) return list(text)
def my_tokenize2(text):
return text.split(" ")
def create_fields(opt): def create_fields(opt):
spacy_langs = ['en', 'fr', 'de', 'es', 'pt', 'it', 'nl'] spacy_langs = ['en', 'fr', 'de', 'es', 'pt', 'it', 'nl']
...@@ -62,8 +74,12 @@ def create_fields(opt): ...@@ -62,8 +74,12 @@ def create_fields(opt):
# SRC = data.Field(lower=True, tokenize=t_src.tokenizer) # SRC = data.Field(lower=True, tokenize=t_src.tokenizer)
# TRG = data.Field(tokenize=my_tokenize, init_token='<sos>', eos_token='<eos>') # TRG = data.Field(tokenize=my_tokenize, init_token='<sos>', eos_token='<eos>')
TRG = data.Field(tokenize=my_tokenize) if opt.src_voc is None and opt.trg_voc is None:
SRC = data.Field(tokenize=my_tokenize) TRG = data.Field(tokenize=my_tokenize)
SRC = data.Field(tokenize=my_tokenize)
else:
TRG = data.Field(tokenize=my_tokenize2)
SRC = data.Field(tokenize=my_tokenize2)
if opt.pkl_dir is not None: if opt.pkl_dir is not None:
try: try:
...@@ -83,7 +99,7 @@ def create_dataset(opt, SRC, TRG): ...@@ -83,7 +99,7 @@ def create_dataset(opt, SRC, TRG):
raw_data = {'src' : [line for line in opt.src_data], 'trg': [line for line in opt.trg_data]} raw_data = {'src' : [line for line in opt.src_data], 'trg': [line for line in opt.trg_data]}
df = pd.DataFrame(raw_data, columns=["src", "trg"]) df = pd.DataFrame(raw_data, columns=["src", "trg"])
print(df.head()) print(df.sample(5))
mask = (df['src'].str.count(' ') < opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen) mask = (df['src'].str.count(' ') < opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen)
# print(mask) # print(mask)
...@@ -104,9 +120,19 @@ def create_dataset(opt, SRC, TRG): ...@@ -104,9 +120,19 @@ def create_dataset(opt, SRC, TRG):
os.remove('translate_transformer_temp.csv') os.remove('translate_transformer_temp.csv')
if opt.load_weights is None: if opt.load_weights is None:
SRC.build_vocab(train) if opt.src_voc is not None:
vocab=build_vocab_from_iterator(yield_tokens(opt.src_voc))
SRC.vocab=vocab
else:
SRC.build_vocab(train)
# print(SRC.vocab.stoi) # print(SRC.vocab.stoi)
TRG.build_vocab(train)
if opt.trg_voc is not None:
vocab=build_vocab_from_iterator(yield_tokens(opt.trg_voc))
TRG.vocab=vocab
else:
TRG.build_vocab(train)
# print(TRG.vocab.stoi) # print(TRG.vocab.stoi)
if opt.checkpoint > 0: if opt.checkpoint > 0:
try: try:
......
...@@ -24,4 +24,6 @@ CUDA_VISIBLE_DEVICES=1 python train2.py -src_data data/pinyin_2.txt -trg_data da ...@@ -24,4 +24,6 @@ CUDA_VISIBLE_DEVICES=1 python train2.py -src_data data/pinyin_2.txt -trg_data da
CUDA_VISIBLE_DEVICES=1 nohup python train2.py -src_data data/pinyin_2.txt -trg_data data/hanzi_2.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification CUDA_VISIBLE_DEVICES=1 nohup python train2.py -src_data data/pinyin_2.txt -trg_data data/hanzi_2.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification
CUDA_VISIBLE_DEVICES=3 python translate_pkl.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_35_0.055335590355098246 -pkl_dir weights/token_classification/11-09_22:00:55 -test_dir data/pkl/test-pkl -result_dir data/pkl/test-pkl-result CUDA_VISIBLE_DEVICES=3 python translate_pkl.py -load_weights weights/token_classification/11-09_22:00:55/token_classification_35_0.055335590355098246 -pkl_dir weights/token_classification/11-09_22:00:55 -test_dir data/pkl/test-pkl -result_dir data/pkl/test-pkl-result
\ No newline at end of file
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_split.txt -trg_data data/hanzi_split.txt -src_lang en_core_web_sm -trg_lang fr_core_news_sm -epochs 100 -model_name token_classification_split -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
\ No newline at end of file
...@@ -20,8 +20,26 @@ from threading import Semaphore ...@@ -20,8 +20,26 @@ from threading import Semaphore
import regex # pip install regex import regex # pip install regex
from xpinyin import Pinyin # pip install xpinyin from xpinyin import Pinyin # pip install xpinyin
import traceback import traceback
from pypinyin.style._utils import get_initials, get_finals
from pypinyin import Style, pinyin
from pypinyin.core import lazy_pinyin
import itertools
pinyin = Pinyin()
def wenzi2pinyin(text):
pinyin_list = lazy_pinyin(text, style=Style.TONE3)
# print(pinyin_list)
# tones_list = [int(py[-1]) if py[-1].isdigit()
# else 0 for py in pinyin_list]
# pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
return pinyin_list
def split_initials_finals(pinyin):
if get_initials(pinyin, False)!="":
return [get_initials(pinyin, False), get_finals(pinyin, False)]
else:
return [get_finals(pinyin, False)]
def align(sent): def align(sent):
''' '''
...@@ -60,6 +78,57 @@ def align(sent): ...@@ -60,6 +78,57 @@ def align(sent):
assert len(pnyns) == len(hanzis), "The hanzis and the pinyins must be the same in length." assert len(pnyns) == len(hanzis), "The hanzis and the pinyins must be the same in length."
return pnyns, hanzis return pnyns, hanzis
def align2(sent):
'''
Args:
sent: A string. A sentence.
Returns:
A tuple of pinyin and chinese sentence.
'''
# pnyns = pinyin.get_pinyin(sent, " ").split()
pinyins=wenzi2pinyin(sent)
pnyns = [split_initials_finals(a) for a in pinyins]
hanzis = []
x=0
for i in range(len(pnyns)):
if i+x<len(sent.replace(" ", "")):
char=sent[i+x]
p=pnyns[i]
if '\u4e00' <= char <= '\u9fa5':
hanzis.extend([char] + ["_"] * (len(p) - 1))
else:
for q in p:
hanzis.append(q)
x+=len(q)
x-=1
# word=""
# while not '\u4e00' <= char <= '\u9fa5':
# word=word+char
# x=x+1
# if x+i>=len(sent.replace(" ", "")):
# break
# char=sent[i+x]
# hanzis.extend([word])
# x=x-1
# for char, p in zip(sent.replace(" ", ""), pnyns):
# if '\u4e00' <= char <= '\u9fa5':
# hanzis.extend([char] + ["_"] * (len(p) - 1))
pnyns = " ".join(list(itertools.chain.from_iterable(pnyns)))
hanzis = " ".join(hanzis)
if len(pnyns.split(" ")) != len(hanzis.split(" ")):
print(sent)
print(pnyns)
print(hanzis)
# assert len(pnyns.split(" ")) == len(hanzis.split(" ")), "The hanzis and the pinyins must be the same in length."
return pnyns, hanzis
def clean(text): def clean(text):
# if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed. # if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
# # if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed. # # if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
...@@ -79,37 +148,41 @@ def clean(text): ...@@ -79,37 +148,41 @@ def clean(text):
if len(text_new)<10: return "" if len(text_new)<10: return ""
return text_new return text_new
def build_corpus(): def build_corpus(src_file,pinyin_file,hanzi_file):
pinyin_list=[] pinyin_list=[]
hanzi_list=[] hanzi_list=[]
with codecs.open("data/zh.tsv", 'w', 'utf-8') as fout: # with codecs.open("data/zho_news_2007-2009_1M-sentences.txt", 'r', 'utf-8') as fin:
# with codecs.open("data/zho_news_2007-2009_1M-sentences.txt", 'r', 'utf-8') as fin: with codecs.open(src_file, 'r', 'utf-8') as fin:
with codecs.open("data/train_set_total.txt", 'r', 'utf-8') as fin: i = 1
i = 1 while 1:
while 1: line = fin.readline()
line = fin.readline() if not line: break
if not line: break
try:
try: # idx, sent = line.strip().split("\t")
# idx, sent = line.strip().split("\t") # if idx == "234":
# if idx == "234": # print(sent)
# print(sent) # sent = clean(sent)
# sent = clean(sent) # if len(sent) > 0:
# if len(sent) > 0: # pnyns, hanzis = align(sent)
# pnyns, hanzis = align(sent) # fout.write(u"{}\t{}\t{}\n".format(idx, pnyns, hanzis))
# fout.write(u"{}\t{}\t{}\n".format(idx, pnyns, hanzis)) sent = line.strip()
sent = line.strip() sent = sent.replace(" ","")
sent = sent.replace(" ","") # sent = clean(sent)
# sent = clean(sent) if len(sent) > 0:
if len(sent) > 0: pnyns, hanzis = align2(sent)
pnyns, hanzis = align(sent) pinyin_list.append(pnyns)
fout.write(u"{}\t{}\n".format(pnyns, hanzis)) hanzi_list.append(hanzis)
except: except:
traceback.print_exc() traceback.print_exc()
continue # it's okay as we have a pretty big corpus! continue # it's okay as we have a pretty big corpus!
if i % 10000 == 0: print(i, ) if i % 10000 == 0: print(i, )
i += 1 i += 1
with open(pinyin_file,'w',encoding='utf-8') as f:
f.write("\n".join(pinyin_list))
with open(hanzi_file,'w',encoding='utf-8') as f:
f.write("\n".join(hanzi_list))
if __name__ == "__main__": if __name__ == "__main__":
build_corpus(); print("Done") build_corpus("./data/train_set_total.txt","./data/pinyin_spilit.txt","./data/hanzi_spilit.txt"); print("Done")
This source diff could not be displayed because it is too large. You can view the blob instead.
import pickle from pypinyin.style._utils import get_initials, get_finals
import os from pypinyin import Style, pinyin
import itertools from build_corpus import align2, split_initials_finals, wenzi2pinyin
from pypinyin.core import lazy_pinyin
a = 'ying1'
strict = False
if get_initials(a, strict)=="":
print("True")
for file in os.listdir("./data/pkl/test-pkl")[2:3]: text="s,但季节分布不均,汛期流量极大,而枯水期几乎河底朝天,旱突频发,居民吃水困难。"
contents = pickle.load(open(os.path.join("data/pkl/test-pkl",file),"rb")) # pinyins,tones=wenzi2pinyin(text)
# print(lazy_pinyin(text, style=Style.TONE3))
contents=[list(itertools.chain.from_iterable(lines)) for lines in contents] # print(lazy_pinyin(text, style=Style.TONE3))
print(contents) # pnyns = [split_initials_finals(a) for a in lazy_pinyin(text, style=Style.TONE3)]
# print(pnyns)
text = text.replace(" ","")
pnyns, hanzis=align2(text)
print(pnyns)
print(hanzis)
print('%s %s' % (get_initials(a, strict), get_finals(a, strict)))
\ No newline at end of file
...@@ -102,6 +102,8 @@ def main(): ...@@ -102,6 +102,8 @@ def main():
parser.add_argument('-checkpoint', type=int, default=0) parser.add_argument('-checkpoint', type=int, default=0)
parser.add_argument('-model_name', type=str, default="transformer") parser.add_argument('-model_name', type=str, default="transformer")
parser.add_argument('-pkl_dir') parser.add_argument('-pkl_dir')
parser.add_argument('-src_voc')
parser.add_argument('-trg_voc')
opt = parser.parse_args() opt = parser.parse_args()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment