Comments (15)
I would love to contribute those back, but unfortunately there's a fairly involved open-source contribution process at my organization that would take a while, it'd probably be best to find someone else to do so.
However, I did test this out locally, and re-ran the benchmarking at https://github.com/csukuangfj/transducer-loss-benchmarking - the results look really good, peak memory usage goes from 3820 all the way down to 1182 (!), and from 2647 to 835 when sorting utterances. Step time (on my hardware) went from 343k to 280k us.
Pretty cool! Always gotta be careful with those torch.gathers.
from fast_rnnt.
@danpovey Yifan has already made PRs here #26 and #24 , you can merge it.
from fast_rnnt.
In a regular rnnt
As you have mentioned, that is for regular RNN-T.
The version we are using is not regular. It has the same condition as CTC training, i.e., S <= T.
from fast_rnnt.
Here is the paper about fast_rnnt:
https://arxiv.org/pdf/2206.13236.pdf
from fast_rnnt.
Here is the code to filter data that don't satisfy S<=T
in icefall:
https://github.com/k2-fsa/icefall/blob/f13cf61b05432a989e6a42c95b843a56639bcbde/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L958
# In ./conformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
from fast_rnnt.
Thanks for your fast reply.
I have tried to modify my code based on this example, I thinks it's a normal transducer. I can filter the data as you said to make it work. I just wonder why we has this limitation (for optimization? Actually I have read your paper yesterday but I didn't notice this condition, I will double check it), could I just comment this assert to make the pruned loss work just like the rnnt_loss (like in torchaudio or warp-transducer)
from fast_rnnt.
@BuaaAlban as you noted, this constraint is indeed not required for the "regular" RNNT topology. Only if you train with the "modified" topology, where you are constrained to emit exactly 1 symbol per time frame, will this constraint be required. We have a PR here (k2-fsa/k2#1149) to remove this constraint from k2. I will also make a similar PR for fast_rnnt.
from fast_rnnt.
@desh2608 are you still planning to make this PR? This would be very useful for my work!
from fast_rnnt.
@arkadyark sorry I forgot to actually push the changes. BTW, I believe Dan fixed some OOM issues in the pruned transducer loss in k2, which hasn't yet been merged in fast_rnnt. So you may want to make those changes yourself.
from fast_rnnt.
Thanks! Which changes are you referring to? Looking through recent changes to rnnt_loss.py I don't see anything there.
from fast_rnnt.
Thanks! Which changes are you referring to? Looking through recent changes to rnnt_loss.py I don't see anything there.
Check k2-fsa/k2#1177 and k2-fsa/k2#1183
from fast_rnnt.
Ah yes. Arkady, it would be great if you could make a PR to fast_rnnt with those changes, I had forgotten about that. If not LMK, I'll ask someone here.
from fast_rnnt.
Hey @danpovey , just wanted to follow up - is anybody able to make those changes here?
from fast_rnnt.
@pkufool could you please have a look at this?
from fast_rnnt.
closed by #29
from fast_rnnt.
Related Issues (20)
- An error occurred while compiling the source code HOT 2
- CUDA error HOT 6
- Trying to Understand pruned_loss HOT 7
- #error -- unsupported GUN version ! gcc version later than 5.3 are not supported! HOT 2
- Train loss is nan or inf HOT 29
- Combination of fast_rnnt and fast_emit HOT 12
- missing: CUDNN_LIBRARY_PATH CUDNN_INCLUDE_PATH when installing HOT 3
- AssertionError: assert py.is_contiguous() HOT 12
- RuntimeError: invalid device ordinal HOT 19
- ModuleNotFoundError: No module named '_fast_rnnt' HOT 2
- Issue in installation HOT 7
- [Help wanted] Support BUILD_FOR_ALL_ARCHS
- Error while installing HOT 6
- pip error HOT 4
- pip error
- Import fast_rnnt is Failed HOT 11
- [feature request] Enable github actions HOT 1
- T>=S constraint in latest pip version HOT 1
- RuntimeError: Failed to find native CUDA module HOT 10
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from fast_rnnt.