Commit f6f3e5dc authored by 李承曦(20硕)'s avatar 李承曦(20硕)

增加基于k2的WFST语境bias代码

parent d833c315
......@@ -5,6 +5,10 @@
/data/result_data
/data/test_data
/data/*.txt
/data/biased
/k2_WFSTdecoder/*.png
/k2_WFSTdecoder/*.txt
/k2_WFSTdecoder/test.sh
sum_char.py
test_oov.py
......
import argparse
import logging
from pathlib import Path
import os
import k2
import torch
from .icefall.lexicon import Lexicon
def compile_GG(lm_dir: str) -> k2.Fsa:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info(f"Loading G_general.fst.txt")
with open(f"{lm_dir}/G_general.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
G = k2.arc_sort(G)
logging.info(f"Loading G_special.fst.txt")
with open(f"{lm_dir}/G_special.fst.txt") as f:
G_ = k2.Fsa.from_openfst(f.read(), acceptor=False)
G_ = k2.arc_sort(G)
G = k2.invert(G)
G_ = k2.remove_epsilon_and_add_self_loops(G_)
del G.aux_labels
GG = k2.intersect(G_, G)
GG_arcs = GG.arcs.values()
max_V = GG_arcs[:, :2].max() - 1
backward_GG_arcs = torch.zeros([max_V + 1, 4], dtype=torch.int32)
for i in range(max_V):
if i == 0:
continue
backward_GG_arcs[i, 0] = i + 1
backward_GG_arcs[-1, 0] = 0
backward_GG_arcs[-1, 1] = max_V + 1
backward_GG_arcs[-1, 2] = -1
backward_GG_arcs[-1, 3] = -1
# backward_GG_arcs = backward_GG_arcs[1:, :]
# backward_GG_arcs[:, 1] = 0
# backward_GG_arcs[:, 2] = 0
backward_GG_arcs[:, 3] = 0
arcs = torch.cat([GG_arcs, backward_GG_arcs], dim=0)
temp_index = arcs[:, 0].sort()[1]
arcs = arcs[temp_index]
temp_aux = torch.cat([GG.aux_labels, torch.zeros(backward_GG_arcs.shape[0])], dim=-0)[temp_index]
# import ipdb
# ipdb.set_trace()
GG = k2.Fsa(arcs=arcs, aux_labels=temp_aux).to(device)
return GG
def compile_G(lm_dir: str, key: str) -> k2.Fsa:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info(f"Loading G_{key}.fst.txt")
# lexicon = Lexicon(params.lang_dir)
if not os.path.exists(f'{lm_dir}/G_{key}.pt'):
with open(f"{lm_dir}/G_{key}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
if not key == 'general':
# del G.aux_labels
# G.labels[G.labels >= 0] = 0
G.__dict__["_properties"] = None
# G = k2.add_epsilon_self_loops(G)
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# G["dummy"] = 1
torch.save(G.as_dict(), f'{lm_dir}/G_{key}.pt')
else:
logging.info(f"Loading pre-compiled G_{key}.pt")
d = torch.load(f'{lm_dir}/G_{key}.pt', map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
G.lm_scores = G.scores.clone()
return G
\ No newline at end of file
This diff is collapsed.
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import sys
from pathlib import Path
from typing import List, Tuple
import k2
import torch
def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
"""Read a lexicon from `filename`.
Each line in the lexicon contains "word p1 p2 p3 ...".
That is, the first field is a word and the remaining
fields are tokens. Fields are separated by space(s).
Args:
filename:
Path to the lexicon.txt
Returns:
A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])]
"""
ans = []
with open(filename, "r", encoding="utf-8") as f:
whitespace = re.compile("[ \t]+")
for line in f:
a = whitespace.split(line.strip(" \t\r\n"))
if len(a) == 0:
continue
if len(a) < 2:
logging.info(
f"Found bad line {line} in lexicon file {filename}"
)
logging.info(
"Every line is expected to contain at least 2 fields"
)
sys.exit(1)
word = a[0]
if word == "<eps>":
logging.info(
f"Found bad line {line} in lexicon file {filename}"
)
logging.info("<eps> should not be a valid word")
sys.exit(1)
tokens = a[1:]
ans.append((word, tokens))
return ans
def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
"""Write a lexicon to a file.
Args:
filename:
Path to the lexicon file to be generated.
lexicon:
It can be the return value of :func:`read_lexicon`.
"""
with open(filename, "w", encoding="utf-8") as f:
for word, tokens in lexicon:
f.write(f"{word} {' '.join(tokens)}\n")
def convert_lexicon_to_ragged(
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
) -> k2.RaggedTensor:
"""Read a lexicon and convert it to a ragged tensor.
The ragged tensor has two axes: [word][token].
Caution:
We assume that each word has a unique pronunciation.
Args:
filename:
Filename of the lexicon. It has a format that can be read
by :func:`read_lexicon`.
word_table:
The word symbol table.
token_table:
The token symbol table.
Returns:
A k2 ragged tensor with two axes [word][token].
"""
disambig_id = word_table["#0"]
# We reuse the same words.txt from the phone based lexicon
# so that we can share the same G.fst. Here, we have to
# exclude some words present only in the phone based lexicon.
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
# epsilon is not a word, but it occupies a position
#
row_splits = [0]
token_ids_list = []
lexicon_tmp = read_lexicon(filename)
lexicon = dict(lexicon_tmp)
if len(lexicon_tmp) != len(lexicon):
raise RuntimeError(
"It's assumed that each word has a unique pronunciation"
)
for i in range(disambig_id):
w = word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
tokens = lexicon[w]
token_ids = [token_table[k] for k in tokens]
row_splits.append(row_splits[-1] + len(token_ids))
token_ids_list.extend(token_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
row_splits,
None,
cached_tot_size,
)
values = torch.tensor(token_ids_list, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
class Lexicon(object):
"""Phone based lexicon."""
def __init__(
self,
lang_dir: Path,
disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Args:
lang_dir:
Path to the lang directory. It is expected to contain the following
files:
- tokens.txt
- words.txt
- L.pt
The above files are produced by the script `prepare.sh`. You
should have run that before running the training code.
disambig_pattern:
It contains the pattern for disambiguation symbols.
"""
lang_dir = Path(lang_dir)
self.token_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
self.word_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
if (lang_dir / "Linv.pt").exists():
logging.info(f"Loading pre-compiled {lang_dir}/Linv.pt")
L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt"))
else:
logging.info("Converting L.pt to Linv.pt")
L = k2.Fsa.from_dict(torch.load(lang_dir / "L.pt"))
L_inv = k2.arc_sort(L.invert())
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
# We save L_inv instead of L because it will be used to intersect with
# transcript FSAs, both of whose labels are word IDs.
self.L_inv = L_inv
self.disambig_pattern = disambig_pattern
@property
def tokens(self) -> List[int]:
"""Return a list of token IDs excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = self.token_table.symbols
ans = []
for s in symbols:
if not self.disambig_pattern.match(s):
ans.append(self.token_table[s])
if 0 in ans:
ans.remove(0)
ans.sort()
return ans
class UniqLexicon(Lexicon):
def __init__(
self,
lang_dir: Path,
uniq_filename: str = "uniq_lexicon.txt",
disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Refer to the help information in Lexicon.__init__.
uniq_filename: It is assumed to be inside the given `lang_dir`.
Each word in the lexicon is assumed to have a unique pronunciation.
"""
lang_dir = Path(lang_dir)
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
self.ragged_lexicon = convert_lexicon_to_ragged(
filename=lang_dir / uniq_filename,
word_table=self.word_table,
token_table=self.token_table,
)
# TODO: should we move it to a certain device ?
def texts_to_token_ids(
self, texts: List[str], oov: str = "<UNK>"
) -> k2.RaggedTensor:
"""
Args:
texts:
A list of transcripts. Each transcript contains space(s)
separated words. An example texts is::
['HELLO k2', 'HELLO icefall']
oov:
The OOV word. If a word in `texts` is not in the lexicon, it is
replaced with `oov`.
Returns:
Return a ragged int tensor with 2 axes [utterance][token_id]
"""
oov_id = self.word_table[oov]
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(oov_id)
word_ids_list.append(word_ids)
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
ans = self.ragged_lexicon.index(ragged_indexes)
ans = ans.remove_axis(ans.num_axes - 2)
return ans
def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor containing token IDs.
We assume there are no OOVs in "words".
"""
word_ids = [self.word_table[w] for w in words]
word_ids = torch.tensor(word_ids, dtype=torch.int32)
ragged, _ = self.ragged_lexicon.index(
indexes=word_ids,
axis=0,
need_value_indexes=False,
)
return ragged
\ No newline at end of file
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2
import k2.version
def get_texts(
best_paths: k2.Fsa, return_ragged: bool = False
) -> Union[List[List[int]], k2.RaggedTensor]:
"""Extract the texts (as word IDs) from the best-path FSAs.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
return_ragged:
True to return a ragged tensor with two axes [utt][word_id].
False to return a list-of-list word IDs.
Returns:
Returns a list of lists of int, containing the label sequences we
decoded.
"""
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
# remove 0's and -1's.
aux_labels = best_paths.aux_labels.remove_values_leq(0)
# TODO: change arcs.shape() to arcs.shape
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
# remove the states and arcs axes.
aux_shape = aux_shape.remove_axis(1)
aux_shape = aux_shape.remove_axis(1)
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
else:
# remove axis corresponding to states.
aux_shape = best_paths.arcs.shape().remove_axis(1)
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
# remove 0's and -1's.
aux_labels = aux_labels.remove_values_leq(0)
assert aux_labels.num_axes == 2
if return_ragged:
return aux_labels
else:
return aux_labels.tolist()
1 0
0 49
0 85
0 1 6 -77.41171594483659
0 2 0 -77.52820497047072
1 3 0 -76.30668366093958
1 4 0 -76.30668366093958
1 5 7 -76.12490222679712
2 4 6 -76.05664693464729
2 3 0 -76.30668366093958
3 7 7 -76.80004123240462
3 6 0 -77.52820497047072
4 6 0 -77.52820497047072
4 8 0 -77.52820497047072
4 9 7 -77.0766342784592
5 6 0 -77.52820497047072
5 9 0 -77.52820497047072
6 10 0 -76.30668366093958
7 10 0 -76.30668366093958
8 10 0 -76.30668366093958
9 10 0 -76.30668366093958
9 11 0 -76.30668366093958
10 12 1 -76.78742485656035
10 13 0 -77.52820497047072
11 13 0 -77.52820497047072
11 14 0 -77.52820497047072
12 15 0 -76.30668366093958
12 16 0 -76.30668366093958
12 17 2 -76.12121222008078
13 15 1 -76.08124381241319
13 16 0 -76.30668366093958
14 16 0 -76.30668366093958
14 18 0 -76.30668366093958
15 19 0 -77.52820497047072
15 20 0 -77.52820497047072
15 21 2 -77.04467980124757
16 20 0 -77.52820497047072
17 21 0 -77.52820497047072
18 20 0 -77.52820497047072
18 22 0 -77.52820497047072
19 23 0 -76.30668366093958
19 24 0 -76.30668366093958
19 25 2 -76.12121222008078
20 24 0 -76.30668366093958
21 25 0 -76.30668366093958
21 24 0 -76.30668366093958
21 26 3 -76.12121222008078
22 24 0 -76.30668366093958
22 27 0 -76.30668366093958
23 28 0 -77.52820497047072
23 29 0 -77.52820497047072
24 29 0 -77.52820497047072
24 31 3 -76.80004123240462
25 30 0 -77.52820497047072
25 29 0 -77.52820497047072
25 32 3 -77.04467980124757
26 32 0 -77.52820497047072
27 29 0 -77.52820497047072
27 33 0 -77.52820497047072
28 34 0 -76.30668366093958
29 34 0 -76.30668366093958
30 35 0 -76.30668366093958
30 34 0 -76.30668366093958
30 36 3 -76.12121222008078
31 34 0 -76.30668366093958
32 36 0 -76.30668366093958
32 34 0 -76.30668366093958
33 34 0 -76.30668366093958
34 39 9 -76.75389983574527
34 37 0 -76.62906099113884
34 40 5 -76.70998441304296
35 38 0 -76.62906099113884
35 37 0 -76.62906099113884
36 41 0 -76.62906099113884
36 42 5 -77.85764925878665
36 37 0 -76.62906099113884
37 43 0 -76.30668366093958
38 44 0 -76.30668366093958
38 43 0 -76.30668366093958
39 43 0 -76.30668366093958
40 43 0 -76.30668366093958
41 43 0 -76.30668366093958
42 43 0 -76.30668366093958
43 45 3 -76.80008784358508
44 46 3 -77.04476177443331
45 47 0 -76.30668366093958
46 47 0 -76.30668366093958
47 48 -1 -99
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment