GithubHelp home page GithubHelp logo

Comments (29)

pkufool avatar pkufool commented on July 1, 2024

When did it turn into nan or inf? At the beginning of the training or middle of training, could you please upload the training log here, thanks!

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

When did it turn into nan or inf? At the beginning of the training or middle of training, could you please upload the training log here, thanks!

At the beginning of the training (pruned_loss_scaled = 0) the loss trun into nan. After 10000 num_updates, the pruned_loss_scaled was set as 0.1 and the loss turn into inf.

Soryy, something went wrong when I upload the log.

from fast_rnnt.

pkufool avatar pkufool commented on July 1, 2024

Do you have any sequences that U > T, I mean the number of tokens in transcript is greater than the number of frames.

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

Do you have any sequences that U > T, I mean the number of tokens in transcript is greater than the number of frames.

The sample rate is 4 depends on 2 maxpooling lalyers. So the tokens U in unlikely to be greater than T.

I put some logs here:

epoch 3 ; loss inf; num updates 16100 ; lr 0.000704907
epoch 3 ; loss 1.13339; num updates 16200 ; lr 0.000702728
epoch 3 ; loss 1.13215; num updates 16300 ; lr 0.000700569
epoch 3 ; loss inf; num updates 16400 ; lr 0.000698043

from fast_rnnt.

danpovey avatar danpovey commented on July 1, 2024

What iteration did the loss become inf on, and what kind of model were you using?

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

What iteration did the loss become inf on, and what kind of model were you using?

The loss become inf at epoch 2, where the pruned_loss_scaled is set to 0.1

The ConformerTransducer model is configured as follows:
Encoder: 2 vggblock + 12 conformer and + 1 lstmp + 1layrenorm
Decoder: 2 lstm + droupout
Joiner: is condigured as k2

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

What iteration did the loss become inf on, and what kind of model were you using?

Other configurations of the joiner is as follows:
lm_only_scale = 0.25
am_only_scale = 0
prune_range = 4
simple_loss_scale= 0.5

pruned_loss_scaled = 0 if num_updates <= 10000
pruned_loss_scaled = 0.1 if 10000 < num_updates <= 20000
pruned_loss_scaled = 1 if num_updates > 20000

from fast_rnnt.

pkufool avatar pkufool commented on July 1, 2024

Can you dump the input of the batches that leads to the inf loss, so we can use it to debug this issue. Thanks.

from fast_rnnt.

danpovey avatar danpovey commented on July 1, 2024

@pkufool perhaps it was not obvious to him how to do this?
Also, @Butterfly-c , are you using fp16 / half-precision for training? It can be tricky to tune a network to perform OK with fp16.
One possibility is to detect inf in the loss , e.g. by comparing (loss - loss) to 0, and skip the update and print a warning.
If you have any utterances in your training set that have too-long transcripts for the utterance length, those could lead to inf loss. It's possible that the model is training OK, if the individual losses on most batches stay finite, even though the overall loss may be infinite. Cases with too-long transcripts will generate infinite loss but will not generate infinite gradients.

from fast_rnnt.

pkufool avatar pkufool commented on July 1, 2024

@Butterfly-c Suppose you used pruned loss like this:

simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed(
      lm=decoder_out,
      am=encoder_out,
      symbols=symbols,
      termination_symbol=blank_id,
      lm_only_scale=lm_scale,
      am_only_scale=am_scale,
      boundary=boundary,
      reduction="sum",
      return_grad=True,
  )

  # ranges : [B, T, prune_range]
  ranges = fast_rnnt.get_rnnt_prune_ranges(
      px_grad=px_grad,
      py_grad=py_grad,
      boundary=boundary,
      s_range=prune_range,
  )

  # am_pruned : [B, T, prune_range, C]
  # lm_pruned : [B, T, prune_range, C]
  am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
      am=encoder_out, lm=decoder_out, ranges=ranges
  )

  # logits : [B, T, prune_range, C]
  logits = joiner(am_pruned, lm_pruned)

  pruned_loss = fast_rnnt.rnnt_loss_pruned(
      logits=logits,
      symbols=symbols,
      ranges=ranges,
      termination_symbol=blank_id,
      boundary=boundary,
      reduction="sum",
  )

You can dump the bad cases as follows:

if simple_loss - simple_loss != 0:
  simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary}
  torch.save(simple_input, "simple_bad_case.pt")

if pruned_loss - pruned_loss != 0:
  pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary}
  torch.save(pruned_input, "pruned_bad_case.pt")

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

@pkufool perhaps it was not obvious to him how to do this? Also, @Butterfly-c , are you using fp16 / half-precision for training? It can be tricky to tune a network to perform OK with fp16. One possibility is to detect inf in the loss , e.g. by comparing (loss - loss) to 0, and skip the update and print a warning. If you have any utterances in your training set that have too-long transcripts for the utterance length, those could lead to inf loss. It's possible that the model is training OK, if the individual losses on most batches stay finite, even though the overall loss may be infinite. Cases with too-long transcripts will generate infinite loss but will not generate infinite gradients.

Thanks for your kindly reply!
I have decoded one model from epoch 4, and the decoding result is ok. But, I'm still confused with the inf loss.
The max_frame is set to 2500 (i.e. 25s ) in my training environment.
I'm curious how long the sentence is can be defined as too-long transcripts?

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

@Butterfly-c Suppose you used pruned loss like this:

simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed(
      lm=decoder_out,
      am=encoder_out,
      symbols=symbols,
      termination_symbol=blank_id,
      lm_only_scale=lm_scale,
      am_only_scale=am_scale,
      boundary=boundary,
      reduction="sum",
      return_grad=True,
  )

  # ranges : [B, T, prune_range]
  ranges = fast_rnnt.get_rnnt_prune_ranges(
      px_grad=px_grad,
      py_grad=py_grad,
      boundary=boundary,
      s_range=prune_range,
  )

  # am_pruned : [B, T, prune_range, C]
  # lm_pruned : [B, T, prune_range, C]
  am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
      am=encoder_out, lm=decoder_out, ranges=ranges
  )

  # logits : [B, T, prune_range, C]
  logits = joiner(am_pruned, lm_pruned)

  pruned_loss = fast_rnnt.rnnt_loss_pruned(
      logits=logits,
      symbols=symbols,
      ranges=ranges,
      termination_symbol=blank_id,
      boundary=boundary,
      reduction="sum",
  )

You can dump the bad cases as follows:

if simple_loss - simple_loss != 0:
  simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary}
  torch.save(simple_input, "simple_bad_case.pt")

if pruned_loss - pruned_loss != 0:
  pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary}
  torch.save(pruned_input, "pruned_bad_case.pt")

Thanks for your suggestion, I'm trying to upload the pruned_bad_case.pt for you to debug the inf issue. It'll take me some time.

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

We have compared two models trained with the warp-transducer and the fast-rnnt seperately,but the The GPU usage does not decrease significantly.

Intuitively, the training time of the two models is as follows:
loss times_per_update
warp-transducer 7m40s
fast-rnnt 6m40s

The models above are both tained with v100-32G-4gpu * 2 (i.e. 8gpu).
Is there any suggestion to accelerate the training?

from fast_rnnt.

csukuangfj avatar csukuangfj commented on July 1, 2024
  1. What is your vocabulary size?
  2. What is your batch size? And how much data does each batch contain (i.e., what is the total duration )?
  3. Is your GPU usage over 90% ? (You can get such information with watch -n 0.5 nvidia-smi)
  4. What is the value of prune_range?

the The GPU usage does not decrease significantly.

What do you want to express ?

from fast_rnnt.

csukuangfj avatar csukuangfj commented on July 1, 2024

I'm curious how long the sentence is can be defined as too-long transcripts?

If the sentence is broken into BPE tokens, it is "too long" if the number of BPE tokens is larger than the number of acoustic frames (after subsampling) of this sentence.

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024
  1. What is your vocabulary size?
  2. What is your batch size? And how much data does each batch contain (i.e., what is the total duration )?
  3. Is your GPU usage over 90% ? (You can get such information with watch -n 0.5 nvidia-smi)
  4. What is the value of prune_range?

the The GPU usage does not decrease significantly.

What do you want to express ?

Some configuration of my environment is as follows:

1、The vocabulary size is 8245,which contains 6726 Chinese characters,1514 bpe subwords and 5 special symbols.
2、The batch size is 5000 frames (i.e. 50s).
3、As "watch -n 0.5 nvidia-smi" is conducted,the peak volatile gpu-util is over 90%, but most time it is between 80% -90%
4、The pruned_range is 4.

As shown in this paper https://arxiv.org/abs/2206.13236
, the peak GPU usage of fast_rnnt is far below warp-transducer ,and the training time has also been greatly reduced. But as the fast_rnnt conducted in our environment,the training time are not reduced as expected.
As conducted with the same batch size (50s),the statistics of the training time are as follows:
loss times_per_update
warp-transducer 7m40s
fast-rnnt 6m40s

Finally, I have another question about the training time. As shown in the paper, the training time per batch of optimized transducer is over 4 times than fast_rnnt. But the training time per epoch of optimized transducer is just 2 times than fast_rnnt.

I really appreciate for your reply.

from fast_rnnt.

danpovey avatar danpovey commented on July 1, 2024

I think the comparisons in the paper may have just been for the core RNN-T loss. It does not count the neural net forward, which would not be affected by speedups in the loss computation.

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

I think the comparisons in the paper may have just been for the core RNN-T loss. It does not count the neural net forward, which would not be affected by speedups in the loss computation.

Thanks for your reply, which solved my confusion.

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

@Butterfly-c Suppose you used pruned loss like this:

simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed(
      lm=decoder_out,
      am=encoder_out,
      symbols=symbols,
      termination_symbol=blank_id,
      lm_only_scale=lm_scale,
      am_only_scale=am_scale,
      boundary=boundary,
      reduction="sum",
      return_grad=True,
  )

  # ranges : [B, T, prune_range]
  ranges = fast_rnnt.get_rnnt_prune_ranges(
      px_grad=px_grad,
      py_grad=py_grad,
      boundary=boundary,
      s_range=prune_range,
  )

  # am_pruned : [B, T, prune_range, C]
  # lm_pruned : [B, T, prune_range, C]
  am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
      am=encoder_out, lm=decoder_out, ranges=ranges
  )

  # logits : [B, T, prune_range, C]
  logits = joiner(am_pruned, lm_pruned)

  pruned_loss = fast_rnnt.rnnt_loss_pruned(
      logits=logits,
      symbols=symbols,
      ranges=ranges,
      termination_symbol=blank_id,
      boundary=boundary,
      reduction="sum",
  )

You can dump the bad cases as follows:

if simple_loss - simple_loss != 0:
  simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary}
  torch.save(simple_input, "simple_bad_case.pt")

if pruned_loss - pruned_loss != 0:
  pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary}
  torch.save(pruned_input, "pruned_bad_case.pt")

Based on your suggestion, I saved some bad cases. What's interesting is that most of the 'ranges' are all zero tensors.

For example, when the training sample is a backgound music, the label is only one symbol. The lm.shape and am.shape are as follows:
decoder_out [1, 2, 8245]
encoder_out [1, 314, 8245]

Is the training loss will become inf when the input and output are unbalanced(i.e. input is far smaller than output) ?
Can you give some explanation?

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

After I filtering the training data as follows, the inf problem has decreased:
1、 label_len > 2
2、 feat_len//label_len > 30

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

Due to the network limitations, I will share the pruned_bad_case.pt latter.

from fast_rnnt.

pkufool avatar pkufool commented on July 1, 2024

For example, when the training sample is a backgound music, the label is only one symbol. The lm.shape and am.shape are as follows:
decoder_out [1, 2, 8245]
encoder_out [1, 314, 8245]

Only one sequence has only one symbol? or all the sequences in one batch have only one symbol?
Thanks, this is very valueable infomation for us.

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

For example, when the training sample is a backgound music, the label is only one symbol. The lm.shape and am.shape are as follows:
decoder_out [1, 2, 8245]
encoder_out [1, 314, 8245]

Only one sequence has only one symbol? or all the sequences in one batch have only one symbol? Thanks, this is very valueable infomation for us.

Based on 40 pruned_bad_case.pts, all of the bad cases are "all the sequences in one batch have only one symbol". And the sum of 'ranges' are all zero tensors.

from fast_rnnt.

pkufool avatar pkufool commented on July 1, 2024

OK, Thanks! That's it. I think our code did not handle S==1 properly, will try to fix it.

from fast_rnnt.

pkufool avatar pkufool commented on July 1, 2024

@Butterfly-c If you have problem uploading your bad cases to github, can you send your bad cases to me via email, [email protected]. I want them to test my fixes, Thanks!

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

@Butterfly-c If you have problem uploading your bad cases to github, can you send your bad cases to me via email, [email protected]. I want them to test my fixes, Thanks!

Due to data permissions, I can't share the bad case information until I get permission. The permission is on the way.

from fast_rnnt.

pkufool avatar pkufool commented on July 1, 2024

Ok, I think there won't be any characters and waves in your bad cases, only float and integer numbers. Hope you can get the permissions, I am testing it with random generated bad cases. Thanks.

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

Ok, I think there won't be any characters and waves in your bad cases, only float and integer numbers. Hope you can get the permissions, I am testing it with random generated bad cases. Thanks.

OK, I will contact you as soon as I get the permission.

from fast_rnnt.

Butterfly-c avatar Butterfly-c commented on July 1, 2024

After updating the fast-rnnt to the version of "fix_s_range", the "inf" problem has been fixed. Thanks!

from fast_rnnt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.