GithubHelp home page GithubHelp logo

lucidrains / soundstorm-pytorch Goto Github PK

View Code? Open in Web Editor NEW
1.2K 51.0 77.0 342 KB

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch

License: MIT License

Python 100.00%
artificial-intelligence audio-generation deep-learning non-autoregressive transformers attention-mechanism

soundstorm-pytorch's Introduction

Soundstorm - Pytorch

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch.

They basically applied MaskGiT to the residual vector quantized codes from Soundstream. The transformer architecture they chose to use is one that fits well with the audio domain, named Conformer

Project Page

Appreciation

  • Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research

  • Lucas Newman for numerous contributions, including the initial training code, acoustic prompting logic, per-level quantizer decoding!

  • 🤗 Accelerate for providing a simple and powerful solution for training

  • Einops for the indispensable abstraction that makes building neural networks fun, easy, and uplifting

  • Steven Hillis for submitting the correct masking strategy and for verifying that the repository works! 🙏

  • Lucas Newman for basically training a small working Soundstorm with models across multiple repositories, showing it all works end-to-end. Models include SoundStream, Text-to-Semantic T5, and finally the SoundStorm transformer here.

  • @Jiang-Stan for identifying a critical bug in the iterative demasking!

Install

$ pip install soundstorm-pytorch

Usage

import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper

conformer = ConformerWrapper(
    codebook_size = 1024,
    num_quantizers = 12,
    conformer = dict(
        dim = 512,
        depth = 2
    ),
)

model = SoundStorm(
    conformer,
    steps = 18,          # 18 steps, as in original maskgit paper
    schedule = 'cosine'  # currently the best schedule is cosine
)

# get your pre-encoded codebook ids from the soundstream from a lot of raw audio

codes = torch.randint(0, 1024, (2, 1024, 12)) # (batch, seq, num residual VQ)

# do the below in a loop for a ton of data

loss, _ = model(codes)
loss.backward()

# model can now generate in 18 steps. ~2 seconds sounds reasonable

generated = model.generate(1024, batch_size = 2) # (2, 1024)

To directly train on raw audio, you need to pass in your pretrained SoundStream into SoundStorm. You can train your own SoundStream at audiolm-pytorch.

import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper, Conformer, SoundStream

conformer = ConformerWrapper(
    codebook_size = 1024,
    num_quantizers = 12,
    conformer = dict(
        dim = 512,
        depth = 2
    ),
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 12,
    attn_window_size = 128,
    attn_depth = 2
)

model = SoundStorm(
    conformer,
    soundstream = soundstream   # pass in the soundstream
)

# find as much audio you'd like the model to learn

audio = torch.randn(2, 10080)

# course it through the model and take a gazillion tiny steps

loss, _ = model(audio)
loss.backward()

# and now you can generate state-of-the-art speech

generated_audio = model.generate(seconds = 30, batch_size = 2)  # generate 30 seconds of audio (it will calculate the length in seconds based off the sampling frequency and cumulative downsamples in the soundstream passed in above)

Complete text-to-speech will rely on a trained TextToSemantic encoder / decoder transformer. You will then load the weights and pass it into the SoundStorm as spear_tts_text_to_semantic

This is a work-in-progress, as spear-tts-pytorch only has the model architecture complete, and not the pretraining + pseudo-labeling + backtranslation logic.

from spear_tts_pytorch import TextToSemantic

text_to_semantic = TextToSemantic(
    dim = 512,
    source_depth = 12,
    target_depth = 12,
    num_text_token_ids = 50000,
    num_semantic_token_ids = 20000,
    use_openai_tokenizer = True
)

# load the trained text-to-semantic transformer

text_to_semantic.load('/path/to/trained/model.pt')

# pass it into the soundstorm

model = SoundStorm(
    conformer,
    soundstream = soundstream,
    spear_tts_text_to_semantic = text_to_semantic
).cuda()

# and now you can generate state-of-the-art speech

generated_speech = model.generate(
    texts = [
        'the rain in spain stays mainly in the plain',
        'the quick brown fox jumps over the lazy dog'
    ]
) # (2, n) - raw waveform decoded from soundstream

Todo

  • integrate soundstream

  • when generating, and length can be defined in seconds (takes into sampling freq etc)

  • make sure grouped rvq is supported. concat embeddings rather than sum across group dimension

  • just copy conformer over and redo shaw's relative positional embedding with rotary embedding. nobody uses shaw anymore.

  • default flash attention to true

  • remove batchnorm, and just use layernorm, but after the swish (as in normformer paper)

  • trainer with accelerate - thanks to @lucasnewman

  • allow for variable lengthed sequence training and generation, by passing in mask at forward and generate

  • option to return list of audio files when generating

  • turn it into a command line tool

  • add cross attention and adaptive layernorm conditioning

