Commit fd2c1cfb authored by szr712's avatar szr712

修改部分代码 翻译增加短句 batch增加最后判断

parent 5e1da936
...@@ -99,7 +99,8 @@ class MyIterator(data.Iterator): ...@@ -99,7 +99,8 @@ class MyIterator(data.Iterator):
sorted(p, key=self.sort_key), sorted(p, key=self.sort_key),
self.batch_size, self.batch_size_fn) self.batch_size, self.batch_size_fn)
for b in random_shuffler(list(p_batch)): for b in random_shuffler(list(p_batch)):
yield b if len(b) == self.batch_size:
yield b
self.batches = pool(self.data(), self.random_shuffler) self.batches = pool(self.data(), self.random_shuffler)
else: else:
......
...@@ -94,6 +94,14 @@ def get_model_token_classification(opt, src_vocab, trg_vocab): ...@@ -94,6 +94,14 @@ def get_model_token_classification(opt, src_vocab, trg_vocab):
if opt.load_weights is not None: if opt.load_weights is not None:
print("loading pretrained weights...") print("loading pretrained weights...")
model.load_state_dict(torch.load(f'{opt.load_weights}')) model.load_state_dict(torch.load(f'{opt.load_weights}'))
# checkpoint = torch.load(opt.load_weights, map_location=lambda storage, loc: storage)
# state_dict = {}
# for k, v in checkpoint.items():
# if k.startswith('module'):
# state_dict[k[7:]] = v
# else:
# state_dict[k] = v
# model.load_state_dict(state_dict)
else: else:
for p in model.parameters(): for p in model.parameters():
if p.dim() > 1: if p.dim() > 1:
......
...@@ -57,7 +57,7 @@ def cer(preFile,textFile): ...@@ -57,7 +57,7 @@ def cer(preFile,textFile):
round(float(total_edit_distance)/num_chars, 5))) round(float(total_edit_distance)/num_chars, 5)))
if __name__ == "__main__": if __name__ == "__main__":
preFile = "./data/test_data/split_new_data_daoxuehao/result_random_change_tones" preFile = "./data/test_data/tmp/2"
textFile = "./data/test_data/hanzi_new" textFile = "./data/test_data/tmp/3"
cer(preFile,textFile) cer(preFile,textFile)
\ No newline at end of file
...@@ -6,7 +6,7 @@ from build_corpus import split_initials_finals, wenzi2pinyin ...@@ -6,7 +6,7 @@ from build_corpus import split_initials_finals, wenzi2pinyin
import random import random
def random_change_tones(tones): def random_change_tones(tones):
change_possibility=[0.5, 0.6, 0.7, 0.8, 0.9, 1] change_possibility=[0.78]
change_possibility=random.choice(change_possibility) change_possibility=random.choice(change_possibility)
random.seed(42) random.seed(42)
for i,x in enumerate(tones): for i,x in enumerate(tones):
...@@ -39,12 +39,12 @@ def convert_pinyin(file,hanzi_dir,pinyin_dir,new_file): ...@@ -39,12 +39,12 @@ def convert_pinyin(file,hanzi_dir,pinyin_dir,new_file):
f.write("\n".join(result)) f.write("\n".join(result))
if __name__=="__main__": if __name__=="__main__":
hanzi_dir="./data/test_data/split_random_wo_tones/hanzi" hanzi_dir="./data/test/hanzi"
pinyin_dir="./data/test_data/split_random_wo_tones/pinyin2" pinyin_dir="./data/test/pinyin"
# with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f: # with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
# yunmus=f.readlines() # yunmus=f.readlines()
# yunmus=[a.strip() for a in yunmus] # yunmus=[a.strip() for a in yunmus]
convert_pinyin("dev_hanzi.txt","./data/dev","./data/dev","dev_pinyin_split.txt") # convert_pinyin("dev_hanzi.txt","./data/dev","./data/dev","dev_pinyin_split.txt")
# for file in os.listdir(hanzi_dir): for file in os.listdir(hanzi_dir):
# convert_pinyin(file,hanzi_dir,pinyin_dir) convert_pinyin(file,hanzi_dir,pinyin_dir,file)
\ No newline at end of file \ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
No preview for this file type
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment