Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
P
pinyin2hanzi
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
邵子睿(21软)
pinyin2hanzi
Commits
237ee68c
Commit
237ee68c
authored
Dec 06, 2021
by
szr712
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
增加在线随机抹除音调
parent
8c385980
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
103 additions
and
28 deletions
+103
-28
Batch.py
Batch.py
+66
-10
Process.py
Process.py
+15
-4
build_corpus.py
build_corpus.py
+7
-3
log.txt
log.txt
+3
-0
train.py
train.py
+2
-2
train_token_classification.py
train_token_classification.py
+10
-9
No files found.
Batch.py
View file @
237ee68c
...
...
@@ -2,41 +2,95 @@ import torch
from
torchtext
import
data
import
numpy
as
np
from
torch.autograd
import
Variable
import
copy
import
random
from
tqdm
import
tqdm
import
time
def
nopeak_mask
(
size
,
opt
):
np_mask
=
np
.
triu
(
np
.
ones
((
1
,
size
,
size
)),
k
=
1
)
.
astype
(
'uint8'
)
np_mask
=
Variable
(
torch
.
from_numpy
(
np_mask
)
==
0
)
k
=
1
)
.
astype
(
'uint8'
)
np_mask
=
Variable
(
torch
.
from_numpy
(
np_mask
)
==
0
)
if
opt
.
device
==
0
:
np_mask
=
np_mask
.
cuda
()
np_mask
=
np_mask
.
cuda
()
return
np_mask
def
create_masks
(
src
,
trg
,
opt
):
src_mask
=
(
src
!=
opt
.
src_pad
)
.
unsqueeze
(
-
2
)
if
trg
is
not
None
:
trg_mask
=
(
trg
!=
opt
.
trg_pad
)
.
unsqueeze
(
-
2
)
size
=
trg
.
size
(
1
)
# get seq_len for matrix
size
=
trg
.
size
(
1
)
# get seq_len for matrix
np_mask
=
nopeak_mask
(
size
,
opt
)
if
trg
.
is_cuda
:
np_mask
.
cuda
()
trg_mask
=
trg_mask
&
np_mask
else
:
trg_mask
=
None
return
src_mask
,
trg_mask
def
create_masks2
(
src
,
opt
):
src_mask
=
(
src
!=
opt
.
src_pad
)
.
unsqueeze
(
-
2
)
return
src_mask
# patch on Torchtext's batching process that makes it more efficient
# from http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks
class
MyIterator
(
data
.
Iterator
):
def
__init__
(
self
,
dataset
,
batch_size
,
sort_key
=
None
,
device
=
None
,
batch_size_fn
=
None
,
train
=
True
,
repeat
=
False
,
shuffle
=
None
,
sort
=
None
,
sort_within_batch
=
None
,
augment
=
False
,
change_possibility
=
[
0.5
,
0.6
,
0.7
,
0.8
,
0.9
,
1
]):
super
()
.
__init__
(
dataset
,
batch_size
,
sort_key
,
device
,
batch_size_fn
,
train
,
repeat
,
shuffle
,
sort
,
sort_within_batch
)
self
.
augment
=
augment
self
.
change_possibility
=
change_possibility
print
(
"start copy..."
)
start
=
time
.
time
()
self
.
ori_examples
=
copy
.
deepcopy
(
self
.
dataset
.
examples
)
print
(
"end copy..., cost total:{}s"
.
format
(
time
.
time
()
-
start
))
with
open
(
"./data/voc/yunmus.txt"
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
yunmu
=
f
.
readlines
()
self
.
yunmus
=
[
a
.
strip
()
for
a
in
yunmu
]
def
data
(
self
):
"""重载data.Iterator的data方法,增加扩容代码
Return the examples in the dataset in order, sorted, or shuffled."""
if
self
.
augment
:
print
(
"augmenting data..."
)
self
.
dataset
.
examples
=
[]
for
ex
in
tqdm
(
self
.
ori_examples
):
for
p
in
self
.
change_possibility
:
new_ex
=
copy
.
deepcopy
(
ex
)
for
i
,
char
in
enumerate
(
ex
.
src
):
r
=
random
.
random
()
if
r
<
p
and
char
in
self
.
yunmus
:
new_ex
.
src
[
i
]
=
char
[:
-
1
]
+
"0"
self
.
dataset
.
examples
.
append
(
new_ex
)
print
(
"data len:{}"
.
format
(
len
(
self
.
dataset
.
examples
)))
# print("src:{}\ntrg:{}".format(type(ex.src),type(ex.trg)))
if
self
.
sort
:
xs
=
sorted
(
self
.
dataset
,
key
=
self
.
sort_key
)
elif
self
.
shuffle
:
xs
=
[
self
.
dataset
[
i
]
for
i
in
self
.
random_shuffler
(
range
(
len
(
self
.
dataset
)))]
else
:
xs
=
self
.
dataset
return
xs
def
create_batches
(
self
):
if
self
.
train
:
def
pool
(
d
,
random_shuffler
):
...
...
@@ -47,15 +101,17 @@ class MyIterator(data.Iterator):
for
b
in
random_shuffler
(
list
(
p_batch
)):
yield
b
self
.
batches
=
pool
(
self
.
data
(),
self
.
random_shuffler
)
else
:
self
.
batches
=
[]
for
b
in
data
.
batch
(
self
.
data
(),
self
.
batch_size
,
self
.
batch_size_fn
):
self
.
batch_size_fn
):
self
.
batches
.
append
(
sorted
(
b
,
key
=
self
.
sort_key
))
global
max_src_in_batch
,
max_tgt_in_batch
def
batch_size_fn
(
new
,
count
,
sofar
):
"Keep augmenting batch and calculate total number of tokens + padding."
global
max_src_in_batch
,
max_tgt_in_batch
...
...
Process.py
View file @
237ee68c
...
...
@@ -33,8 +33,15 @@ def read_data(opt):
if
opt
.
src_data
is
not
None
:
try
:
print
(
"loading src_data"
)
opt
.
src_data
=
open
(
opt
.
src_data
)
.
read
()
.
strip
()
.
split
(
'
\n
'
)
opt
.
src_data
=
[
x
for
x
in
tqdm
(
opt
.
src_data
)]
if
os
.
path
.
isdir
(
opt
.
src_data
):
train_set
=
[]
for
file
in
os
.
listdir
(
opt
.
src_data
):
train_set
.
appand
(
open
(
os
.
path
.
join
(
opt
.
src_data
,
file
))
.
read
()
.
strip
()
.
split
(
'
\n
'
))
opt
.
src_data
=
train_set
opt
.
src_data
=
[
x
for
x
in
tqdm
(
opt
.
src_data
)]
else
:
opt
.
src_data
=
open
(
opt
.
src_data
)
.
read
()
.
strip
()
.
split
(
'
\n
'
)
opt
.
src_data
=
[
x
for
x
in
tqdm
(
opt
.
src_data
)]
# print(len(opt.src_data))
except
:
print
(
"error: '"
+
opt
.
src_data
+
"' file not found"
)
...
...
@@ -108,14 +115,18 @@ def create_dataset(opt, SRC, TRG):
df
.
to_csv
(
"translate_transformer_temp.csv"
,
index
=
False
)
data_fields
=
[(
'src'
,
SRC
),
(
'trg'
,
TRG
)]
train
=
data
.
TabularDataset
(
'./translate_transformer_temp.csv'
,
format
=
'csv'
,
fields
=
data_fields
)
train
=
data
.
TabularDataset
(
'./translate_transformer_temp.csv'
,
format
=
'csv'
,
fields
=
data_fields
,
skip_header
=
True
)
# train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
# repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
# batch_size_fn=batch_size_fn, train=True, shuffle=True)
# train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
# repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
# batch_size_fn=None, train=True, shuffle=True)
train_iter
=
MyIterator
(
train
,
batch_size
=
opt
.
batchsize
,
device
=
opt
.
device
,
repeat
=
False
,
sort_key
=
lambda
x
:
(
len
(
x
.
src
),
len
(
x
.
trg
)),
batch_size_fn
=
None
,
train
=
True
,
shuffle
=
True
)
batch_size_fn
=
None
,
train
=
True
,
shuffle
=
True
,
augment
=
True
)
os
.
remove
(
'translate_transformer_temp.csv'
)
...
...
build_corpus.py
View file @
237ee68c
...
...
@@ -214,6 +214,10 @@ if __name__ == "__main__":
# with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
# yunmus=f.readlines()
# yunmus=[a.strip() for a in yunmus]
build_corpus
(
"./data/train_set_new.txt"
,
"./data/pinyin_new_split.txt"
,
"./data/hanzi_new_split.txt"
)
print
(
"Done"
)
ori_dir
=
"./data/train_file/ori_file_split_random_wo_tones"
hanzi_dir
=
"./data/train_file/hanzi_split_random_wo_tones"
pinyin_dir
=
"./data/train_file/pinyin_split_random_wo_tones"
for
file
in
os
.
listdir
(
ori_dir
):
build_corpus
(
os
.
path
.
join
(
ori_dir
,
file
),
os
.
path
.
join
(
pinyin_dir
,
file
),
os
.
path
.
join
(
hanzi_dir
,
file
))
print
(
"Done"
)
log.txt
View file @
237ee68c
...
...
@@ -38,3 +38,6 @@ CUDA_VISIBLE_DEVICES=6 nohup python train_token_classification.py -src_data data
CUDA_VISIBLE_DEVICES=2 nohup python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt >log1 2>&1 &
CUDA_VISIBLE_DEVICES=2 python train_token_classification.py -src_data data/pinyin_new_split.txt -trg_data data/hanzi_new_split.txt -epochs 100 -model_name token_classification_split_new -src_voc ./data/voc/pinyin.txt -trg_voc ./data/voc/hanzi.txt
train.py
View file @
237ee68c
...
...
@@ -81,8 +81,8 @@ def main():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'-src_data'
,
required
=
True
)
parser
.
add_argument
(
'-trg_data'
,
required
=
True
)
parser
.
add_argument
(
'-src_lang'
,
required
=
True
,
default
=
"en_core_web_sm"
)
parser
.
add_argument
(
'-trg_lang'
,
required
=
True
,
default
=
"fr_core_news_sm"
)
parser
.
add_argument
(
'-src_lang'
,
default
=
"en_core_web_sm"
)
parser
.
add_argument
(
'-trg_lang'
,
default
=
"fr_core_news_sm"
)
parser
.
add_argument
(
'-no_cuda'
,
action
=
'store_true'
)
parser
.
add_argument
(
'-SGDR'
,
action
=
'store_true'
)
parser
.
add_argument
(
'-epochs'
,
type
=
int
,
default
=
2
)
...
...
train_token_classification.py
View file @
237ee68c
...
...
@@ -8,7 +8,8 @@ from Optim import CosineWithRestarts
from
Batch
import
create_masks
,
create_masks2
import
dill
as
pickle
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"1"
from
Process
import
get_len
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def
train_model
(
model
,
opt
,
start_time
):
...
...
@@ -35,7 +36,7 @@ def train_model(model, opt, start_time):
# torch.save(model.state_dict(), 'weights/model_weights')
for
i
,
batch
in
enumerate
(
opt
.
train
):
src
=
batch
.
src
.
transpose
(
0
,
1
)
.
cuda
()
trg
=
batch
.
trg
.
transpose
(
0
,
1
)
.
cuda
()
# print("src shape:{} trg shape:{}".format(src.shape,trg.shape))
...
...
@@ -59,8 +60,8 @@ def train_model(model, opt, start_time):
p
=
int
(
100
*
(
i
+
1
)
/
opt
.
train_len
)
avg_loss
=
total_loss
/
opt
.
printevery
if
opt
.
floyd
is
False
:
print
(
"
%
dm: epoch
%
d [
%
s
%
s]
%
d
%%
loss =
%.3
f"
%
((
time
.
time
()
-
start
)
//
60
,
epoch
+
1
,
""
.
join
(
'#'
*
(
p
//
5
)),
""
.
join
(
' '
*
(
20
-
(
p
//
5
))),
p
,
avg_loss
),
end
=
'
\r
'
)
print
(
"
%
dm: epoch
%
d [
%
s
%
s]
%
d
%%
steps =
%
d
loss =
%.3
f"
%
((
time
.
time
()
-
start
)
//
60
,
epoch
+
1
,
""
.
join
(
'#'
*
(
p
//
5
)),
""
.
join
(
' '
*
(
20
-
(
p
//
5
))),
p
,
i
,
avg_loss
),
end
=
'
\r
'
)
else
:
print
(
"
%
dm: epoch
%
d [
%
s
%
s]
%
d
%%
loss =
%.3
f"
%
((
time
.
time
()
-
start
)
//
60
,
epoch
+
1
,
""
.
join
(
'#'
*
(
p
//
5
)),
""
.
join
(
' '
*
(
20
-
(
p
//
5
))),
p
,
avg_loss
))
...
...
@@ -70,8 +71,8 @@ def train_model(model, opt, start_time):
# torch.save(model.state_dict(), 'weights/model_weights')
# cptime = time.time()
print
(
"
%
dm: epoch
%
d [
%
s
%
s]
%
d
%%
loss =
%.3
f
\n
epoch
%
d complete
, loss =
%.03
f"
%
((
time
.
time
()
-
start
)
//
60
,
epoch
+
1
,
""
.
join
(
'#'
*
(
100
//
5
)),
""
.
join
(
' '
*
(
20
-
(
100
//
5
))),
100
,
avg_loss
,
epoch
+
1
,
avg_loss
))
print
(
"
%
dm: epoch
%
d [
%
s
%
s]
%
d
%%
steps =
%
d loss =
%.3
f
\n
epoch
%
d complete, total steps =
%
d
, loss =
%.03
f"
%
((
time
.
time
()
-
start
)
//
60
,
epoch
+
1
,
""
.
join
(
'#'
*
(
100
//
5
)),
""
.
join
(
' '
*
(
20
-
(
100
//
5
))),
100
,
i
,
avg_loss
,
epoch
+
1
,
i
,
avg_loss
))
torch
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
dst
,
opt
.
model_name
+
"_{}_{}"
.
format
(
epoch
+
1
,
avg_loss
)))
print
(
"model saved as {}"
.
format
(
opt
.
model_name
+
...
...
@@ -83,8 +84,8 @@ def main():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'-src_data'
,
required
=
True
)
parser
.
add_argument
(
'-trg_data'
,
required
=
True
)
parser
.
add_argument
(
'-src_lang'
,
required
=
True
,
default
=
"en_core_web_sm"
)
parser
.
add_argument
(
'-trg_lang'
,
required
=
True
,
default
=
"fr_core_news_sm"
)
parser
.
add_argument
(
'-src_lang'
,
default
=
"en_core_web_sm"
)
parser
.
add_argument
(
'-trg_lang'
,
default
=
"fr_core_news_sm"
)
parser
.
add_argument
(
'-no_cuda'
,
action
=
'store_true'
)
parser
.
add_argument
(
'-SGDR'
,
action
=
'store_true'
)
parser
.
add_argument
(
'-epochs'
,
type
=
int
,
default
=
2
)
...
...
@@ -117,7 +118,7 @@ def main():
SRC
,
TRG
=
create_fields
(
opt
)
opt
.
train
=
create_dataset
(
opt
,
SRC
,
TRG
)
print
(
"train_len:{}"
.
format
(
opt
.
train_len
))
#
print("train_len:{}".format(opt.train_len))
# model = get_model(opt, len(SRC.vocab), len(TRG.vocab))
model
=
get_model_token_classification
(
opt
,
len
(
SRC
.
vocab
),
len
(
TRG
.
vocab
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment