Comments (2)
How about using time-syinc decoder?
from espnet.
@sw005320
Thank you for getting back to me so quickly.
How about using time-syinc decoder?
Yes, that was the second part of the profiling.
On GPU, and using 100 sentences:
- BatchBeamSearch: 148.815 seconds
- BeamSearchTimeSync: 203.136 seconds
TimeSync is slower in GPU. And moving all the operations to GPU (removing numpy from TimeSync) is even worse:
- BeamSearchTimeSyncTH: 1222.309
The code of the BeamSearch on pytorch is:
class BeamSearchTimeSyncTH(BeamSearchTimeSync):
def cached_score(self, h: Tuple[int], cache: dict, scorer: ScorerInterface) -> Any:
"""Retrieve decoder/LM scores which may be cached."""
root = h[:-1] # prefix
if root in cache:
root_scores = cache[root].scores
root_state = cache[root].state
root_log_sum = cache[root].log_sum
else: # run decoder fwd one step and update cache
root_root = root[:-1]
root_root_state = cache[root_root].state
root_scores, root_state = scorer.score(
torch.tensor(root, device=self.enc_output.device).long(),
root_root_state,
self.enc_output,
)
root_log_sum = cache[root_root].log_sum + float(
cache[root_root].scores[root[-1]]
)
cache[root] = CacheItem(
state=root_state, scores=root_scores, log_sum=root_log_sum
)
cand_score = float(root_scores[h[-1]])
score = root_log_sum + cand_score
return score
def joint_score(self, hyps: Any, ctc_score_dp: Any) -> Any:
"""Calculate joint score for hyps."""
scores = dict()
for h in hyps:
score = self.ctc_weight * torch.logaddexp(*ctc_score_dp[h]) # ctc score
if len(h) > 1 and self.decoder_weight > 0 and self.decoder is not None:
score += (
self.cached_score(h, self.attn_cache, self.decoder)
* self.decoder_weight
) # attn score
if len(h) > 1 and self.lm is not None and self.lm_weight > 0:
score += (
self.cached_score(h, self.lm_cache, self.lm) * self.lm_weight
) # lm score
score += self.penalty * (len(h) - 1) # penalty score
scores[h] = score
return scores
def time_step(self, p_ctc: Any, ctc_score_dp: Any, hyps: Any) -> Any:
"""Execute a single time step."""
device = p_ctc.device
pre_beam_threshold, _ = torch.sort(p_ctc)
pre_beam_threshold = pre_beam_threshold[-self.pre_beam_size]
cands = (p_ctc >= pre_beam_threshold).nonzero()[:, 0]
cands = torch.unique(cands, sorted=False)
if len(cands) == 0:
cands = {torch.argmax(p_ctc)}
new_hyps = set()
ctc_score_dp_next = defaultdict(
lambda: (torch.Tensor([float("-inf")]).to(device), torch.Tensor([float("-inf")]).to(device))
) # (p_nb, p_b)
tmp = []
for hyp_l in hyps:
p_prev_l = torch.logaddexp(*ctc_score_dp[hyp_l])
for c in cands:
if c == self.blank:
logger.debug("blank cand, hypothesis is " + str(hyp_l))
p_nb, p_b = ctc_score_dp_next[hyp_l]
p_b = torch.logaddexp(p_b, p_ctc[c] + p_prev_l)
ctc_score_dp_next[hyp_l] = (p_nb, p_b)
new_hyps.add(hyp_l)
else:
l_plus = hyp_l + (int(c),)
logger.debug("non-blank cand, hypothesis is " + str(l_plus))
p_nb, p_b = ctc_score_dp_next[l_plus]
if c == hyp_l[-1]:
logger.debug("repeat cand, hypothesis is " + str(hyp_l))
p_nb_prev, p_b_prev = ctc_score_dp[hyp_l]
p_nb = torch.logaddexp(p_nb, p_ctc[c] + p_b_prev)
p_nb_l, p_b_l = ctc_score_dp_next[hyp_l]
p_nb_l = torch.logaddexp(p_nb_l, p_ctc[c] + p_nb_prev)
ctc_score_dp_next[hyp_l] = (p_nb_l, p_b_l)
else:
p_nb = torch.logaddexp(p_nb, p_ctc[c] + p_prev_l)
if l_plus not in hyps and l_plus in ctc_score_dp:
p_b = torch.logaddexp(
p_b, p_ctc[self.blank] + torch.logaddexp(*ctc_score_dp[l_plus])
)
p_nb = torch.logaddexp(p_nb, p_ctc[c] + ctc_score_dp[l_plus][0])
tmp.append(l_plus)
ctc_score_dp_next[l_plus] = (p_nb, p_b)
new_hyps.add(l_plus)
scores = self.joint_score(new_hyps, ctc_score_dp_next)
hyps = sorted(new_hyps, key=lambda ll: scores[ll], reverse=True)[
: self.beam_size
]
ctc_score_dp = ctc_score_dp_next.copy()
return ctc_score_dp, hyps, scores
def forward(
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
) -> List[Hypothesis]:
"""Perform beam search.
Args:
enc_output (torch.Tensor)
Return:
list[Hypothesis]
"""
device = x.device
logger.info("decoder input lengths: " + str(x.shape[0]))
lpz = self.ctc.log_softmax(x.unsqueeze(0))
lpz = lpz.squeeze(0)
self.reset(x)
hyps = [(self.sos,)]
ctc_score_dp = defaultdict(
lambda: (torch.Tensor([float("-inf")]).to(device), torch.Tensor([float("-inf")]).to(device))
) # (p_nb, p_b) - dp object tracking p_ctc
ctc_score_dp[(self.sos,)] = (torch.Tensor([float("-inf")]).to(device), torch.Tensor([0.0]).to(device))
for t in range(lpz.shape[0]):
logger.debug("position " + str(t))
ctc_score_dp, hyps, scores = self.time_step(lpz[t, :], ctc_score_dp, hyps)
ret = [
Hypothesis(yseq=torch.tensor(list(h) + [self.sos]), score=float(scores[h]))
for h in hyps
]
best_hyp = "".join([self.token_list[x] for x in ret[0].yseq.tolist()][1:-1])
best_hyp_len = len(ret[0].yseq) - 2
best_score = ret[0].score
logger.info(f"output length: {best_hyp_len}")
logger.info(f"total log probability: {best_score:.2f}")
logger.info(f"best hypo: {best_hyp}")
return ret
from espnet.
Related Issues (20)
- Changes that requires to be made while using wav2vec2.0(CLSRIL-23.pt) features for training CTC/Attention based training HOT 6
- Error when training VITS model for vctk dataset HOT 3
- No such parameter e_branchformer_ctc in encoder parameter HOT 3
- X-vector based TTS model packaging broken in tts.sh HOT 1
- USES `ref_channel` usage HOT 4
- Question regarding switching speakers, weights during runtime.
- Question about asr2.sh and its options to reproduce the librispeech_100 recipe. HOT 5
- An error when using LoRA for s3prl frontend. HOT 1
- TSE with Librimix: mismatch in number of speakers HOT 4
- Streaming ASR model latency issue HOT 6
- asr_train.py: error: unrecognized arguments: use_lora HOT 1
- Espnet Collect stats: s3prl Upstream 'hubert-large-ll60k' HOT 5
- How to use 960h LM? HOT 1
- about dc_crn training HOT 1
- Cannot retrieve the public link of the file when running espnet_tts_demo
- [ERROR] The torch version has been changed. Please report to espnet administrators make: *** [Makefile:203: fairscale.done] Error 1
- [ERROR] The torch version has been changed. Please report to espnet administrators make: *** [Makefile:203: fairscale.done] Error 1
- Which class or python script is used for training the linear preencoder used in the config file used during an ASR task which uses representations from XLSR-128 model??
- How to use whisper as frontend?
- Does conformer support batch_size > 1 in ASR task inference? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from espnet.