GithubHelp home page GithubHelp logo

lucidrains / bs-roformer Goto Github PK

View Code? Open in Web Editor NEW
265.0 9.0 10.0 226 KB

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs

License: MIT License

Python 100.00%
artificial-intelligence attention-mechanisms deep-learning music-source-separation transformers

bs-roformer's Introduction

BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs. They beat the previous first place by a large margin. The technique uses axial attention across frequency (hence multi-band) and time. They also have experiments to show that rotary positional encoding led to a huge improvement over learned absolute positions.

It also includes support for stereo training and outputting multiple stems.

Please join Join us on Discord if you are interested in replicating a SOTA music source separator out in the open

Update: This paper has been replicated by Roman and weight open sourced here

Appreciation

  • StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

  • Roee and Fabian-Robert for sharing their audio expertise and fixing audio hyperparameters

  • @chenht2010 and Roman for working out the default band splitting hyperparameter!

  • Max Prod for reporting a big bug with Mel-Band Roformer with stereo training!

  • Roman for successfully training the model and open sourcing his training code and weights at this repository!

  • Christopher for fixing an issue with multiple stems in Mel-Band Roformer

  • Iver Jordal for identifying that the default stft window function is not correct

Install

$ pip install BS-RoFormer

Usage

import torch
from bs_roformer import BSRoformer

model = BSRoformer(
    dim = 512,
    depth = 12,
    time_transformer_depth = 1,
    freq_transformer_depth = 1
)

x = torch.randn(2, 352800)
target = torch.randn(2, 352800)

loss = model(x, target = target)
loss.backward()

# after much training

out = model(x)

To use the Mel-Band Roformer proposed in a recent follow up paper, simply import MelBandRoformer instead

import torch
from bs_roformer import MelBandRoformer

model = MelBandRoformer(
    dim = 32,
    depth = 1,
    time_transformer_depth = 1,
    freq_transformer_depth = 1
)

x = torch.randn(2, 352800)
target = torch.randn(2, 352800)

loss = model(x, target = target)
loss.backward()

# after much training

out = model(x)

Todo

  • get the multiscale stft loss in there
  • figure out what n_fft should be
  • review band split + mask estimation modules

Citations