Citations

@misc{borsos2023soundstorm,
    title   = {SoundStorm: Efficient Parallel Audio Generation}, 
    author  = {Zalán Borsos and Matt Sharifi and Damien Vincent and Eugene Kharitonov and Neil Zeghidour and Marco Tagliasacchi},
    year    = {2023},
    eprint  = {2305.09636},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{Chang2022MaskGITMG,
    title   = {MaskGIT: Masked Generative Image Transformer},
    author  = {Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11305-11315}
}
@article{Lezama2022ImprovedMI,
    title   = {Improved Masked Image Generation with Token-Critic},
    author  = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2209.04439}
}
@inproceedings{Nijkamp2021SCRIPTSP,
    title   = {SCRIPT: Self-Critic PreTraining of Transformers},
    author  = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
    booktitle = {North American Chapter of the Association for Computational Linguistics},
    year    = {2021}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

soundstorm-pytorch's People

Contributors

chenht2021 avatar lucasnewman avatar lucidrains avatar osehmathias avatar stevenhillis avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

soundstorm-pytorch's Issues

"conformer" type hint <built-in function any> either PEP-noncompliant or currently unsupported by @beartype

See pull request: #33

---------------------------------------------------------------------------
BeartypeDecorHintNonpepException          Traceback (most recent call last)
[<ipython-input-2-712e33fc0646>](https://localhost:8080/#) in <cell line: 1>()
----> 1 from soundstorm_pytorch import SoundStorm, ConformerWrapper

26 frames
[/usr/local/lib/python3.10/dist-packages/beartype/_util/hint/nonpep/utilnonpeptest.py](https://localhost:8080/#) in die_unless_hint_nonpep(hint, is_str_valid, exception_cls, exception_prefix)
    201 
    202     # Raise a generic exception.
--> 203     raise exception_cls(
    204         f'{exception_prefix}type hint {repr(hint)} either '
    205         f'PEP-noncompliant or currently unsupported by @beartype.'

BeartypeDecorHintNonpepException: Method soundstorm_pytorch.soundstorm.ConformerWrapper.__init__() parameter "conformer" type hint <built-in function any> either PEP-noncompliant or currently unsupported by @beartype.

Any help is needed?

I found this repo is quite active and you may need someone to do parallel experiment or coding. I'm willing to help and working on getting context of the whole repo.

[BUG] Questions about 'get_mask_subset_prob' and MLM in soundstorm.py

hi, great project. when reading your training code in soundstorm.py(commit 264a1f2), there are 2 places are hard for me to understand, could you please explain for me, thanks.

que1: get_mask_subset_prob in line86, I understand that this function has two purposes: 1. Randomly sample from elements that are True in the mask to create a subset_mask. 2. The elements that are True in the subset_mask should also be True in the corresponding positions in the mask. However, it seems that it cannot achieve the above purpose 1. Please refer to the following example:
index: 0 1 2 3 4 5 6 7
mask: 1 1 1 1 1 1 0 0. prob=0.3 min_mask=0 -> num_to_mask = 1.8 num_padding=2
logits: 0.2 0.4 0.1 0.3 0.6 0.5 -1 -1 #line98
randperm: 6 7 2. 0. 3. 1 5 4 #line100
randperm: 5 6 1 -1 2. 0 4 3 #line103
sub_mask: 0 0 1 1 0 1 0 0. #line 105
sub_mask: 0 0 1 1 0 1 0 0. #line 106

so it cannot achieve the above purpose 1: sum(sub_mask) > num_to_mask;
Another problem is that the elements in the range of 0~ num_padding in the sub_mask are most likely to be False, when num_padding > 0, because the larger padding elements' indices will always be sorted to the front in #line 100, which result in the model only learning the token-embedding from the latter half of the training data.

I attempted to fix the above issue, the code is as follows:

def get_mask_subset_prob(
mask: Tensor,
prob: Union[float, Tensor],
min_mask: int = 0
):
batch, seq, device = *mask.shape, mask.device

if isinstance(prob, Tensor):
    prob = rearrange(prob, 'b -> b 1')

num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask)
logits = torch.rand((batch, seq), device = device)
# logits = logits.masked_fill(~mask, -1)
logits = logits.masked_fill(~mask, 2) #mod: place padding elements to end position

randperm = logits.argsort(dim = -1)
randperm = randperm.argsort(dim = -1).float() #mod: randperm should be score instead of origin index

#num_padding = (~mask).sum(dim = -1, keepdim = True) # mod: delete
#randperm -= num_padding #mod: delete

subset_mask = randperm < num_to_mask
#subset_mask.masked_fill_(~mask, False) #mod: delete
return subset_mask

que2: I do not quite understand the role of introducing MLM in lines 1129 to 1140. In lines 1129 to 1133, randomly selecting some elements not to be replaced seems to only reduce the number of masked tokens. The tokens randomly replaced in lines 1135 to 1140 are only used as input and are not predicted to calculate loss. It's like creating errors in the input and having the model predict the masked token based on erroneous input. I guess this is to minimize the gap between training and inference, because SoundStorm will perform multi-step inference at level 0, and some tokens predicted in historical steps may be wrong?

I sincerely await your reply.

Simple mini train and test example?

As always, great appreciation for your prompt and illustrative implementation. Thank you very much.
Is it possible to implement some simple mini examples to train and test with current framework implementation repo alongside your repo. That will immensely speed up community understanding of closed source results and perhaps accelerate the test of their claimed results.
Cheers Patrick

README.md sample code error

I found a diminutive error in the README.md file, where num_quantizers in SoundStorm does not agree with that of the conformer.

In particular, the second code snippet should have been :

conformer = ConformerWrapper(
    codebook_size = 1024,
    num_quantizers = 12,
    conformer = dict(
        dim = 512,
        depth = 2
    ),
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 12,
    attn_window_size = 128,
    attn_depth = 2
)

I reckon 12 could have been anything, but they need to be the same. Otherwise, we will bump in to some assertion errors.

A bug in generate

scores = 1 - logits.softmax(dim = -1)

I think the score here should be logits.softmax, rather than 1-logits.softmax. logits in training is calculated normally, so the higher the logits, the better it should be.

I also tested both code, and found that if 1-logits is used, unmasked tokens would change in every step of first layer. But if it is replaced by logits, most of unmasked tokens of first layer predicted are kept in the next step.

RoPE Embeddings question

In the code you provided, I observed that instead of rotating every two elements, half of the vector in the dim array is being rotated

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

I am curious to know what is the reason behind it

Try using T5 relative positioning instead of RoPE

This paper shows that RoPE consistently underperforms T5 rel pos in various language-based logic problems. It might be interesting to explore a version of SoundStorm that uses T5 rel. pos. instead? Although these tasks are not exactly comparable, it does look like RoPE consistently underperforms other approaches

Trouble using flash attention on A100

When I enable flash attention on A100 I get two errors. First the input is not float16. Even if that is fixed, flash attention does not seem to support a non-empty mask.

Pytorch: 2.1.1+cu121
Cuda: 12.1
A100

`import torch
import soundstorm_pytorch
myattn = soundstorm_pytorch.attend.Attend(flash=True)

x = torch.randint(0, 1024, (1, 8, 1024, 64)).to('cuda') # (batch, seq, num residual VQ)
z=myattn(x,x,x) # Fails
z=myattn(x.half(),x.half(),x.half()) # Works

mask = torch.ones((1, 8, 1024, 1024)).to('cuda').bool()
z=myattn(x.half(),x.half(),x.half(),mask=mask) # Fails`

Error messages are

/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:367.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:437.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:369.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Expected query, key and value to all be of dtype: {Half, BFloat16}. Got Query dtype: long int, Key dtype: long int, and Value dtype: long int instead. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:92.) out = F.scaled_dot_product_attention( Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 135, in forward return self.flash_attn(q, k, v, mask = mask) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 109, in flash_attn out = F.scaled_dot_product_attention(

and

/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Both fused kernels do not support non-null attn_mask. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:261.) out = F.scaled_dot_product_attention( Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 135, in forward return self.flash_attn(q, k, v, mask = mask) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 109, in flash_attn out = F.scaled_dot_product_attention( RuntimeError: No available kernel. Aborting execution.

Minimum number of training steps for unconditional synthesis

Hi,

This is a great library, I am training SoundStorm on the Librispeech 1000hrs dataset and want to know how many training steps are required to start hearing some sensible audio using generate function, currently, it is trained for 100K steps and the audio is still pure noise, can you specify after how many steps you started to hear some sensible audio?

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.