Commit f6f3e5dc authored by 李承曦(20硕)'s avatar 李承曦(20硕)

增加基于k2的WFST语境bias代码

parent d833c315
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
/data/result_data /data/result_data
/data/test_data /data/test_data
/data/*.txt /data/*.txt
/data/biased
/k2_WFSTdecoder/*.png
/k2_WFSTdecoder/*.txt
/k2_WFSTdecoder/test.sh
sum_char.py sum_char.py
test_oov.py test_oov.py
......
import argparse
import logging
from pathlib import Path
import os
import k2
import torch
from .icefall.lexicon import Lexicon
def compile_GG(lm_dir: str) -> k2.Fsa:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info(f"Loading G_general.fst.txt")
with open(f"{lm_dir}/G_general.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
G = k2.arc_sort(G)
logging.info(f"Loading G_special.fst.txt")
with open(f"{lm_dir}/G_special.fst.txt") as f:
G_ = k2.Fsa.from_openfst(f.read(), acceptor=False)
G_ = k2.arc_sort(G)
G = k2.invert(G)
G_ = k2.remove_epsilon_and_add_self_loops(G_)
del G.aux_labels
GG = k2.intersect(G_, G)
GG_arcs = GG.arcs.values()
max_V = GG_arcs[:, :2].max() - 1
backward_GG_arcs = torch.zeros([max_V + 1, 4], dtype=torch.int32)
for i in range(max_V):
if i == 0:
continue
backward_GG_arcs[i, 0] = i + 1
backward_GG_arcs[-1, 0] = 0
backward_GG_arcs[-1, 1] = max_V + 1
backward_GG_arcs[-1, 2] = -1
backward_GG_arcs[-1, 3] = -1
# backward_GG_arcs = backward_GG_arcs[1:, :]
# backward_GG_arcs[:, 1] = 0
# backward_GG_arcs[:, 2] = 0
backward_GG_arcs[:, 3] = 0
arcs = torch.cat([GG_arcs, backward_GG_arcs], dim=0)
temp_index = arcs[:, 0].sort()[1]
arcs = arcs[temp_index]
temp_aux = torch.cat([GG.aux_labels, torch.zeros(backward_GG_arcs.shape[0])], dim=-0)[temp_index]
# import ipdb
# ipdb.set_trace()
GG = k2.Fsa(arcs=arcs, aux_labels=temp_aux).to(device)
return GG
def compile_G(lm_dir: str, key: str) -> k2.Fsa:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info(f"Loading G_{key}.fst.txt")
# lexicon = Lexicon(params.lang_dir)
if not os.path.exists(f'{lm_dir}/G_{key}.pt'):
with open(f"{lm_dir}/G_{key}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
if not key == 'general':
# del G.aux_labels
# G.labels[G.labels >= 0] = 0
G.__dict__["_properties"] = None
# G = k2.add_epsilon_self_loops(G)
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# G["dummy"] = 1
torch.save(G.as_dict(), f'{lm_dir}/G_{key}.pt')
else:
logging.info(f"Loading pre-compiled G_{key}.pt")
d = torch.load(f'{lm_dir}/G_{key}.pt', map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
G.lm_scores = G.scores.clone()
return G
\ No newline at end of file
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List, Optional, Union
import k2
import torch
from .utils import get_texts
def _intersect_device(
a_fsas: k2.Fsa,
b_fsas: k2.Fsa,
b_to_a_map: torch.Tensor,
sorted_match_a: bool,
batch_size: int = 50,
) -> k2.Fsa:
"""This is a wrapper of k2.intersect_device and its purpose is to split
b_fsas into several batches and process each batch separately to avoid
CUDA OOM error.
The arguments and return value of this function are the same as
:func:`k2.intersect_device`.
"""
num_fsas = b_fsas.shape[0]
if num_fsas <= batch_size:
return k2.intersect_device(
a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a
)
num_batches = (num_fsas + batch_size - 1) // batch_size
splits = []
for i in range(num_batches):
start = i * batch_size
end = min(start + batch_size, num_fsas)
splits.append((start, end))
ans = []
for start, end in splits:
indexes = torch.arange(start, end).to(b_to_a_map)
fsas = k2.index_fsa(b_fsas, indexes)
b_to_a = k2.index_select(b_to_a_map, indexes)
path_lattice = k2.intersect_device(
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
)
ans.append(path_lattice)
return k2.cat(ans)
def get_lattice(
nnet_output: torch.Tensor,
decoding_graph: k2.Fsa,
supervision_segments: torch.Tensor,
search_beam: float,
output_beam: float,
min_active_states: int,
max_active_states: int,
subsampling_factor: int = 1,
) -> k2.Fsa:
"""Get the decoding lattice from a decoding graph and neural
network output.
Args:
nnet_output:
It is the output of a neural model of shape `(N, T, C)`.
decoding_graph:
An Fsa, the decoding graph. It can be either an HLG
(see `compile_HLG.py`) or an H (see `k2.ctc_topo`).
supervision_segments:
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
Each row contains information for a supervision segment. Column 0
is the `sequence_index` indicating which sequence this segment
comes from; column 1 specifies the `start_frame` of this segment
within the sequence; column 2 contains the `duration` of this
segment.
search_beam:
Decoding beam, e.g. 20. Smaller is faster, larger is more exact
(less pruning). This is the default value; it may be modified by
`min_active_states` and `max_active_states`.
output_beam:
Beam to prune output, similar to lattice-beam in Kaldi. Relative
to best path of output.
min_active_states:
Minimum number of FSA states that are allowed to be active on any given
frame for any given intersection/composition task. This is advisory,
in that it will try not to have fewer than this number active.
Set it to zero if there is no constraint.
max_active_states:
Maximum number of FSA states that are allowed to be active on any given
frame for any given intersection/composition task. This is advisory,
in that it will try not to exceed that but may not always succeed.
You can use a very large number if no constraint is needed.
subsampling_factor:
The subsampling factor of the model.
Returns:
An FsaVec containing the decoding result. It has axes [utt][state][arc].
"""
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=subsampling_factor - 1,
)
lattice = k2.intersect_dense_pruned(
decoding_graph,
dense_fsa_vec,
search_beam=search_beam,
output_beam=output_beam,
min_active_states=min_active_states,
max_active_states=max_active_states,
)
return lattice
class Nbest(object):
"""
An Nbest object contains two fields:
(1) fsa. It is an FsaVec containing a vector of **linear** FSAs.
Its axes are [path][state][arc]
(2) shape. Its type is :class:`k2.RaggedShape`.
Its axes are [utt][path]
The field `shape` has two axes [utt][path]. `shape.dim0` contains
the number of utterances, which is also the number of rows in the
supervision_segments. `shape.tot_size(1)` contains the number
of paths, which is also the number of FSAs in `fsa`.
Caution:
Don't be confused by the name `Nbest`. The best in the name `Nbest`
has nothing to do with `best scores`. The important part is
`N` in `Nbest`, not `best`.
"""
def __init__(self, fsa: k2.Fsa, shape: k2.RaggedShape) -> None:
"""
Args:
fsa:
An FsaVec with axes [path][state][arc]. It is expected to contain
a list of **linear** FSAs.
shape:
A ragged shape with two axes [utt][path].
"""
assert len(fsa.shape) == 3, f"fsa.shape: {fsa.shape}"
assert shape.num_axes == 2, f"num_axes: {shape.num_axes}"
if fsa.shape[0] != shape.tot_size(1):
raise ValueError(
f"{fsa.shape[0]} vs {shape.tot_size(1)}\n"
"Number of FSAs in `fsa` does not match the given shape"
)
self.fsa = fsa
self.shape = shape
def __str__(self):
s = "Nbest("
s += f"Number of utterances:{self.shape.dim0}, "
s += f"Number of Paths:{self.fsa.shape[0]})"
return s
@staticmethod
def from_lattice(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
nbest_scale: float = 0.5,
) -> "Nbest":
"""Construct an Nbest object by **sampling** `num_paths` from a lattice.
Each sampled path is a linear FSA.
We assume `lattice.labels` contains token IDs and `lattice.aux_labels`
contains word IDs.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to **sample** from the lattice
using :func:`k2.random_paths`.
use_double_scores:
True to use double precision in :func:`k2.random_paths`.
False to use single precision.
scale:
Scale `lattice.score` before passing it to :func:`k2.random_paths`.
A smaller value leads to more unique paths at the risk of being not
to sample the path with the best score.
Returns:
Return an Nbest instance.
"""
saved_scores = lattice.scores.clone()
lattice.scores *= nbest_scale
# path is a ragged tensor with dtype torch.int32.
# It has three axes [utt][path][arc_pos]
path = k2.random_paths(
lattice, num_paths=num_paths, use_double_scores=use_double_scores
)
lattice.scores = saved_scores
# word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
# It axes is [utt][path][word_id]
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
word_seq = word_seq.remove_values_leq(0)
# Each utterance has `num_paths` paths but some of them transduces
# to the same word sequence, so we need to remove repeated word
# sequences within an utterance. After removing repeats, each utterance
# contains different number of paths
#
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
_, _, new2old = word_seq.unique(
need_num_repeats=False, need_new2old_indexes=True
)
# kept_path is a ragged tensor with dtype torch.int32.
# It has axes [utt][path][arc_pos]
kept_path, _ = path.index(new2old, axis=1, need_value_indexes=False)
# utt_to_path_shape has axes [utt][path]
utt_to_path_shape = kept_path.shape.get_layer(0)
# Remove the utterance axis.
# Now kept_path has only two axes [path][arc_pos]
kept_path = kept_path.remove_axis(0)
# labels is a ragged tensor with 2 axes [path][token_id]
# Note that it contains -1s.
labels = k2.ragged.index(lattice.labels.contiguous(), kept_path)
# Remove -1 from labels as we will use it to construct a linear FSA
labels = labels.remove_values_eq(-1)
if isinstance(lattice.aux_labels, k2.RaggedTensor):
# lattice.aux_labels is a ragged tensor with dtype torch.int32.
# It has 2 axes [arc][word], so aux_labels is also a ragged tensor
# with 2 axes [arc][word]
aux_labels, _ = lattice.aux_labels.index(
indexes=kept_path.values, axis=0, need_value_indexes=False
)
else:
assert isinstance(lattice.aux_labels, torch.Tensor)
aux_labels = k2.index_select(lattice.aux_labels, kept_path.values)
# aux_labels is a 1-D torch.Tensor. It also contains -1 and 0.
fsa = k2.linear_fsa(labels)
fsa.aux_labels = aux_labels
# Caution: fsa.scores are all 0s.
# `fsa` has only one extra attribute: aux_labels.
return Nbest(fsa=fsa, shape=utt_to_path_shape)
def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest":
"""Intersect this Nbest object with a lattice, get 1-best
path from the resulting FsaVec, and return a new Nbest object.
The purpose of this function is to attach scores to an Nbest.
Args:
lattice:
An FsaVec with axes [utt][state][arc]. If it has `aux_labels`, then
we assume its `labels` are token IDs and `aux_labels` are word IDs.
If it has only `labels`, we assume its `labels` are word IDs.
use_double_scores:
True to use double precision when computing shortest path.
False to use single precision.
Returns:
Return a new Nbest. This new Nbest shares the same shape with `self`,
while its `fsa` is the 1-best path from intersecting `self.fsa` and
`lattice`. Also, its `fsa` has non-zero scores and inherits attributes
for `lattice`.
"""
# Note: We view each linear FSA as a word sequence
# and we use the passed lattice to give each word sequence a score.
#
# We are not viewing each linear FSAs as a token sequence.
#
# So we use k2.invert() here.
# We use a word fsa to intersect with k2.invert(lattice)
word_fsa = k2.invert(self.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
word_fsa
)
path_to_utt_map = self.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = _intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = _intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
one_best = k2.shortest_path(
path_lattice, use_double_scores=use_double_scores
)
one_best = k2.invert(one_best)
# Now one_best has token IDs as labels and word IDs as aux_labels
return Nbest(fsa=one_best, shape=self.shape)
def compute_am_scores(self) -> k2.RaggedTensor:
"""Compute AM scores of each linear FSA (i.e., each path within
an utterance).
Hint:
`self.fsa.scores` contains two parts: acoustic scores (AM scores)
and n-gram language model scores (LM scores).
Caution:
We require that ``self.fsa`` has an attribute ``lm_scores``.
Returns:
Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64.
"""
scores_shape = self.fsa.arcs.shape().remove_axis(1)
# scores_shape has axes [path][arc]
am_scores = self.fsa.scores - self.fsa.lm_scores
ragged_am_scores = k2.RaggedTensor(scores_shape, am_scores.contiguous())
tot_scores = ragged_am_scores.sum()
return k2.RaggedTensor(self.shape, tot_scores)
def compute_lm_scores(self) -> k2.RaggedTensor:
"""Compute LM scores of each linear FSA (i.e., each path within
an utterance).
Hint:
`self.fsa.scores` contains two parts: acoustic scores (AM scores)
and n-gram language model scores (LM scores).
Caution:
We require that ``self.fsa`` has an attribute ``lm_scores``.
Returns:
Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64.
"""
scores_shape = self.fsa.arcs.shape().remove_axis(1)
# scores_shape has axes [path][arc]
ragged_lm_scores = k2.RaggedTensor(
scores_shape, self.fsa.lm_scores.contiguous()
)
tot_scores = ragged_lm_scores.sum()
return k2.RaggedTensor(self.shape, tot_scores)
def tot_scores(self) -> k2.RaggedTensor:
"""Get total scores of FSAs in this Nbest.
Note:
Since FSAs in Nbest are just linear FSAs, log-semiring
and tropical semiring produce the same total scores.
Returns:
Return a ragged tensor with two axes [utt][path_scores].
Its dtype is torch.float64.
"""
scores_shape = self.fsa.arcs.shape().remove_axis(1)
# scores_shape has axes [path][arc]
ragged_scores = k2.RaggedTensor(
scores_shape, self.fsa.scores.contiguous()
)
tot_scores = ragged_scores.sum()
return k2.RaggedTensor(self.shape, tot_scores)
def build_levenshtein_graphs(self) -> k2.Fsa:
"""Return an FsaVec with axes [utt][state][arc]."""
word_ids = get_texts(self.fsa, return_ragged=True)
return k2.levenshtein_graph(word_ids)
def one_best_decoding(
lattice: k2.Fsa,
use_double_scores: bool = True,
) -> k2.Fsa:
"""Get the best path from a lattice.
Args:
lattice:
The decoding lattice returned by :func:`get_lattice`.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
Return:
An FsaVec containing linear paths.
"""
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
return best_path
def nbest_decoding(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
nbest_scale: float = 1.0,
) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists.
The basic idea is to first extract `num_paths` paths from the given lattice,
build a word sequence from these paths, and compute the total scores
of the word sequence in the tropical semiring. The one with the max score
is used as the decoding output.
Caution:
Don't be confused by `best` in the name `n-best`. Paths are selected
**randomly**, not by ranking their scores.
Hint:
This decoding method is for demonstration only and it does
not produce a lower WER than :func:`one_best_decoding`.
Args:
lattice:
The decoding lattice, e.g., can be the return value of
:func:`get_lattice`. It has 3 axes [utt][state][arc].
num_paths:
It specifies the size `n` in n-best. Note: Paths are selected randomly
and those containing identical word sequences are removed and only one
of them is kept.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
nbest_scale:
It's the scale applied to the `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
Returns:
An FsaVec containing **linear** FSAs. It axes are [utt][state][arc].
"""
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores contains 0s
nbest = nbest.intersect(lattice)
# now nbest.fsa.scores gets assigned
# max_indexes contains the indexes for the path with the maximum score
# within an utterance.
max_indexes = nbest.tot_scores().argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
return best_path
def nbest_oracle(
lattice: k2.Fsa,
num_paths: int,
ref_texts: List[str],
word_table: k2.SymbolTable,
use_double_scores: bool = True,
nbest_scale: float = 0.5,
oov: str = "<UNK>",
) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript.
The basic idea is to extract `num_paths` paths from the given lattice,
unique them, and select the one that has the minimum edit distance with
the corresponding reference transcript as the decoding output.
The decoding result returned from this function is the best result that
we can obtain using n-best decoding with all kinds of rescoring techniques.
This function is useful to tune the value of `nbest_scale`.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
Note: We assume its `aux_labels` contains word IDs.
num_paths:
The size of `n` in n-best.
ref_texts:
A list of reference transcript. Each entry contains space(s)
separated words
word_table:
It is the word symbol table.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
oov:
The out of vocabulary word.
Return:
Return a dict. Its key contains the information about the parameters
when calling this function, while its value contains the decoding output.
`len(ans_dict) == len(ref_texts)`
"""
device = lattice.device
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
hyps = nbest.build_levenshtein_graphs()
oov_id = word_table[oov]
word_ids_list = []
for text in ref_texts:
word_ids = []
for word in text.split():
if word in word_table:
word_ids.append(word_table[word])
else:
word_ids.append(oov_id)
word_ids_list.append(word_ids)
refs = k2.levenshtein_graph(word_ids_list, device=device)
levenshtein_alignment = k2.levenshtein_alignment(
refs=refs,
hyps=hyps,
hyp_to_ref_map=nbest.shape.row_ids(1),
sorted_match_ref=True,
)
tot_scores = levenshtein_alignment.get_tot_scores(
use_double_scores=False, log_semiring=False
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
return best_path
def rescore_with_n_best_list(
lattice: k2.Fsa,
G: k2.Fsa,
num_paths: int,
lm_scale_list: List[float],
nbest_scale: float = 1.0,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""Rescore an n-best list with an n-gram LM.
The path with the maximum score is used as the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc]. It must have the following
attributes: ``aux_labels`` and ``lm_scores``. Its labels are
token IDs and ``aux_labels`` word IDs.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
num_paths:
Size of nbest list.
lm_scale_list:
A list of float representing LM score scales.
nbest_scale:
Scale to be applied to ``lattice.score`` when sampling paths
using ``k2.random_paths``.
use_double_scores:
True to use double precision during computation. False to use
single precision.
Returns:
A dict of FsaVec, whose key is an lm_scale and the value is the
best decoding path for each utterance in the lattice.
"""
device = lattice.device
assert len(lattice.shape) == 3
assert hasattr(lattice, "aux_labels")
assert hasattr(lattice, "lm_scores")
assert G.shape == (1, None, None)
assert G.device == device
assert hasattr(G, "aux_labels") is False
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
# Now nbest.fsa has its scores set
assert hasattr(nbest.fsa, "lm_scores")
am_scores = nbest.compute_am_scores()
nbest = nbest.intersect(G)
# Now nbest contains only lm scores
lm_scores = nbest.tot_scores()
ans = dict()
for lm_scale in lm_scale_list:
tot_scores = am_scores.values / lm_scale + lm_scores.values
tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans
def rescore_with_whole_lattice(
lattice: k2.Fsa,
G_with_epsilon_loops: k2.Fsa,
lm_scale_list: Optional[List[float]] = None,
use_double_scores: bool = True,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
"""Intersect the lattice with an n-gram LM and use shortest path
to decode.
The input lattice is obtained by intersecting `HLG` with
a DenseFsaVec, where the `G` in `HLG` is in general a 3-gram LM.
The input `G_with_epsilon_loops` is usually a 4-gram LM. You can consider
this function as a second pass decoding. In the first pass decoding, we
use a small G, while we use a larger G in the second pass decoding.
Args:
lattice:
An FsaVec with axes [utt][state][arc]. Its `aux_lables` are word IDs.
It must have an attribute `lm_scores`.
G_with_epsilon_loops:
An FsaVec containing only a single FSA. It contains epsilon self-loops.
It is an acceptor and its labels are word IDs.
lm_scale_list:
Optional. If none, return the intersection of `lattice` and
`G_with_epsilon_loops`.
If not None, it contains a list of values to scale LM scores.
For each scale, there is a corresponding decoding result contained in
the resulting dict.
use_double_scores:
True to use double precision in the computation.
False to use single precision.
Returns:
If `lm_scale_list` is None, return a new lattice which is the intersection
result of `lattice` and `G_with_epsilon_loops`.
Otherwise, return a dict whose key is an entry in `lm_scale_list` and the
value is the decoding result (i.e., an FsaVec containing linear FSAs).
"""
# Nbest is not used in this function
assert hasattr(lattice, "lm_scores")
assert G_with_epsilon_loops.shape == (1, None, None)
device = lattice.device
lattice.scores = lattice.scores - lattice.lm_scores
# We will use lm_scores from G, so remove lats.lm_scores here
del lattice.lm_scores
assert hasattr(G_with_epsilon_loops, "lm_scores")
# Now, lattice.scores contains only am_scores
# inv_lattice has word IDs as labels.
# Its `aux_labels` is token IDs
inv_lattice = k2.invert(lattice)
num_seqs = lattice.shape[0]
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
max_loop_count = 10
loop_count = 0
while loop_count <= max_loop_count:
loop_count += 1
try:
rescoring_lattice = k2.intersect_device(
G_with_epsilon_loops,
inv_lattice,
b_to_a_map,
sorted_match_a=True,
)
rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice))
break
except RuntimeError as e:
logging.info(f"Caught exception:\n{e}\n")
logging.info(
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
)
logging.info(
"This OOM is not an error. You can ignore it. "
"If your model does not converge well, or --max-duration "
"is too large, or the input sound file is difficult to "
"decode, you will meet this exception."
)
# NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
# to avoid OOM. You may need to fine tune it.
inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True)
logging.info(
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
)
if loop_count > max_loop_count:
logging.info("Return None as the resulting lattice is too large")
return None
# lat has token IDs as labels
# and word IDs as aux_labels.
lat = k2.invert(rescoring_lattice)
if lm_scale_list is None:
return lat
ans = dict()
saved_am_scores = lat.scores - lat.lm_scores
for lm_scale in lm_scale_list:
am_scores = saved_am_scores / lm_scale
lat.scores = am_scores + lat.lm_scores
best_path = k2.shortest_path(lat, use_double_scores=use_double_scores)
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans
def rescore_with_attention_decoder(
lattice: k2.Fsa,
num_paths: int,
model: torch.nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
nbest_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""This function extracts `num_paths` paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest score is
the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `(T, N, C)`.
memory_key_padding_mask:
The padding mask for memory with shape `(N, T)`.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each utterance in the lattice.
"""
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
# Now nbest.fsa has its scores set.
# Also, nbest.fsa inherits the attributes from `lattice`.
assert hasattr(nbest.fsa, "lm_scores")
am_scores = nbest.compute_am_scores()
ngram_lm_scores = nbest.compute_lm_scores()
# The `tokens` attribute is set inside `compile_hlg.py`
assert hasattr(nbest.fsa, "tokens")
assert isinstance(nbest.fsa.tokens, torch.Tensor)
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
# the shape of memory is (T, N, C), so we use axis=1 here
expanded_memory = memory.index_select(1, path_to_utt_map)
if memory_key_padding_mask is not None:
# The shape of memory_key_padding_mask is (N, T), so we
# use axis=0 here.
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_utt_map
)
else:
expanded_memory_key_padding_mask = None
# remove axis corresponding to states.
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
tokens = tokens.remove_values_leq(0)
token_ids = tokens.tolist()
if len(token_ids) == 0:
print("Warning: rescore_with_attention_decoder(): empty token-ids")
return None
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
attention_scale_list = [attention_scale]
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
tot_scores = (
am_scores.values
+ n_scale * ngram_lm_scores.values
+ a_scale * attention_scores
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path
\ No newline at end of file
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import sys
from pathlib import Path
from typing import List, Tuple
import k2
import torch
def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
"""Read a lexicon from `filename`.
Each line in the lexicon contains "word p1 p2 p3 ...".
That is, the first field is a word and the remaining
fields are tokens. Fields are separated by space(s).
Args:
filename:
Path to the lexicon.txt
Returns:
A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])]
"""
ans = []
with open(filename, "r", encoding="utf-8") as f:
whitespace = re.compile("[ \t]+")
for line in f:
a = whitespace.split(line.strip(" \t\r\n"))
if len(a) == 0:
continue
if len(a) < 2:
logging.info(
f"Found bad line {line} in lexicon file {filename}"
)
logging.info(
"Every line is expected to contain at least 2 fields"
)
sys.exit(1)
word = a[0]
if word == "<eps>":
logging.info(
f"Found bad line {line} in lexicon file {filename}"
)
logging.info("<eps> should not be a valid word")
sys.exit(1)
tokens = a[1:]
ans.append((word, tokens))
return ans
def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
"""Write a lexicon to a file.
Args:
filename:
Path to the lexicon file to be generated.
lexicon:
It can be the return value of :func:`read_lexicon`.
"""
with open(filename, "w", encoding="utf-8") as f:
for word, tokens in lexicon:
f.write(f"{word} {' '.join(tokens)}\n")
def convert_lexicon_to_ragged(
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
) -> k2.RaggedTensor:
"""Read a lexicon and convert it to a ragged tensor.
The ragged tensor has two axes: [word][token].
Caution:
We assume that each word has a unique pronunciation.
Args:
filename:
Filename of the lexicon. It has a format that can be read
by :func:`read_lexicon`.
word_table:
The word symbol table.
token_table:
The token symbol table.
Returns:
A k2 ragged tensor with two axes [word][token].
"""
disambig_id = word_table["#0"]
# We reuse the same words.txt from the phone based lexicon
# so that we can share the same G.fst. Here, we have to
# exclude some words present only in the phone based lexicon.
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
# epsilon is not a word, but it occupies a position
#
row_splits = [0]
token_ids_list = []
lexicon_tmp = read_lexicon(filename)
lexicon = dict(lexicon_tmp)
if len(lexicon_tmp) != len(lexicon):
raise RuntimeError(
"It's assumed that each word has a unique pronunciation"
)
for i in range(disambig_id):
w = word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
tokens = lexicon[w]
token_ids = [token_table[k] for k in tokens]
row_splits.append(row_splits[-1] + len(token_ids))
token_ids_list.extend(token_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
row_splits,
None,
cached_tot_size,
)
values = torch.tensor(token_ids_list, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
class Lexicon(object):
"""Phone based lexicon."""
def __init__(
self,
lang_dir: Path,
disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Args:
lang_dir:
Path to the lang directory. It is expected to contain the following
files:
- tokens.txt
- words.txt
- L.pt
The above files are produced by the script `prepare.sh`. You
should have run that before running the training code.
disambig_pattern:
It contains the pattern for disambiguation symbols.
"""
lang_dir = Path(lang_dir)
self.token_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
self.word_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
if (lang_dir / "Linv.pt").exists():
logging.info(f"Loading pre-compiled {lang_dir}/Linv.pt")
L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt"))
else:
logging.info("Converting L.pt to Linv.pt")
L = k2.Fsa.from_dict(torch.load(lang_dir / "L.pt"))
L_inv = k2.arc_sort(L.invert())
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
# We save L_inv instead of L because it will be used to intersect with
# transcript FSAs, both of whose labels are word IDs.
self.L_inv = L_inv
self.disambig_pattern = disambig_pattern
@property
def tokens(self) -> List[int]:
"""Return a list of token IDs excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = self.token_table.symbols
ans = []
for s in symbols:
if not self.disambig_pattern.match(s):
ans.append(self.token_table[s])
if 0 in ans:
ans.remove(0)
ans.sort()
return ans
class UniqLexicon(Lexicon):
def __init__(
self,
lang_dir: Path,
uniq_filename: str = "uniq_lexicon.txt",
disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Refer to the help information in Lexicon.__init__.
uniq_filename: It is assumed to be inside the given `lang_dir`.
Each word in the lexicon is assumed to have a unique pronunciation.
"""
lang_dir = Path(lang_dir)
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
self.ragged_lexicon = convert_lexicon_to_ragged(
filename=lang_dir / uniq_filename,
word_table=self.word_table,
token_table=self.token_table,
)
# TODO: should we move it to a certain device ?
def texts_to_token_ids(
self, texts: List[str], oov: str = "<UNK>"
) -> k2.RaggedTensor:
"""
Args:
texts:
A list of transcripts. Each transcript contains space(s)
separated words. An example texts is::
['HELLO k2', 'HELLO icefall']
oov:
The OOV word. If a word in `texts` is not in the lexicon, it is
replaced with `oov`.
Returns:
Return a ragged int tensor with 2 axes [utterance][token_id]
"""
oov_id = self.word_table[oov]
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(oov_id)
word_ids_list.append(word_ids)
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
ans = self.ragged_lexicon.index(ragged_indexes)
ans = ans.remove_axis(ans.num_axes - 2)
return ans
def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor containing token IDs.
We assume there are no OOVs in "words".
"""
word_ids = [self.word_table[w] for w in words]
word_ids = torch.tensor(word_ids, dtype=torch.int32)
ragged, _ = self.ragged_lexicon.index(
indexes=word_ids,
axis=0,
need_value_indexes=False,
)
return ragged
\ No newline at end of file
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2
import k2.version
def get_texts(
best_paths: k2.Fsa, return_ragged: bool = False
) -> Union[List[List[int]], k2.RaggedTensor]:
"""Extract the texts (as word IDs) from the best-path FSAs.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
return_ragged:
True to return a ragged tensor with two axes [utt][word_id].
False to return a list-of-list word IDs.
Returns:
Returns a list of lists of int, containing the label sequences we
decoded.
"""
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
# remove 0's and -1's.
aux_labels = best_paths.aux_labels.remove_values_leq(0)
# TODO: change arcs.shape() to arcs.shape
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
# remove the states and arcs axes.
aux_shape = aux_shape.remove_axis(1)
aux_shape = aux_shape.remove_axis(1)
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
else:
# remove axis corresponding to states.
aux_shape = best_paths.arcs.shape().remove_axis(1)
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
# remove 0's and -1's.
aux_labels = aux_labels.remove_values_leq(0)
assert aux_labels.num_axes == 2
if return_ragged:
return aux_labels
else:
return aux_labels.tolist()
1 0
0 49
0 85
0 1 6 -77.41171594483659
0 2 0 -77.52820497047072
1 3 0 -76.30668366093958
1 4 0 -76.30668366093958
1 5 7 -76.12490222679712
2 4 6 -76.05664693464729
2 3 0 -76.30668366093958
3 7 7 -76.80004123240462
3 6 0 -77.52820497047072
4 6 0 -77.52820497047072
4 8 0 -77.52820497047072
4 9 7 -77.0766342784592
5 6 0 -77.52820497047072
5 9 0 -77.52820497047072
6 10 0 -76.30668366093958
7 10 0 -76.30668366093958
8 10 0 -76.30668366093958
9 10 0 -76.30668366093958
9 11 0 -76.30668366093958
10 12 1 -76.78742485656035
10 13 0 -77.52820497047072
11 13 0 -77.52820497047072
11 14 0 -77.52820497047072
12 15 0 -76.30668366093958
12 16 0 -76.30668366093958
12 17 2 -76.12121222008078
13 15 1 -76.08124381241319
13 16 0 -76.30668366093958
14 16 0 -76.30668366093958
14 18 0 -76.30668366093958
15 19 0 -77.52820497047072
15 20 0 -77.52820497047072
15 21 2 -77.04467980124757
16 20 0 -77.52820497047072
17 21 0 -77.52820497047072
18 20 0 -77.52820497047072
18 22 0 -77.52820497047072
19 23 0 -76.30668366093958
19 24 0 -76.30668366093958
19 25 2 -76.12121222008078
20 24 0 -76.30668366093958
21 25 0 -76.30668366093958
21 24 0 -76.30668366093958
21 26 3 -76.12121222008078
22 24 0 -76.30668366093958
22 27 0 -76.30668366093958
23 28 0 -77.52820497047072
23 29 0 -77.52820497047072
24 29 0 -77.52820497047072
24 31 3 -76.80004123240462
25 30 0 -77.52820497047072
25 29 0 -77.52820497047072
25 32 3 -77.04467980124757
26 32 0 -77.52820497047072
27 29 0 -77.52820497047072
27 33 0 -77.52820497047072
28 34 0 -76.30668366093958
29 34 0 -76.30668366093958
30 35 0 -76.30668366093958
30 34 0 -76.30668366093958
30 36 3 -76.12121222008078
31 34 0 -76.30668366093958
32 36 0 -76.30668366093958
32 34 0 -76.30668366093958
33 34 0 -76.30668366093958
34 39 9 -76.75389983574527
34 37 0 -76.62906099113884
34 40 5 -76.70998441304296
35 38 0 -76.62906099113884
35 37 0 -76.62906099113884
36 41 0 -76.62906099113884
36 42 5 -77.85764925878665
36 37 0 -76.62906099113884
37 43 0 -76.30668366093958
38 44 0 -76.30668366093958
38 43 0 -76.30668366093958
39 43 0 -76.30668366093958
40 43 0 -76.30668366093958
41 43 0 -76.30668366093958
42 43 0 -76.30668366093958
43 45 3 -76.80008784358508
44 46 3 -77.04476177443331
45 47 0 -76.30668366093958
46 47 0 -76.30668366093958
47 48 -1 -99
This source diff could not be displayed because it is too large. You can view the blob instead.
No preview for this file type
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment