GithubHelp home page GithubHelp logo

Comments (8)

TParcollet avatar TParcollet commented on June 5, 2024 1

Hello @egaznep I am not sure to understand the issue here. Could you provide a code snippet showing explicitly the error? The function length_to_mask() is expected to provide masks containing padding i.e. real size of the input tensor, and this for each sequence. Could you detail a bit more what you are trying to achieve?

from speechbrain.

Adel-Moumen avatar Adel-Moumen commented on June 5, 2024

Hello @egaznep, thanks for opening this issue!

Could you please have a look @TParcollet? Thanks :)

from speechbrain.

egaznep avatar egaznep commented on June 5, 2024

@TParcollet Here is a minimal working (or in this case crashing) example:

import torch
import torch.nn as nn
from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
import random


# Instantiate the TransformerASR model
model = TransformerASR(
    tgt_vocab=720,
    input_size=80,
    d_model=512,
    nhead=1,
    num_encoder_layers=1,
    num_decoder_layers=1,
)

# Generate some dummy input with different lengths
input_lengths = torch.tensor([l for l in range(10,101,10)])
input_data = [torch.randn(length,80) for length in input_lengths]
input_targets = [torch.randint(low=0, high=720, size=(length.item(),)) for length in input_lengths]
# Pad the input sequences to have the same length
input_data = nn.utils.rnn.pad_sequence(input_data, batch_first=True)
input_targets = nn.utils.rnn.pad_sequence(input_targets, batch_first=True)
input_lengths = input_lengths/100.0

print(input_data.shape, input_targets.shape, input_lengths.shape)

output = model.forward(input_data, input_targets, wav_len=input_lengths) # works
output = model.forward(input_data[:-1], input_targets[:-1], wav_len=input_lengths[:-1]) # fails

First call to the model.forward in line 29 has the following wav_lens: 10, 20, ..., 100, thus all are padded to 100. The mask computation will not fail because it can infer the proper mask size.

Second call to the model.forward in line 30, however, has the following wav_lens: 10, 20, ..., 90, even though all are padded to 100 because that is the longest in the entire dataset. The mask computation will fail because mask generator does not check how long the padded sequences are, but simply the largest wav_lens are.

from speechbrain.

TParcollet avatar TParcollet commented on June 5, 2024

Hello thanks. SpeechBrain padding is relative to the batch, not the dataset. The max len of wav_lens is the max len of the batch.

from speechbrain.

egaznep avatar egaznep commented on June 5, 2024

I had this error while training a model using DistributedDataParallel on 2 GPUs. Could it be that initial data sampler does the padding relative to the complete batch size, but then during distribution to individual GPUs the error I am facing occurs?

from speechbrain.

TParcollet avatar TParcollet commented on June 5, 2024

@Gastron correct me if I am wrong, but as far as I know, DDP sampler is per-process, hence the padding should be relative to the batch of each process. @egaznep any chance that you could give us an example where this happens?

from speechbrain.

Adel-Moumen avatar Adel-Moumen commented on June 5, 2024

Hello @egaznep, any news on your side please ?

from speechbrain.

egaznep avatar egaznep commented on June 5, 2024

I was swamped with some projects until now, and I'm out of office this week. I will try to reproduce when I am back, but I guess it's more likely an issue with that specific project and not really related with Speechbrain internals. Thank you for reminding me.

from speechbrain.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.