GithubHelp home page GithubHelp logo

Comments (2)

sw005320 avatar sw005320 commented on May 27, 2024

How about using time-syinc decoder?

from espnet.

Fhrozen avatar Fhrozen commented on May 27, 2024

@sw005320
Thank you for getting back to me so quickly.

How about using time-syinc decoder?

Yes, that was the second part of the profiling.

On GPU, and using 100 sentences:

  • BatchBeamSearch: 148.815 seconds
  • BeamSearchTimeSync: 203.136 seconds

TimeSync is slower in GPU. And moving all the operations to GPU (removing numpy from TimeSync) is even worse:

  • BeamSearchTimeSyncTH: 1222.309

The code of the BeamSearch on pytorch is:

class BeamSearchTimeSyncTH(BeamSearchTimeSync):

    def cached_score(self, h: Tuple[int], cache: dict, scorer: ScorerInterface) -> Any:
        """Retrieve decoder/LM scores which may be cached."""
        root = h[:-1]  # prefix
        if root in cache:
            root_scores = cache[root].scores
            root_state = cache[root].state
            root_log_sum = cache[root].log_sum
        else:  # run decoder fwd one step and update cache
            root_root = root[:-1]
            root_root_state = cache[root_root].state
            root_scores, root_state = scorer.score(
                torch.tensor(root, device=self.enc_output.device).long(),
                root_root_state,
                self.enc_output,
            )
            root_log_sum = cache[root_root].log_sum + float(
                cache[root_root].scores[root[-1]]
            )
            cache[root] = CacheItem(
                state=root_state, scores=root_scores, log_sum=root_log_sum
            )
        cand_score = float(root_scores[h[-1]])
        score = root_log_sum + cand_score

        return score

    def joint_score(self, hyps: Any, ctc_score_dp: Any) -> Any:
        """Calculate joint score for hyps."""
        scores = dict()
        for h in hyps:
            score = self.ctc_weight * torch.logaddexp(*ctc_score_dp[h])  # ctc score
            if len(h) > 1 and self.decoder_weight > 0 and self.decoder is not None:
                score += (
                    self.cached_score(h, self.attn_cache, self.decoder)
                    * self.decoder_weight
                )  # attn score
            if len(h) > 1 and self.lm is not None and self.lm_weight > 0:
                score += (
                    self.cached_score(h, self.lm_cache, self.lm) * self.lm_weight
                )  # lm score
            score += self.penalty * (len(h) - 1)  # penalty score
            scores[h] = score
        return scores

    def time_step(self, p_ctc: Any, ctc_score_dp: Any, hyps: Any) -> Any:
        """Execute a single time step."""
        device = p_ctc.device
        pre_beam_threshold, _ = torch.sort(p_ctc)
        pre_beam_threshold = pre_beam_threshold[-self.pre_beam_size]
        cands = (p_ctc >= pre_beam_threshold).nonzero()[:, 0]
        cands = torch.unique(cands, sorted=False)

        if len(cands) == 0:
            cands = {torch.argmax(p_ctc)}
        new_hyps = set()
        ctc_score_dp_next = defaultdict(
            lambda: (torch.Tensor([float("-inf")]).to(device), torch.Tensor([float("-inf")]).to(device))
        )  # (p_nb, p_b)
        tmp = []
        for hyp_l in hyps:
            p_prev_l = torch.logaddexp(*ctc_score_dp[hyp_l])
            for c in cands:
                if c == self.blank:
                    logger.debug("blank cand, hypothesis is " + str(hyp_l))
                    p_nb, p_b = ctc_score_dp_next[hyp_l]
                    p_b = torch.logaddexp(p_b, p_ctc[c] + p_prev_l)
                    ctc_score_dp_next[hyp_l] = (p_nb, p_b)
                    new_hyps.add(hyp_l)
                else:
                    l_plus = hyp_l + (int(c),)
                    logger.debug("non-blank cand, hypothesis is " + str(l_plus))
                    p_nb, p_b = ctc_score_dp_next[l_plus]
                    if c == hyp_l[-1]:
                        logger.debug("repeat cand, hypothesis is " + str(hyp_l))
                        p_nb_prev, p_b_prev = ctc_score_dp[hyp_l]
                        p_nb = torch.logaddexp(p_nb, p_ctc[c] + p_b_prev)
                        p_nb_l, p_b_l = ctc_score_dp_next[hyp_l]
                        p_nb_l = torch.logaddexp(p_nb_l, p_ctc[c] + p_nb_prev)
                        ctc_score_dp_next[hyp_l] = (p_nb_l, p_b_l)
                    else:
                        p_nb = torch.logaddexp(p_nb, p_ctc[c] + p_prev_l)
                    if l_plus not in hyps and l_plus in ctc_score_dp:
                        p_b = torch.logaddexp(
                            p_b, p_ctc[self.blank] + torch.logaddexp(*ctc_score_dp[l_plus])
                        )
                        p_nb = torch.logaddexp(p_nb, p_ctc[c] + ctc_score_dp[l_plus][0])
                        tmp.append(l_plus)
                    ctc_score_dp_next[l_plus] = (p_nb, p_b)
                    new_hyps.add(l_plus)

        scores = self.joint_score(new_hyps, ctc_score_dp_next)

        hyps = sorted(new_hyps, key=lambda ll: scores[ll], reverse=True)[
            : self.beam_size
        ]
        ctc_score_dp = ctc_score_dp_next.copy()
        return ctc_score_dp, hyps, scores

    def forward(
        self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
    ) -> List[Hypothesis]:
        """Perform beam search.

        Args:
            enc_output (torch.Tensor)

        Return:
            list[Hypothesis]

        """
        device = x.device
        logger.info("decoder input lengths: " + str(x.shape[0]))
        lpz = self.ctc.log_softmax(x.unsqueeze(0))
        lpz = lpz.squeeze(0)
        self.reset(x)
        hyps = [(self.sos,)]
        ctc_score_dp = defaultdict(
            lambda: (torch.Tensor([float("-inf")]).to(device), torch.Tensor([float("-inf")]).to(device))
        )  # (p_nb, p_b) - dp object tracking p_ctc
        ctc_score_dp[(self.sos,)] = (torch.Tensor([float("-inf")]).to(device), torch.Tensor([0.0]).to(device))
        for t in range(lpz.shape[0]):
            logger.debug("position " + str(t))
            ctc_score_dp, hyps, scores = self.time_step(lpz[t, :], ctc_score_dp, hyps)

        ret = [
            Hypothesis(yseq=torch.tensor(list(h) + [self.sos]), score=float(scores[h]))
            for h in hyps
        ]
        best_hyp = "".join([self.token_list[x] for x in ret[0].yseq.tolist()][1:-1])
        best_hyp_len = len(ret[0].yseq) - 2
        best_score = ret[0].score
        logger.info(f"output length: {best_hyp_len}")
        logger.info(f"total log probability: {best_score:.2f}")
        logger.info(f"best hypo: {best_hyp}")

        return ret

from espnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.