GithubHelp home page GithubHelp logo

zhangxinfd / soundstorm-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lucidrains/soundstorm-pytorch

0.0 0.0 0.0 259 KB

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch

License: MIT License

Python 100.00%

soundstorm-pytorch's Introduction

Soundstorm - Pytorch (wip)

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch.

They basically applied MaskGiT to the residual vector quantized codes from Soundstream. The transformer architecture they chose to use is one that fits well with the audio domain, named Conformer

Project Page

Appreciation

  • Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research

  • Lucas Newman for numerous contributions, including the initial training code, acoustic prompting logic, per-level quantizer decoding!

  • 🤗 Accelerate for providing a simple and powerful solution for training

  • Einops for the indispensable abstraction that makes building neural networks fun, easy, and uplifting

  • Steven Hillis for submitting the correct masking strategy and for verifying that the repository works! 🙏

  • Lucas Newman for basically training a small working Soundstorm with models across multiple repositories, showing it all works end-to-end. Models include SoundStream, Text-to-Semantic T5, and finally the SoundStorm transformer here.

Install

$ pip install soundstorm-pytorch

Usage

import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper

conformer = ConformerWrapper(
    codebook_size = 1024,
    num_quantizers = 12,
    conformer = dict(
        dim = 512,
        depth = 2
    ),
)

model = SoundStorm(
    conformer,
    steps = 18,          # 18 steps, as in original maskgit paper
    schedule = 'cosine'  # currently the best schedule is cosine
)

# get your pre-encoded codebook ids from the soundstream from a lot of raw audio

codes = torch.randint(0, 1024, (2, 1024, 12)) # (batch, seq, num residual VQ)

# do the below in a loop for a ton of data

loss, _ = model(codes)
loss.backward()

# model can now generate in 18 steps. ~2 seconds sounds reasonable

generated = model.generate(1024, batch_size = 2) # (2, 1024)

To directly train on raw audio, you need to pass in your pretrained SoundStream into SoundStorm. You can train your own SoundStream at audiolm-pytorch.

import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper, Conformer, SoundStream

conformer = ConformerWrapper(
    codebook_size = 1024,
    num_quantizers = 12,
    conformer = dict(
        dim = 512,
        depth = 2
    ),
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 4,
    attn_window_size = 128,
    attn_depth = 2
)

model = SoundStorm(
    conformer,
    soundstream = soundstream   # pass in the soundstream
)

# find as much audio you'd like the model to learn

audio = torch.randn(2, 10080)

# course it through the model and take a gazillion tiny steps

loss, _ = model(audio)
loss.backward()

# and now you can generate state-of-the-art speech

generated_audio = model.generate(seconds = 30, batch_size = 2)  # generate 30 seconds of audio (it will calculate the length in seconds based off the sampling frequency and cumulative downsamples in the soundstream passed in above)

Complete text-to-speech will rely on a trained TextToSemantic encoder / decoder transformer. You will then load the weights and pass it into the SoundStorm as spear_tts_text_to_semantic

This is a work-in-progress, as spear-tts-pytorch only has the model architecture complete, and not the pretraining + pseudo-labeling + backtranslation logic.

from spear_tts_pytorch import TextToSemantic

text_to_semantic = TextToSemantic(
    dim = 512,
    source_depth = 12,
    target_depth = 12,
    num_text_token_ids = 50000,
    num_semantic_token_ids = 20000,
    use_openai_tokenizer = True
)

# load the trained text-to-semantic transformer

text_to_semantic.load('/path/to/trained/model.pt')

# pass it into the soundstorm

model = SoundStorm(
    conformer,
    soundstream = soundstream,
    spear_tts_text_to_semantic = text_to_semantic
).cuda()

# and now you can generate state-of-the-art speech

generated_speech = model.generate(
    texts = [
        'the rain in spain stays mainly in the plain',
        'the quick brown fox jumps over the lazy dog'
    ]
) # (2, n) - raw waveform decoded from soundstream

Todo

  • integrate soundstream

  • when generating, and length can be defined in seconds (takes into sampling freq etc)

  • make sure grouped rvq is supported. concat embeddings rather than sum across group dimension

  • just copy conformer over and redo shaw's relative positional embedding with rotary embedding. nobody uses shaw anymore.

  • default flash attention to true

  • remove batchnorm, and just use layernorm, but after the swish (as in normformer paper)

  • trainer with accelerate - thanks to @lucasnewman

  • allow for variable lengthed sequence training and generation, by passing in mask at forward and generate

  • option to return list of audio files when generating

  • turn it into a command line tool

  • add cross attention and adaptive layernorm conditioning

Citations

@misc{borsos2023soundstorm,
    title   = {SoundStorm: Efficient Parallel Audio Generation}, 
    author  = {Zalán Borsos and Matt Sharifi and Damien Vincent and Eugene Kharitonov and Neil Zeghidour and Marco Tagliasacchi},
    year    = {2023},
    eprint  = {2305.09636},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{Chang2022MaskGITMG,
    title   = {MaskGIT: Masked Generative Image Transformer},
    author  = {Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11305-11315}
}
@article{Lezama2022ImprovedMI,
    title   = {Improved Masked Image Generation with Token-Critic},
    author  = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2209.04439}
}
@inproceedings{Nijkamp2021SCRIPTSP,
    title   = {SCRIPT: Self-Critic PreTraining of Transformers},
    author  = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
    booktitle = {North American Chapter of the Association for Computational Linguistics},
    year    = {2021}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.