Comments (8)
Hi @yiqiaoc11
I'm training on TPUs to validate this.
Are you using the warp-transducer loss or the rnnt loss in tensorflow?
So far as I'm testing with rnnt loss in tensorflow for the past months, it has some issues with convergence. But I dont have resources to test with GPUs.
from tensorflowasr.
@usimarit Thanks for comments. All my test were conducted on GPU as bs = 2.
I tried once with TensorFlowASR\examples\rnn_transducer\config.yml with warp-transducer loss and observed same #231(#231). Then I switched to rnnt-loss while using pretrained models' .yml containing warmup_steps for transformer scheduler. The loss is shown below,
Epoch 1/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 339.2118 - val_loss: 160.0162
Epoch 2/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 254.4009 - val_loss: 147.7653
Epoch 3/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 241.0356 - val_loss: 144.2561
Epoch 4/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 231.6980 - val_loss: 140.2191
Epoch 5/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 223.2308 - val_loss: 137.6041
Epoch 6/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 216.7098 - val_loss: 136.0396
4/6-layer encoder worked with different warmup steps in case of rnnt-loss, but not 8. Just trying to recover the performance of the pretrained. Conformer reportedly works which differs only with rnn_transducer.
Feel to advise and I can try it on GPU here.
from tensorflowasr.
@yiqiaoc11 Could you help me train 2 models for 30 epochs using rnnt-loss:
- 4-layers encoder
- 8-layers encoder
Then plot the loss of 2 models for better comparison?
Other configs are the same.
from tensorflowasr.
@usimarit, using the streaming config.yml (https://drive.google.com/file/d/1xYFYi3z94ZqaQZ-cTyiNekBwhITh1Ru4l) with warmup_steps=40000 , right?
From the timeline, you seemed to apply warp-transducer loss to get the pretrained .h5 weights.
from tensorflowasr.
@yiqiaoc11 Yes, with the pretrained config
I trained the rnn transducer on TPUs so warp-transducer loss cannot be applied, only rnnt-loss can be used here. But you can experiment with warp-transducer loss too, plotting the loss of 2 models for better comparison.
from tensorflowasr.
@usimarit, Now I'm having 2 x 3090, 2 x 30 epochs will take fairly long time with rnnt-loss. Now 8-layer doesn't converge and 4-layer converge with > 40000 warmup-steps. Conformer using the same rnnt-loss works. Could rnn_transducer differ while you pretrained it giving same loss, same optimizer, same number of weights?
from tensorflowasr.
@yiqiaoc11 The rnn_transducer structure stays the same in version v1.0.x
Is the number of weights in your case the same as in the pretrained example?
from tensorflowasr.
Yes, the number of weights and distributions of layers are same, but other config information from the pretrained isn't tractable. Not sure what leads to the underfitting observed.
Primary loss curves for 4/8layer are posted for differentiation. Green curves are for 8-layer while blue 4-layer. Losses are very similar while models were tuned under the same .yml in GDrive. They don't converge.
[2023-02-09 09:03:10] PRINT Layer (type) Output Shape Param #
[2023-02-09 09:03:10] PRINT ====================================================================================================
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_reshape (Resha multiple 0
[2023-02-09 09:03:10] PRINT pe)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_0 (RnnTr multiple 5511488
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_1 (RnnTr multiple 7149888
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_2 (RnnTr multiple 5839168
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_3 (RnnTr multiple 5839168
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT ====================================================================================================
[2023-02-09 09:03:10] PRINT Total params: 24,339,712
[2023-02-09 09:03:10] PRINT Trainable params: 24,339,712
[2023-02-09 09:03:10] PRINT Non-trainable params: 0
[2023-02-09 09:03:15] PRINT Layer (type) Output Shape Param #
[2023-02-09 09:03:15] PRINT ====================================================================================================
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_reshape (Resha multiple 0
[2023-02-09 09:03:15] PRINT pe)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_0 (RnnTr multiple 5511488
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_1 (RnnTr multiple 7149888
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_2 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_3 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_4 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_5 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_6 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_7 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT ====================================================================================================
[2023-02-09 09:03:15] PRINT Total params: 47,696,384
[2023-02-09 09:03:15] PRINT Trainable params: 47,696,384
[2023-02-09 09:03:15] PRINT Non-trainable params: 0
from tensorflowasr.
Related Issues (20)
- unexpected truncation of the dataset HOT 2
- TypeError: Unable to serialize 144.0 to JSON. Unrecognized type <class 'tensorflow.python.framework.ops.EagerTensor HOT 6
- ValueError: Shape mismatch in layer #1 (named conformer_prediction) for weight conformer/conformer_prediction/conformer_prediction_embedding/embeddings:0 HOT 1
- Multi-GPU card training with MirrorStrategy wait forever after loading the cudnn HOT 1
- Empty TFLite output HOT 3
- rnn_transducer test error HOT 1
- rnn transducer inference
- RNNT loss HOT 1
- conformer.tflite increases the wer compared to tensorflow conformer HOT 3
- Unused dependency Pillow incompatible with python 3.6 HOT 2
- Is the Vietnamese dataset INFoRe a single speaker dataset HOT 2
- full int8 quantisation
- Issue about CharFeaturizer HOT 1
- Inference Problem with DeepSpeech2
- Masking in encoder HOT 2
- Can you make an example HOT 1
- Librispeech_train_4_1030.subword Not compatible HOT 5
- test.py not run with cpu=false HOT 4
- Still maintained? HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tensorflowasr.