GithubHelp home page GithubHelp logo

Comments (8)

nglehuy avatar nglehuy commented on June 11, 2024

Hi @yiqiaoc11
I'm training on TPUs to validate this.
Are you using the warp-transducer loss or the rnnt loss in tensorflow?
So far as I'm testing with rnnt loss in tensorflow for the past months, it has some issues with convergence. But I dont have resources to test with GPUs.

from tensorflowasr.

yiqiaoc11 avatar yiqiaoc11 commented on June 11, 2024

@usimarit Thanks for comments. All my test were conducted on GPU as bs = 2.
I tried once with TensorFlowASR\examples\rnn_transducer\config.yml with warp-transducer loss and observed same #231(#231). Then I switched to rnnt-loss while using pretrained models' .yml containing warmup_steps for transformer scheduler. The loss is shown below,
Epoch 1/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 339.2118 - val_loss: 160.0162
Epoch 2/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 254.4009 - val_loss: 147.7653
Epoch 3/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 241.0356 - val_loss: 144.2561
Epoch 4/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 231.6980 - val_loss: 140.2191
Epoch 5/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 223.2308 - val_loss: 137.6041
Epoch 6/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 216.7098 - val_loss: 136.0396

4/6-layer encoder worked with different warmup steps in case of rnnt-loss, but not 8. Just trying to recover the performance of the pretrained. Conformer reportedly works which differs only with rnn_transducer.

Feel to advise and I can try it on GPU here.

from tensorflowasr.

nglehuy avatar nglehuy commented on June 11, 2024

@yiqiaoc11 Could you help me train 2 models for 30 epochs using rnnt-loss:

  1. 4-layers encoder
  2. 8-layers encoder

Then plot the loss of 2 models for better comparison?
Other configs are the same.

from tensorflowasr.

yiqiaoc11 avatar yiqiaoc11 commented on June 11, 2024

@usimarit, using the streaming config.yml (https://drive.google.com/file/d/1xYFYi3z94ZqaQZ-cTyiNekBwhITh1Ru4l) with warmup_steps=40000 , right?

From the timeline, you seemed to apply warp-transducer loss to get the pretrained .h5 weights.

from tensorflowasr.

nglehuy avatar nglehuy commented on June 11, 2024

@yiqiaoc11 Yes, with the pretrained config

I trained the rnn transducer on TPUs so warp-transducer loss cannot be applied, only rnnt-loss can be used here. But you can experiment with warp-transducer loss too, plotting the loss of 2 models for better comparison.

from tensorflowasr.

yiqiaoc11 avatar yiqiaoc11 commented on June 11, 2024

@usimarit, Now I'm having 2 x 3090, 2 x 30 epochs will take fairly long time with rnnt-loss. Now 8-layer doesn't converge and 4-layer converge with > 40000 warmup-steps. Conformer using the same rnnt-loss works. Could rnn_transducer differ while you pretrained it giving same loss, same optimizer, same number of weights?

from tensorflowasr.

nglehuy avatar nglehuy commented on June 11, 2024

@yiqiaoc11 The rnn_transducer structure stays the same in version v1.0.x
Is the number of weights in your case the same as in the pretrained example?

from tensorflowasr.

yiqiaoc11 avatar yiqiaoc11 commented on June 11, 2024

Yes, the number of weights and distributions of layers are same, but other config information from the pretrained isn't tractable. Not sure what leads to the underfitting observed.

Primary loss curves for 4/8layer are posted for differentiation. Green curves are for 8-layer while blue 4-layer. Losses are very similar while models were tuned under the same .yml in GDrive. They don't converge.
Untitled

[2023-02-09 09:03:10] PRINT Layer (type) Output Shape Param #
[2023-02-09 09:03:10] PRINT ====================================================================================================
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_reshape (Resha multiple 0
[2023-02-09 09:03:10] PRINT pe)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_0 (RnnTr multiple 5511488
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_1 (RnnTr multiple 7149888
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_2 (RnnTr multiple 5839168
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_3 (RnnTr multiple 5839168
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT ====================================================================================================
[2023-02-09 09:03:10] PRINT Total params: 24,339,712
[2023-02-09 09:03:10] PRINT Trainable params: 24,339,712
[2023-02-09 09:03:10] PRINT Non-trainable params: 0

[2023-02-09 09:03:15] PRINT Layer (type) Output Shape Param #
[2023-02-09 09:03:15] PRINT ====================================================================================================
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_reshape (Resha multiple 0
[2023-02-09 09:03:15] PRINT pe)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_0 (RnnTr multiple 5511488
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_1 (RnnTr multiple 7149888
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_2 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_3 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_4 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_5 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_6 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_7 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT ====================================================================================================
[2023-02-09 09:03:15] PRINT Total params: 47,696,384
[2023-02-09 09:03:15] PRINT Trainable params: 47,696,384
[2023-02-09 09:03:15] PRINT Non-trainable params: 0

from tensorflowasr.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.