@inproceedings{Lu2023MusicSS,
    title   = {Music Source Separation with Band-Split RoPE Transformer},
    author  = {Wei-Tsung Lu and Ju-Chiang Wang and Qiuqiang Kong and Yun-Ning Hung},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:261556702}
}
@inproceedings{Wang2023MelBandRF,
    title   = {Mel-Band RoFormer for Music Source Separation},
    author  = {Ju-Chiang Wang and Wei-Tsung Lu and Minz Won},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:263608675}
}
@misc{ho2019axial,
    title  = {Axial Attention in Multidimensional Transformers},
    author = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans},
    year   = {2019},
    archivePrefix = {arXiv}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{Bondarenko2023QuantizableTR,
    title   = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
    author  = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.12929},
    url     = {https://api.semanticscholar.org/CorpusID:259224568}
}
@inproceedings{ElNouby2021XCiTCI,
    title   = {XCiT: Cross-Covariance Image Transformers},
    author  = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
    booktitle = {Neural Information Processing Systems},
    year    = {2021},
    url     = {https://api.semanticscholar.org/CorpusID:235458262}
}

bs-roformer's People

Contributors

crlandsc avatar iver56 avatar lucidrains avatar shenberg avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bs-roformer's Issues

MLP design in MaskEstimator

Thanks for your great work.
In the BS-RoFormer paper, the authors mention:

Each MLP layer consists of a RMS Norm layer, a fully connected layer followed by a Tanh activation, and a fully connected layer followed by a gated linear unit (GLU) layer[29].

However, your MaskEstimator implementation is somewhat different from the paper description.

Here's my implementation (I also checked the number of parameters):

class TanH(Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        x = self.proj(x)
        return x.tanh()
class GLU(Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * gate.sigmoid()
class MaskEstimator(Module):
    @beartype
    def __init__(
            self,
            dim,
            dim_inputs: Tuple[int, ...],
            mlp_expansion_factor = 4,
    ):
        super().__init__()
        self.dim_inputs = dim_inputs
        self.to_freqs = ModuleList([])

        for dim_in in dim_inputs:
            net = []
            net.append(TanH(dim, dim * ff_mult))
            net.append(GLU(dim * ff_mult, dim_in))
            self.to_freqs.append(nn.Sequential(*net))

    def forward(self, x):
        x = x.unbind(dim=-2)

        outs = []

        for band_features, to_freq in zip(x, self.to_freqs):
            freq_out = to_freq(band_features)
            outs.append(freq_out)

        return torch.cat(outs, dim=-1)

MelRoformer parameters from paper

I'm trying to reproduce the paper model. But I have no luck.

My current settings which gave me batch size only 2 for 48 GB memory:

  dim: 192
  depth: 8
  stereo: true
  num_stems: 1
  time_transformer_depth: 1
  freq_transformer_depth: 1
  num_bands: 60
  dim_head: 64
  heads: 8
  attn_dropout: 0.1
  ff_dropout: 0.1
  flash_attn: True
  dim_freqs_in: 1025
  sample_rate: 44100  # needed for mel filter bank from librosa
  stft_n_fft: 2048
  stft_hop_length: 512
  stft_win_length: 2048
  stft_normalized: False
  mask_estimator_depth: 2
  multi_stft_resolution_loss_weight: 1.0
  multi_stft_resolutions_window_sizes: !!python/tuple
  - 4096
  - 2048
  - 1024
  - 512
  - 256
  multi_stft_hop_size: 147
  multi_stft_normalized: False

On input I give 8 seconds of 44100Hz so length is 352800.

I run my code model through torchinfo:

from torchinfo import summary
summary(model, input_size=(1, 2, 352768))

Report is:

==============================================================================================================
Layer (type:depth-idx)                                       Output Shape              Param #
==============================================================================================================
MelBandRoformer                                              [1, 2, 352768]            56,503,768
├─ModuleList: 1-1                                            --                        --
│    └─ModuleList: 2-1                                       --                        384
│    │    └─Transformer: 3-77                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-78                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-2                                       --                        384
│    │    └─Transformer: 3-79                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-80                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-3                                       --                        384
│    │    └─Transformer: 3-81                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-82                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-4                                       --                        384
│    │    └─Transformer: 3-83                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-84                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-5                                       --                        384
│    │    └─Transformer: 3-85                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-86                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-6                                       --                        384
│    │    └─Transformer: 3-87                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-88                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-7                                       --                        384
│    │    └─Transformer: 3-89                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-90                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-8                                       --                        384
│    │    └─Transformer: 3-91                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-92                                [690, 60, 192]            (recursive)
├─BandSplit: 1-2                                             [1, 690, 60, 192]         --
│    └─ModuleList: 2                                         --                        --
....
├─ModuleList: 1-1                                            --                        --
│    └─ModuleList: 2-1                                       --                        384
│    │    └─Transformer: 3-77                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-78                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-2                                       --                        384
│    │    └─Transformer: 3-79                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-80                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-3                                       --                        384
│    │    └─Transformer: 3-81                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-82                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-4                                       --                        384
│    │    └─Transformer: 3-83                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-84                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-5                                       --                        384
│    │    └─Transformer: 3-85                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-86                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-6                                       --                        384
│    │    └─Transformer: 3-87                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-88                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-7                                       --                        384
│    │    └─Transformer: 3-89                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-90                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-8                                       --                        384
│    │    └─Transformer: 3-91                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-92                                [690, 60, 192]            (recursive)
├─ModuleList: 1                                              --                        --
│    └─MaskEstimator: 2-9                                    [1, 690, 7916]            --
==============================================================================================================
Total params: 69,102,468
Trainable params: 69,102,404
Non-trainable params: 64
Total mult-adds (G): 8.35
==============================================================================================================
Input size (MB): 2.82
Forward/backward pass size (MB): 703.40
Params size (MB): 232.17
Estimated Total Size (MB): 938.40
==============================================================================================================

From report I expect to have batch more than 48. But in the end I can use batch only 2.

GPU memory usage for batch = 2:
изображение

To follow the paper I must increase dim to 384, depth to 12 and decrease stft_hop_length to 441 - to be 10 ms. In this case batch size will be only 1 or not fit in memory )

Any ideas how to deal with such big memory usage?

Cut frequencies at half of the Nyquist in the MelBand model

Hi, before starting training on a large corpus, I've just test the latest release 3.0 and the new MelBand model. Both models are initialized with the same parameters, no change on others parameters.

model = MelBandRoformer(dim = 384,
                        depth = 9,
                        time_transformer_depth = 1,
                        freq_transformer_depth = 1, 
                        stereo=stereo, 
                        sample_rate=sr).to(device)

For in-situation condition, I load a audio stereo mixture of 8 seconds @ 44100 Hz with a simil target (a drums stem). The batch size is 1, so the tensors are of size : [1, 2, 352800]

After calling a single training forward step, the backward loss seem coherent:

BandSplit : tensor(2.3477, device='cuda:0', grad_fn=<AddBackward0>)
  MelBand : tensor(2.2762, device='cuda:0', grad_fn=<AddBackward0>)

Unfortunatly, when I save back the audio of both models outputs, I got a strange behavior in the MelBand model. Better than words, the spectrograms:

Mixture [mel scale view]

1_mixture

Target [mel scale view]

2_target

BandSplit output [mel scale view]

3_recon_audio_lin

MelBand output [mel scale view]

4_recon_audio_mel

MelBand output [linear scale view]

4_recon_audio_mel_linear

The spectrogram show that the MelBand model output cut the frequencies above 11025 Hz, so half of the Nyquist frequency of 22050 Hz for a 44100 Hz audio.

I don't know if it's normal or a bug, but I prefer to share the information here.

Thank's so much for BS-RoFormer !!!

MelBand doesn't work with stereo

Hello. I tried to run a training.

model = MelBandRoformer(
      stereo=True,
      dim=32,
      depth=1,
      attn_dropout=0.1,
      time_transformer_depth=1,
      freq_transformer_depth=1,
  )

But I got an error:

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 3964 (input tensor's size at dimension -1), but got split_sizes=[24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 28, 32, 32, 32, 36, 40, 44, 44, 44, 52, 56, 60, 64, 64, 68, 76, 80, 84, 92, 100, 104, 112, 120, 124, 136, 148, 156, 164, 176, 188, 200, 216, 228, 244, 264, 280, 296, 316, 340, 364, 388, 412, 440, 472, 504]

Without stereo=True it works normally.

Default time_transformer_depth

In readme you use the following parameters as example:

time_transformer_depth = 1
freq_transformer_depth = 1

However, these are the defaults in the BSRoformer class:

time_transformer_depth = 2
freq_transformer_depth = 2

What would you recommend?

Feature request: decouple the loss function of the forward function

In the current implementation, the forward() method is generic for train or eval mode. In some case, we need to have not only the loss but the prediction on output that allow to compute extra features like the SDR metric during the validation step.

Because the loss function code is common for BSRoformer and MelBandRoformer classes, maybe that can be better create a new class like MultiResLoss for a maximum of flexibility:

import torch
import torch.nn.functional as F
from einops import rearrange
from beartype import beartype
from beartype.typing import Tuple

class MultiResLoss():
    @beartype
    def __init__(
        self,
        num_stems,
        stft_n_fft,
        multi_stft_resolution_loss_weight = 1.,
        multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
        multi_stft_hop_size = 147,
        multi_stft_normalized = False
    ):
        self.num_stems = num_stems

        self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
        self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
        self.multi_stft_n_fft = stft_n_fft

        self.multi_stft_kwargs = dict(
            hop_length = multi_stft_hop_size,
            normalized = multi_stft_normalized
        )
    
    def __call__(
        self, 
        predict, 
        targets, 
        return_loss_breakdown = False
    ):
        if self.num_stems > 1:
            assert targets.ndim == 4 and targets.shape[1] == self.num_stems
        
        if targets.ndim == 2:
            targets = rearrange(targets, '... t -> ... 1 t')

        targets = targets[..., :predict.shape[-1]] # protect against lost length on istft

        loss = F.l1_loss(predict, targets)

        multi_stft_resolution_loss = 0.

        for window_size in self.multi_stft_resolutions_window_sizes:

            res_stft_kwargs = dict(
                n_fft = max(window_size, self.multi_stft_n_fft),  # not sure what n_fft is across multi resolution stft
                win_length = window_size,
                return_complex = True,
                **self.multi_stft_kwargs,
            )

            predict_Y = torch.stft(rearrange(predict, '... s t -> (... s) t'), **res_stft_kwargs)
            targets_Y = torch.stft(rearrange(targets, '... s t -> (... s) t'), **res_stft_kwargs)

            multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(predict_Y, targets_Y)

        weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight

        total_loss =  loss + weighted_multi_resolution_loss

        if not return_loss_breakdown:
            return total_loss

        return total_loss, (loss, multi_stft_resolution_loss)

In the same spirit, a little refactoring could be to create a new file for the common classes :

- RMSNorm
- FeedForward
- Attention
- Transformer
- BandSplit
- MLP
- MaskEstimator

That can be easier for future change in the code?

Potential bug with `num_stems > 1`

Hi @lucidrains

I'm working on a training code to share here and attempt to reproduce the same workflow than the SAMI-ByteDance paper. I will share the code here. In order to make the training code generic, and because I'm pretty newbie in this kind of task, I want to be sure to clearly understand your code.

Num_Stems parameter

EDIT: I remove this very stupid assumption.

Divergence of default parameters

This is a less important, but for better following of the code:

  1. In the "Music Source Separation with Band-Split Rope Transformer" paper:

We use D = 384 for feature dimension, L = 12 for the number of Transformer blocks, 8 heads for each Transformer, and a dropout rate of 0.1.

I don't know if the choice of attn_dropout = 0. and ff_dropout = 0. as default parameters for the Transformers is motivated by some technical considerations. If not, maybe it's better to change the values to follow the original paper?

  1. In the "Mel-Band Roformer for Music Source Separation" paper:

We use 60 Mel-bands, as it is similar to the number of subbands, i.e., 62, adopted by BS-RoFormer.

Same here with the default parameter num_bands = 62 of the MelBandRoformer?

Hidden dim in mask estimation module

In the paper second sentence in 4.4 configuration they mention:

The multi-band mask estimation module utilizes MLPs with a hidden layer dimension of 4D.

But in the actual code from my understanding there is no hidden layer only 2 linear layers smart combined to a single one.

Training of BS-RoFormer

I tried to train this neural net without any success. SDR stuck in around 2.1 for vocals and never grows more. If somebody have better results please let me know.

How to use it for audio source separation?

I am a novice in this field and would like to complete the task of audio source separation. This project seems to be for that purpose, but I couldn't understand how to use it for audio source separation based on the examples in the documentation.

Hi @lucidrains + @ZFTurbo

          Hi @lucidrains + @ZFTurbo 

I work with birds and have a audio dataset the sounds they make that I would like to train with this code, I have had success with other models that were authored for music. Please can I request for a way to train this model to be added to the code? Sorry i can not accomplish this myself.

Originally posted by @lyndonlauder in #4 (comment)

Gates in Attention module of bs_roformer.py

I am a bit confused of the gates in the Attention module of the bs_roformer.py. The code in lines 103-105 is

out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

From my understanding this is not the standard multi head attention approach and the paper has not mentioned anything about using something else. Therefore, I would remove the parts using gates resulting in the following code:

out = self.attend(q, k, v)

out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)

What got I wrong? What is the gates for and why is it used can you clear this up?

Problem with DataParallel

There is a problem with model if you try to use multiGPU:

Traceback (most recent call last):
  File "train_local_run.py", line 44, in <module>
    train_model(args)
  File "train.py", line 239, in train_model
    loss = model(x, y)
  File "\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "\site-packages\torch\nn\parallel\data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "\site-packages\torch\nn\parallel\data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "\site-packages\torch\nn\parallel\parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "\site-packages\torch\_utils.py", line 644, in reraise
    raise exception
StopIteration: Caught StopIteration in replica 0 on device 0.
Original Traceback (most recent call last):
  File "\site-packages\torch\nn\parallel\parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "\models\bs_roformer\mel_band_roformer.py", line 437, in forward
    stft_window = self.stft_window_fn(device=self.device)
  File "\models\bs_roformer\mel_band_roformer.py", line 403, in device
    return next(self.parameters()).device
StopIteration

What is the correct way to fix it?

[bad assertion] strange bottleneck performance

Hi, let me introduce this little benchmark of the model running on GPU (device='cuda'):

---------- Training ----------
Time: 0.019369 seconds in "STFT" stage.
Time: 0.704290 seconds in "Band Split" stage.
Time: 0.124404 seconds in "Transformer" stage.
Time: 0.017830 seconds in "Mask Estimator" stage.
Time: 6.671448 seconds in "ISTFT" stage.
--------------------------------
Time: 0.147087 seconds in "Multiresolution Loss" stage.
Time: 0.493661 seconds in "Loss Backward" stage.
--------------------------------
Time: 9.102034 seconds for the whole process.
 
---------- Evaluation ----------
Time: 0.000107 seconds in "STFT" stage.
Time: 0.003609 seconds in "Band Split" stage.
Time: 0.015426 seconds in "Transformer" stage.
Time: 0.009862 seconds in "Mask Estimator" stage.
Time: 3.528062 seconds in "ISTFT" stage.
--------------------------------
Time: 3.565335 seconds for the whole process.

The model specifications follow the original paper:

    model = BSRoformer(dim = 384,
                       depth = 12,
                       time_transformer_depth = 1,
                       freq_transformer_depth = 1,
                       mask_estimator_depth = 4).to(device)

The tensors for testing are initialized using:

    audio_size = 4 * 44100 # 4 seconds @ 44.1 KHz

    sample = torch.randn(2, audio_size, dtype=torch.float32).to(device)
    target = torch.randn(2, audio_size, dtype=torch.float32).to(device)

The benchmark do not include the einops operations, nor the other tensor manipulation but bounds the stages of the model like that:

        self.bench.start('Band Split')
        x = self.band_split(x)
        self.bench.stop()
        
        # axial / hierarchical attention

        self.bench.start('Transformer')
        for time_transformer, freq_transformer in self.layers:

            x = rearrange(x, 'b t f d -> b f t d')
            x, ps = pack([x], '* t d')

            x = time_transformer(x)

            x, = unpack(x, ps, '* t d')
            x = rearrange(x, 'b f t d -> b t f d')
            x, ps = pack([x], '* f d')

            x = freq_transformer(x)

            x, = unpack(x, ps, '* f d')

        x = self.final_norm(x)
        self.bench.stop()

        num_stems = len(self.mask_estimators)

        self.bench.start('Mask Estimator')
        mask = torch.stack([fn(x) for fn in self.mask_estimators], dim = 1)
        self.bench.stop()

The Benchmark class is a pretty trivial one:

from time import perf_counter

class Benchmark():
    def __init__(self):
        pass

    def start(self, stage):
        self.stage = stage
        self.time_start = perf_counter()

    def stop(self):
        self.time_duration = perf_counter() - self.time_start
        print(f'Time: {self.time_duration:.6f} seconds in "{self.stage}" stage.')

Conclusion:

66% of the time in the model is lost in the torch.istft process while the torch.stft is not slow at all.

Am I the only one to notice this?

Edit:

Wrong conclusion, see next message.

Input tensor size

I am confused about why the tensor length in the readme is 131680, from the paper I got that they use 8s of audio with 44.1kHz which makes from my understand 352800 😅

Linear Attention temperature initialization

First of all, thanks for all the good work (including in this repo)!

There's a potential bug in LinearAttention that I thought I should bring to your attention: In contrast to the original paper, you learn the logarithmic temperature, but still initialize the parameter with ones instead of zeros. This means the temperature will initially be Euler's number. Maybe it doesn't make much of a difference in practice or works even better, but it looks like it may have been unintentional.

Flash Attention support

Thank you very much for your code. You rock!

Is Flash Attention only supported by A100 GPU ?

Hop length

I am still a bit confused because in the code is the following comment
in line 241 stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better

It's true that they are mentioned to use a hop_size of 10ms for a 44.1kHz sample rate, but in my calculation 512 is not 10ms but ~11.6ms

Did I missunderstood something or what is the intention here ?

Input data shape

Hello, there is a problem with input data. You write

x = torch.randn(2, 131680)
target = torch.randn(2, 131680)

loss = model(x, target = target)
loss.backward()

# after much training

out = model(x)

But in reality there must be also some batch size:

x = torch.randn(10, 2, 131680)
target = torch.randn(10, 2, 131680)

loss = model(x, target = target)
loss.backward()

# after much training

out = model(x)

If I use like in bottom example I have a problem with stft.

RuntimeError: stft(torch.cuda.FloatTensor[6, 2, 133728], n_fft=2048, hop_length=512, win_length=2048, window=None, normalized=0, onesided=None, return_complex=1) : expected a 1D or 2D tensor

Hardly to train with 8s length audio for batch size of 2

Hi, according to the paper,
We do not use the In-House dataset for ablation study. The effective batch size is 64 (i.e., 4 for each GPU) using accumulate grad batches=2.
and
Our model takes a segment of 8-seconds waveform for input and output.
for the L=6 model.
Based on my understanding, they train every 2 audio segments with 8s on each GPU without accumulate batch.
However, I hardly to fit in while I am only training with one V100.

Another thing is the size of the model
In the paper, they mentioned
The numbers of parameters for BS-RoFormer and BS-Transformer with L=6 are 72.2M and 72.5M.
While I following the default setting from the paper, my model is 108M

model = Model(
dim=384,
depth=6,
time_transformer_depth=1,
freq_transformer_depth=1,
heads=8,
attn_dropout=0.1,
ff_dropout=0.1,
dim_head=48,
stereo=True
)

Anything wrong?

Flash attention error in Linear Attention layer

I noticed error in Linear Attention layer when flash_attn is set in True:

File "attend.py", line 84, in flash_attn
    out = F.scaled_dot_product_attention(
RuntimeError: No available kernel. Aborting execution. 

In standard Attention (time_transformer and freq_transformer) it works ok.

So currently as workaround I did followig change in LinearAttention (set False to flash):

self.attend = Attend(
    scale=scale,
    dropout=dropout,
    flash=False
)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.