GithubHelp home page GithubHelp logo

lucidrains / audiolm-pytorch Goto Github PK

View Code? Open in Web Editor NEW
2.2K 61.0 239.0 519 KB

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch

License: MIT License

Python 87.65% Jupyter Notebook 12.35%
artificial-intelligence attention-mechanisms audio-synthesis deep-learning transformers

audiolm-pytorch's Introduction

AudioLM - Pytorch

Implementation of AudioLM, a Language Modeling Approach to Audio Generation out of Google Research, in Pytorch

It also extends the work for conditioning with classifier free guidance with T5. This allows for one to do text-to-audio or TTS, not offered in the paper. Yes, this means VALL-E can be trained from this repository. It is essentially the same.

Please join Join us on Discord if you are interested in replicating this work in the open

This repository now also contains a MIT licensed version of SoundStream. It is also compatible with EnCodec, which is also MIT-licensed at the time of writing.

Update: AudioLM was essentially used to 'solve' music generation in the new MusicLM

In the future, this movie clip would no longer make any sense. You would just prompt an AI instead.

Appreciation

  • Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research

  • ๐Ÿค— Huggingface for their amazing accelerate and transformers libraries

  • MetaAI for Fairseq and the liberal license

  • @eonglints and Joseph for offering their professional advice and expertise as well as pull requests!

  • @djqualia, @yigityu, @inspirit, and @BlackFox1197 for helping with the debugging of soundstream

  • Allen and LWprogramming for reviewing the code and submitting bug fixes!

  • Ilya for finding an issue with multi-scale discriminator downsampling and for soundstream trainer improvements

  • Andrey for identifying a missing loss in soundstream and guiding me through the proper mel spectrogram hyperparameters

  • Alejandro and Ilya for sharing their results with training soundstream, and for working through a few issues with the local attention positional embeddings

  • LWprogramming for adding Encodec compatibility!

  • LWprogramming for finding an issue with handling of the EOS token when sampling from the FineTransformer!

  • @YoungloLee for identifying a big bug in the 1d causal convolution for soundstream related to padding not accounting for strides!

  • Hayden for pointing out some discrepancies in the multi-scale discriminator for Soundstream

Install

$ pip install audiolm-pytorch

Usage

SoundStream & Encodec

There are two options for the neural codec. If you want to use the pretrained 24kHz Encodec, just create an Encodec object as follows:

from audiolm_pytorch import EncodecWrapper
encodec = EncodecWrapper()
# Now you can use the encodec variable in the same way you'd use the soundstream variables below.

Otherwise, to stay more true to the original paper, you can use SoundStream. First, SoundStream needs to be trained on a large corpus of audio data

from audiolm_pytorch import SoundStream, SoundStreamTrainer

soundstream = SoundStream(
    codebook_size = 4096,
    rq_num_quantizers = 8,
    rq_groups = 2,                       # this paper proposes using multi-headed residual vector quantization - https://arxiv.org/abs/2305.02765
    use_lookup_free_quantizer = True,    # whether to use residual lookup free quantization - there are now reports of successful usage of this unpublished technique
    use_finite_scalar_quantizer = False, # whether to use residual finite scalar quantization
    attn_window_size = 128,              # local attention receptive field at bottleneck
    attn_depth = 2                       # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
)

trainer = SoundStreamTrainer(
    soundstream,
    folder = '/path/to/audio/files',
    batch_size = 4,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length_seconds = 2,  # train on 2 second audio
    num_train_steps = 1_000_000
).cuda()

trainer.train()

# after a lot of training, you can test the autoencoding as so

soundstream.eval() # your soundstream must be in eval mode, to avoid having the residual dropout of the residual VQ necessary for training

audio = torch.randn(10080).cuda()
recons = soundstream(audio, return_recons_only = True) # (1, 10080) - 1 channel

Your trained SoundStream can then be used as a generic tokenizer for audio

audio = torch.randn(1, 512 * 320)

codes = soundstream.tokenize(audio)

# you can now train anything with the codebook ids

recon_audio_from_codes = soundstream.decode_from_codebook_indices(codes)

# sanity check

assert torch.allclose(
    recon_audio_from_codes,
    soundstream(audio, return_recons_only = True)
)

You can also use soundstreams that are specific to AudioLM and MusicLM by importing AudioLMSoundStream and MusicLMSoundStream respectively

from audiolm_pytorch import AudioLMSoundStream, MusicLMSoundStream

soundstream = AudioLMSoundStream(...) # say you want the hyperparameters as in Audio LM paper

# rest is the same as above

As of version 0.17.0, you can now invoke the class method on SoundStream to load from checkpoint files, without having to remember your configurations.

from audiolm_pytorch import SoundStream

soundstream = SoundStream.init_and_load_from('./path/to/checkpoint.pt')

To use Weights & Biases tracking, first set use_wandb_tracking = True on the SoundStreamTrainer, then do the following

trainer = SoundStreamTrainer(
    soundstream,
    ...,
    use_wandb_tracking = True
)

# wrap .train() with contextmanager, specifying project and run name

with trainer.wandb_tracker(project = 'soundstream', run = 'baseline'):
    trainer.train()

Hierarchical Transformers

Then three separate transformers (SemanticTransformer, CoarseTransformer, FineTransformer) need to be trained

ex. SemanticTransformer

import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer

# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6,
    flash_attn = True
).cuda()


trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder ='/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()

ex. CoarseTransformer

import torch
from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerTrainer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

soundstream = SoundStream.init_and_load_from('/path/to/trained/soundstream.pt')

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6,
    flash_attn = True
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    codec = soundstream,
    wav2vec = wav2vec,
    folder = '/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1_000_000
)

trainer.train()

ex. FineTransformer

import torch
from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer

soundstream = SoundStream.init_and_load_from('/path/to/trained/soundstream.pt')

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6,
    flash_attn = True
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    codec = soundstream,
    folder = '/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1_000_000
)

trainer.train()

All together now

from audiolm_pytorch import AudioLM

audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

generated_wav = audiolm(batch_size = 1)

# or with priming

generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8))

# or with text condition, if given

generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells'])

Text Conditioned Audio Synthesis

Update: Looks like this will work, given 'VALL-E'

ex. Semantic Transformer

import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = 500,
    dim = 1024,
    depth = 6,
    has_condition = True,               # this will have to be set to True
    cond_as_self_attn_prefix = True     # whether to condition as prefix to self attention, instead of cross attention, as was done in 'VALL-E' paper
).cuda()

# mock text audio dataset (as an example)

# you will have to extend your own from `Dataset`, and return an audio tensor as well as a string (the audio description) in any order (the framework will autodetect and route it into the transformer)

from torch.utils.data import Dataset

class MockTextAudioDataset(Dataset):
    def __init__(self, length = 100, audio_length = 320 * 32):
        super().__init__()
        self.audio_length = audio_length
        self.len = length

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        mock_audio = torch.randn(self.audio_length)
        mock_caption = 'audio caption'
        return mock_caption, mock_audio

dataset = MockTextAudioDataset()

# instantiate semantic transformer trainer and train

trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    dataset = dataset,
    batch_size = 4,
    grad_accum_every = 8,
    data_max_length = 320 * 32,
    num_train_steps = 1_000_000
)

trainer.train()

# after much training above

sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_size = 1, max_length = 2) # (1, < 128) - may terminate early if it detects [eos]

Multi-GPU

Because all the trainer classes uses ๐Ÿค— Accelerator, you can easily do multi gpu training by using the accelerate command as so

At the project root

$ accelerate config

Then, in the same directory

$ accelerate launch train.py

Todo

  • complete CoarseTransformer

  • use fairseq vq-wav2vec for embeddings

  • add conditioning

  • add classifier free guidance

  • add unique consecutive for

  • incorporate ability to use hubert intermediate features as semantic tokens, recommended by eonglints

  • accommodate variable lengthed audio, bring in eos token

  • make sure unique consecutive works with coarse transformer

  • pretty printing all discriminator losses to log

  • handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing

  • complete sampling code for both Coarse and Fine Transformers, which will be tricky

  • make sure full inference with or without prompting works on the AudioLM class

  • complete full training code for soundstream, taking care of discriminator training

  • add efficient gradient penalty for discriminators for soundstream

  • wire up sample hz from sound dataset -> transformers, and have proper resampling within during training - think about whether to allow for dataset to have sound files of varying or enforce same sample hz

  • full transformer training code for all three transformers

  • refactor so semantic transformer has a wrapper to that handles unique consecutives as well as wav to hubert or vq-wav2vec

  • simply not self attend to eos token on the prompting side (semantic for coarse transformer, coarse for fine transformer)

  • add structured dropout from forgetful causal masking, far better than traditional dropouts

  • figure out how to suppress logging in fairseq

  • assert that all three transformers passed into audiolm is compatible

  • allow for specialized relative positional embeddings in fine transformer based on absolute matching positions of quantizers between coarse and fine

  • allow for grouped residual vq in soundstream (use GroupedResidualVQ from vector-quantize-pytorch lib), from hifi-codec

  • add flash attention with NoPE

  • accept prime wave in AudioLM as a path to an audio file, and auto resample for semantic vs acoustic

  • add key / value caching to all transformers, speeding up inference

  • design a hierarchical coarse and fine transformer

  • investigate spec decoding, first test in x-transformers, then port over if applicable

  • redo the positional embeddings in the presence of groups in residual vq

  • test with speech synthesis for starters

  • cli tool, something like audiolm generate <wav.file | text> and save generated wav file to local directory

  • return a list of waves in the case of variable lengthed audio

  • just take care of the edge case in coarse transformer text conditioned training, where the raw wave is resampled at different frequencies. autodetermine how to route based on length

Citations

@inproceedings{Borsos2022AudioLMAL,
  title  = {AudioLM: a Language Modeling Approach to Audio Generation},
  author = {Zal{\'a}n Borsos and Rapha{\"e}l Marinier and Damien Vincent and Eugene Kharitonov and Olivier Pietquin and Matthew Sharifi and Olivier Teboul and David Grangier and Marco Tagliasacchi and Neil Zeghidour},
  year   = {2022}
}
@misc{https://doi.org/10.48550/arxiv.2107.03312,
  title  = {SoundStream: An End-to-End Neural Audio Codec},
  author = {Zeghidour, Neil and Luebs, Alejandro and Omran, Ahmed and Skoglund, Jan and Tagliasacchi, Marco},
  publisher = {arXiv},
  url    = {https://arxiv.org/abs/2107.03312},
  year   = {2021}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam M. Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@article{Ho2022ClassifierFreeDG,
    title   = {Classifier-Free Diffusion Guidance},
    author  = {Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2207.12598}
}
@misc{crowson2022,
    author  = {Katherine Crowson},
    url     = {https://twitter.com/rivershavewings}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Liu2022FCMFC,
    title   = {FCM: Forgetful Causal Masking Makes Causal Language Models Better Zero-Shot Learners},
    author  = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.13432}
}
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {under review}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Li2021LocalViTBL,
    title   = {LocalViT: Bringing Locality to Vision Transformers},
    author  = {Yawei Li and K. Zhang and Jie Cao and Radu Timofte and Luc Van Gool},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2104.05707}
}
@article{Defossez2022HighFN,
    title   = {High Fidelity Neural Audio Compression},
    author  = {Alexandre D'efossez and Jade Copet and Gabriel Synnaeve and Yossi Adi},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.13438}
}
@article{Hu2017SqueezeandExcitationN,
    title   = {Squeeze-and-Excitation Networks},
    author  = {Jie Hu and Li Shen and Gang Sun},
    journal = {2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition},
    year    = {2017},
    pages   = {7132-7141}
}
@inproceedings{Yang2023HiFiCodecGV,
    title   = {HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec},
    author  = {Dongchao Yang and Songxiang Liu and Rongjie Huang and Jinchuan Tian and Chao Weng and Yuexian Zou},
    year    = {2023}
}
@article{Kazemnejad2023TheIO,
    title   = {The Impact of Positional Encoding on Length Generalization in Transformers},
    author  = {Amirhossein Kazemnejad and Inkit Padhi and Karthikeyan Natesan Ramamurthy and Payel Das and Siva Reddy},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.19466}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@misc{yu2023language,
    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
    author  = {Lijun Yu and Josรฉ Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
    year    = {2023},
    eprint  = {2310.05737},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{Katsch2023GateLoopFD,
    title   = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
    author  = {Tobias Katsch},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265018962}
}

audiolm-pytorch's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

audiolm-pytorch's Issues

What's the mean of dimension?

Thank you for your work.
I ran t5_encode_text like below:
a = t5_encode_text('I like your work.')
print(a.shape)
Then It was a.shape =torch.Size(17,3,768)

I think first dimension (17) is length of texts and third dimension indicates feature dimension.
What is the second dimension 3? What does it mean?

Anyway, Thank you for your contribution, It is easy to run your code opened in github.

typical range of `num_train_steps`?

Thanks for sharing this great repo!

I'm wondering what is the typical range of num_train_steps for a SoundStream model and others.
I tested with 10000 and saw the loss went down somewhat smoothly, but it did not generate any meaningful results (very noisy).

Error training soundstream in 0.3.0

Since the latest code (0.3.0) I get the following error attempting to train a SoundStream:

Traceback (most recent call last):
File "train_soundstream.py", line 26, in
trainer.train()
File "/home/qualia/code/audiolm/audiolm_pytorch/trainer.py", line 402, in train
logs = self.train_step()
File "/home/qualia/code/audiolm/audiolm_pytorch/trainer.py", line 320, in train_step
discr_losses = self.soundstream(
File "/home/qualia/anaconda3/envs/audiolm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/qualia/code/audiolm/audiolm_pytorch/soundstream.py", line 416, in forward
stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2
RuntimeError: imag is not implemented for tensors with non-complex dtypes.

I can't make it work

First of all thank you for your work.

I am trying to get the code to work. I get the system to train, using about 5190 hours of speech (a part of the libri-light) but at the output of the generation I only have noise. I have tried to encode-decode an audio signal with the SoundStream model that I obtained as the result of the training and I see that it is not working because there is only noise at the output.

Could you tell me if I'm doing something wrong? Maybe I need to use a larger amount of audio or a specific database?

Thank you!

RuntimeError: stack expects each tensor to be equal size, but got [5440] at entry 0 and [5120] at entry 2

The code appears to be able to handle audio of varying sizes. Indeed, librispeech contains audio of different lengths.

However, when I run on a corpus of mixed size audio, I get the following error:

[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.
Traceback (most recent call last):
  File "/root/foo.py", line 19, in <module>
    trainer.train()
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/trainer.py", line 411, in train
    logs = self.train_step()
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/trainer.py", line 302, in train_step
    wave, = next(self.dl_iter)
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/trainer.py", line 70, in cycle
    for data in dl:
  File "/opt/conda/lib/python3.10/site-packages/accelerate/data_loader.py", line 375, in __iter__
    current_batch = next(dataloader_iter)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 61, in fetch
    return self.collate_fn(data)
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/data.py", line 98, in inner
    data = torch.stack(data)
RuntimeError: stack expects each tensor to be equal size, but got [5440] at entry 0 and [5120] at entry 2

use_complex_stft_discriminator = False

Hi Phil,

after setting use_complex_stft_discriminator = False i can see that discr losses are zero or very small:
soundstream total loss: 177.001, soundstream recon loss: 1.143 | discr (scale 1) loss: 0.000 | discr (scale 0.5) loss: 0.000 | discr (scale 0.25) loss: 0.000

also total loss seems to be too high

when setting use_complex_stft_discriminator = True

losses look more normal to me:
soundstream total loss: 12.153, soundstream recon loss: 0.007 | discr (scale 1) loss: 1.184 | discr (scale 0.5) loss: 4.218 | discr (scale 0.25) loss: 2.376

Wave discriminator?

The original soundlm paper also uses a wave-based discriminator:

III.D
"For the wave-based discriminator, we use the same multi๏ฟฝresolution convolutional discriminator proposed in [15] and
adopted in [45]. Three structurally identical models are applied
to the input audio at different resolutions: original, 2-times down-sampled, and 4-times down-sampled. Each single-scale
discriminator consists of an initial plain convolution followed
by four grouped convolutions, each of which has a group size
of 4, a down-sampling factor of 4, and a channel multiplier
of 4 up to a maximum of 1024 output channels. They are
followed by two more plain convolution layers to produce the
final output, i.e., the logits."

citing: K. Kumar, R. Kumar, T. de Boissiere, L. Gestin, W. Z. Teoh, J. Sotelo,
A. de Brebisson, Y. Bengio, and A. Courville, โ€œMelGAN: Generative
adversarial networks for conditional waveform synthesis,โ€ in Advances
in Neural Information Processing Systems, 2019

M. Tagliasacchi, Y. Li, K. Misiunas, and D. Roblek, โ€œSEANet: A multi๏ฟฝmodal speech enhancement network,โ€ in Interspeech, 2020, pp. 1126โ€“
1130.

SoundDataset.target_sample_hz will always exist

In data.py, we have a check if self.target_sample_hz is not None. However, on initialization the field is set to (None,) if the input argument target_sample_hz is None, so the field will always exist. Was the intention here to set to None instead of a tuple of one element None? Otherwise, it seems like that conditional will always execute.

ComplexConv2d in ComplexSTFTDiscriminator gives RuntimeError

Hi,
I am trying to train a soundstream model, and I get the following error during the forward pass:

File "audiolm-pytorch/audiolm-pytorch/soundstream.py", line 148, in forward
return F.conv2d(x, weight, bias, stride = self.stride, padding = self.padding)
RuntimeError: getCudnnDataTypeFromScalarType() not supported for ComplexFloat

It seems that conv2d not support the complex results obtained by STFT in ComplexSTFTDiscriminator. Has anyone come across the same issue before?

0.1.18 issues training SoundStream

Since the latest version, I'm getting an error training sound stream: AttributeError: 'tuple' object has no attribute 'to'

A more complete stack trace:

Traceback (most recent call last):
File "train_soundstream.py", line 20, in
trainer.train()
File "/mnt/c/audio-ml-workspace/audiolm/audiolm_pytorch/trainer.py", line 376, in train
logs = self.train_step()
File "/mnt/c/audio-ml-workspace/audiolm/audiolm_pytorch/trainer.py", line 271, in train_step
wave = wave.to(device)
AttributeError: 'tuple' object has no attribute 'to'

I'll see if I can figure out what caused this...

pad_id not in nn.Embedding vocab

Hey, I'm sure you're aware of this, but flagging just in case.

As you know, in the SemanticTransformer, the pad_id defaults to -1. When using a batch size greater than 1, tokens are padded by batch_unique_consecutive, using this pad_id.

This then causes an issue when these padded token sequences are passed through the embedding layer as the embedding layer has been initialized with a vocab size of num_semantic_tokens + 1, which doesn't include this default pad_id. This obviously causes an index out of range error (which can be tricky to spot immediately when running with CUDA, easier to spot on CPU).

Perhaps the pad_id should be self.eos_id + 1 (equivalent to num_semantic_tokens + 2)? Or maybe there's a smarter solution...

Activation units position

Hi Phil,
I was wondering if ELU activations should be moved up before Conv? considering final encoder/decoder compositions:

  1. Emb conv from input
  2. NxEncoderBlock [3xResidualUnit + DownsampleConv]
    so if we position activations before Conv in ResidualUnit it will follow pattern [Conv -> Act -> Conv ...]
    we will also need to add additional Act to EncoderBlock: [3xResidualUnit + Act + DownsampleConv]
    def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7):

Consider default of normalized=True for all STFT and Mel transforms

normalized = False by default in torch STFTs and Mels. Consider using a default of normalized = True, particularly if computing multiple length transforms.

This is done in Encodec: https://github.com/facebookresearch/encodec/blob/main/encodec/msstftd.py#L62

And I had a non-audiolm model I trained recently that didn't work until I flipped this variable.

This won't completely solve the loss balancing issue #61 but might nonetheless improve training

How to create wav2vec checkpoints?

First off, thanks for this code! I've been trying to test it out, and am attempting to start training the semantic transformer. Following the current instructions, it's not clear how I should create the files referenced in the wav2vec checkpoint/kmeans paths, e.g:

wav2vec = HubertWithKmeans(
checkpoint_path = './hubert/hubert_base_ls960.pt',
kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

How did you create these? Thanks again!

Global num_workers?

In training SoundStream, I see that there is no way (besides modifying code) to set the num_workers.

But doing so has a dramatic impact on speed on the machine I'm trying right now. A way to set this globally, or for each trainer, would be nice.

BTW, I set num_workers=32 and after some steps it gets this error:

Traceback (most recent call last):
  File "/root/./foo.py", line 21, in <module>
    trainer.train()
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/trainer.py", line 411, in train
    logs = self.train_step()
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/trainer.py", line 302, in train_step
    wave, = next(self.dl_iter)
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/trainer.py", line 70, in cycle
    for data in dl:
  File "/opt/conda/lib/python3.10/site-packages/accelerate/data_loader.py", line 383, in __iter__
    next_batch = next(dataloader_iter)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1313, in _next_data
    return self._process_data(data)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1359, in _process_data
    data.reraise()
  File "/opt/conda/lib/python3.10/site-packages/torch/_utils.py", line 543, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 13.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataset.py", line 295, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/data.py", line 74, in __getitem__
    data_tuple = tuple((resample(d, sample_hz, target_sample_hz) if exists(target_sample_hz) else d) for d, target_sample_hz in zip(data, self.target_sample_hz))
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/data.py", line 74, in <genexpr>
    data_tuple = tuple((resample(d, sample_hz, target_sample_hz) if exists(target_sample_hz) else d) for d, target_sample_hz in zip(data, self.target_sample_hz))
  File "/opt/conda/lib/python3.10/site-packages/torchaudio/functional/functional.py", line 1600, in resample
    resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
  File "/opt/conda/lib/python3.10/site-packages/torchaudio/functional/functional.py", line 1532, in _apply_sinc_resample_kernel
    waveform = waveform.view(-1, shape[-1])
RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0] because the unspecified dimension size -1 can be any value and is ambiguous

Error training soundstream on 0.12.1

Traceback (most recent call last):
File "train_soundstream.py", line 6, in
soundstream = config_util.createSoundStream(config)
File "/home/qualia/code/audiolm/config_util.py", line 28, in createSoundStream
soundstream = SoundStream(
File "/home/qualia/code/audiolm/audiolm_pytorch/soundstream.py", line 411, in init
self.encoder_attn = nn.Sequential([LocalTransformerBlock(**attn_kwargs) for _ in range(attn_depth)]) if use_local_attn else None
File "/home/qualia/code/audiolm/audiolm_pytorch/soundstream.py", line 411, in
self.encoder_attn = nn.Sequential(
[LocalTransformerBlock(**attn_kwargs) for _ in range(attn_depth)]) if use_local_attn else None
File "/home/qualia/code/audiolm/audiolm_pytorch/soundstream.py", line 333, in init
self.attn = LocalMHA(dim = dim, qk_rmsnorm = True, **kwargs)
File "/home/qualia/anaconda3/envs/audiolm/lib/python3.8/site-packages/local_attention/transformer.py", line 55, in init
self.attn_fn = LocalAttention(
TypeError: init() got an unexpected keyword argument 'qk_rmsnorm'

Any suggestion for how to fix? Do I have an old version of some dependency?

hubert instead of w2v-bert ?

Did you use hubert instead of w2v-bert (used in the original paper) because of the avilability of the hubert model on HF, or because of other reasons, i.e., have you seen hubert results that support this replacement?

soundstream.load() no longer works in 0.12.3

is this intentional?

my first guess is that we should be calling the associated trainer.load() instead. however, when it comes to inference time, does that mean we need to load all the trainer classes?

in case it helps, the exception i got when trying to resume training on 0.12.3 (started on 0.12.1):

File "train_soundstream.py", line 6, in
soundstream = config_util.createSoundStream(config)
File "/home/qualia/code/audiolm/config_util.py", line 38, in createSoundStream
soundstream.load(checkpoint)
File "/home/qualia/code/audiolm/audiolm_pytorch/soundstream.py", line 499, in load
self.load_state_dict(torch.load(str(path)))
File "/home/qualia/anaconda3/envs/audiolm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SoundStream:
Missing key(s) in state_dict: "encoder.0.conv.weight", "encoder.0.conv.bias", "encoder.1.0.fn.0.conv.weight", "encoder.1.0.fn.0.conv.bias", "encoder.1.0.fn.2.conv.weight", "encoder.1.0.fn.2.conv.bias", "encoder.1.1.fn.0.conv.weight", "encoder.1.1.fn.0.conv.bias", "encoder.1.1.fn.2.conv.weight", "encoder.1.1.fn.2.conv.bias", "encoder.1.2.fn.0.conv.weight", "encoder.1.2.fn.0.conv.bias", "encoder.1.2.fn.2.conv.weight", "encoder.1.2.fn.2.conv.bias", "encoder.1.3.conv.weight", "encoder.1.3.conv.bias", "encoder.2.prenorm.weight", "encoder.2.prenorm.bias", "encoder.2.mhema.expansion", "encoder.2.mhema.reduction", "encoder.2.mhema.alphas", "encoder.2.mhema.dampen_factors", "encoder.3.0.fn.0.conv.weight", "encoder.3.0.fn.0.conv.bias", "encoder.3.0.fn.2.conv.weight", "encoder.3.0.fn.2.conv.bias", "encoder.3.1.fn.0.conv.weight", "encoder.3.1.fn.0.conv.bias", "encoder.3.1.fn.2.conv.weight", "encoder.3.1.fn.2.conv.bias", "encoder.3.2.fn.0.conv.weight", "encoder.3.2.fn.0.conv.bias", "encoder.3.2.fn.2.conv.weight", "encoder.3.2.fn.2.conv.bias", "encoder.3.3.conv.weight", "encoder.3.3.conv.bias", "encoder.4.prenorm.weight", "encoder.4.prenorm.bias", "encoder.4.mhema.expansion", "encoder.4.mhema.reduction", "encoder.4.mhema.alphas", "encoder.4.mhema.dampen_factors", "encoder.5.0.fn.0.conv.weight", "encoder.5.0.fn.0.conv.bias", "encoder.5.0.fn.2.conv.weight", "encoder.5.0.fn.2.conv.bias", "encoder.5.1.fn.0.conv.weight", "encoder.5.1.fn.0.conv.bias", "encoder.5.1.fn.2.conv.weight", "encoder.5.1.fn.2.conv.bias", "encoder.5.2.fn.0.conv.weight", "encoder.5.2.fn.0.conv.bias", "encoder.5.2.fn.2.conv.weight", "encoder.5.2.fn.2.conv.bias", "encoder.5.3.conv.weight", "encoder.5.3.conv.bias", "encoder.6.prenorm.weight", "encoder.6.prenorm.bias", "encoder.6.mhema.expansion", "encoder.6.mhema.reduction", "encoder.6.mhema.alphas", "encoder.6.mhema.dampen_factors", "encoder.7.0.fn.0.conv.weight", "encoder.7.0.fn.0.conv.bias", "encoder.7.0.fn.2.conv.weight", "encoder.7.0.fn.2.conv.bias", "encoder.7.1.fn.0.conv.weight", "encoder.7.1.fn.0.conv.bias", "encoder.7.1.fn.2.conv.weight", "encoder.7.1.fn.2.conv.bias", "encoder.7.2.fn.0.conv.weight", "encoder.7.2.fn.0.conv.bias", "encoder.7.2.fn.2.conv.weight", "encoder.7.2.fn.2.conv.bias", "encoder.7.3.conv.weight", "encoder.7.3.conv.bias", "encoder.8.prenorm.weight", "encoder.8.prenorm.bias", "encoder.8.mhema.expansion", "encoder.8.mhema.reduction", "encoder.8.mhema.alphas", "encoder.8.mhema.dampen_factors", "encoder.9.conv.weight", "encoder.9.conv.bias", "encoder_attn.0.attn.q_scale", "encoder_attn.0.attn.k_scale", "encoder_attn.0.attn.norm.weight", "encoder_attn.0.attn.norm.bias", "encoder_attn.0.attn.to_qkv.weight", "encoder_attn.0.attn.attn_fn.rel_pos.inv_freq", "encoder_attn.0.attn.to_out.weight", "encoder_attn.0.ff.0.weight", "encoder_attn.0.ff.0.bias", "encoder_attn.0.ff.1.weight", "encoder_attn.0.ff.4.weight", "encoder_attn.1.attn.q_scale", "encoder_attn.1.attn.k_scale", "encoder_attn.1.attn.norm.weight", "encoder_attn.1.attn.norm.bias", "encoder_attn.1.attn.to_qkv.weight", "encoder_attn.1.attn.attn_fn.rel_pos.inv_freq", "encoder_attn.1.attn.to_out.weight", "encoder_attn.1.ff.0.weight", "encoder_attn.1.ff.0.bias", "encoder_attn.1.ff.1.weight", "encoder_attn.1.ff.4.weight", "rq.layers.0._codebook.initted", "rq.layers.0._codebook.cluster_size", "rq.layers.0._codebook.embed_avg", "rq.layers.0._codebook.embed", "rq.layers.1._codebook.initted", "rq.layers.1._codebook.cluster_size", "rq.layers.1._codebook.embed_avg", "rq.layers.1._codebook.embed", "rq.layers.2._codebook.initted", "rq.layers.2._codebook.cluster_size", "rq.layers.2._codebook.embed_avg", "rq.layers.2._codebook.embed", "rq.layers.3._codebook.initted", "rq.layers.3._codebook.cluster_size", "rq.layers.3._codebook.embed_avg", "rq.layers.3._codebook.embed", "rq.layers.4._codebook.initted", "rq.layers.4._codebook.cluster_size", "rq.layers.4._codebook.embed_avg", "rq.layers.4._codebook.embed", "rq.layers.5._codebook.initted", "rq.layers.5._codebook.cluster_size", "rq.layers.5._codebook.embed_avg", "rq.layers.5._codebook.embed", "rq.layers.6._codebook.initted", "rq.layers.6._codebook.cluster_size", "rq.layers.6._codebook.embed_avg", "rq.layers.6._codebook.embed", "rq.layers.7._codebook.initted", "rq.layers.7._codebook.cluster_size", "rq.layers.7._codebook.embed_avg", "rq.layers.7._codebook.embed", "rq.layers.8._codebook.initted", "rq.layers.8._codebook.cluster_size", "rq.layers.8._codebook.embed_avg", "rq.layers.8._codebook.embed", "rq.layers.9._codebook.initted", "rq.layers.9._codebook.cluster_size", "rq.layers.9._codebook.embed_avg", "rq.layers.9._codebook.embed", "rq.layers.10._codebook.initted", "rq.layers.10._codebook.cluster_size", "rq.layers.10._codebook.embed_avg", "rq.layers.10._codebook.embed", "rq.layers.11._codebook.initted", "rq.layers.11._codebook.cluster_size", "rq.layers.11._codebook.embed_avg", "rq.layers.11._codebook.embed", "decoder_attn.0.attn.q_scale", "decoder_attn.0.attn.k_scale", "decoder_attn.0.attn.norm.weight", "decoder_attn.0.attn.norm.bias", "decoder_attn.0.attn.to_qkv.weight", "decoder_attn.0.attn.attn_fn.rel_pos.inv_freq", "decoder_attn.0.attn.to_out.weight", "decoder_attn.0.ff.0.weight", "decoder_attn.0.ff.0.bias", "decoder_attn.0.ff.1.weight", "decoder_attn.0.ff.4.weight", "decoder_attn.1.attn.q_scale", "decoder_attn.1.attn.k_scale", "decoder_attn.1.attn.norm.weight", "decoder_attn.1.attn.norm.bias", "decoder_attn.1.attn.to_qkv.weight", "decoder_attn.1.attn.attn_fn.rel_pos.inv_freq", "decoder_attn.1.attn.to_out.weight", "decoder_attn.1.ff.0.weight", "decoder_attn.1.ff.0.bias", "decoder_attn.1.ff.1.weight", "decoder_attn.1.ff.4.weight", "decoder.0.conv.weight", "decoder.0.conv.bias", "decoder.1.0.conv.weight", "decoder.1.0.conv.bias", "decoder.1.1.fn.0.conv.weight", "decoder.1.1.fn.0.conv.bias", "decoder.1.1.fn.2.conv.weight", "decoder.1.1.fn.2.conv.bias", "decoder.1.2.fn.0.conv.weight", "decoder.1.2.fn.0.conv.bias", "decoder.1.2.fn.2.conv.weight", "decoder.1.2.fn.2.conv.bias", "decoder.1.3.fn.0.conv.weight", "decoder.1.3.fn.0.conv.bias", "decoder.1.3.fn.2.conv.weight", "decoder.1.3.fn.2.conv.bias", "decoder.2.prenorm.weight", "decoder.2.prenorm.bias", "decoder.2.mhema.expansion", "decoder.2.mhema.reduction", "decoder.2.mhema.alphas", "decoder.2.mhema.dampen_factors", "decoder.3.0.conv.weight", "decoder.3.0.conv.bias", "decoder.3.1.fn.0.conv.weight", "decoder.3.1.fn.0.conv.bias", "decoder.3.1.fn.2.conv.weight", "decoder.3.1.fn.2.conv.bias", "decoder.3.2.fn.0.conv.weight", "decoder.3.2.fn.0.conv.bias", "decoder.3.2.fn.2.conv.weight", "decoder.3.2.fn.2.conv.bias", "decoder.3.3.fn.0.conv.weight", "decoder.3.3.fn.0.conv.bias", "decoder.3.3.fn.2.conv.weight", "decoder.3.3.fn.2.conv.bias", "decoder.4.prenorm.weight", "decoder.4.prenorm.bias", "decoder.4.mhema.expansion", "decoder.4.mhema.reduction", "decoder.4.mhema.alphas", "decoder.4.mhema.dampen_factors", "decoder.5.0.conv.weight", "decoder.5.0.conv.bias", "decoder.5.1.fn.0.conv.weight", "decoder.5.1.fn.0.conv.bias", "decoder.5.1.fn.2.conv.weight", "decoder.5.1.fn.2.conv.bias", "decoder.5.2.fn.0.conv.weight", "decoder.5.2.fn.0.conv.bias", "decoder.5.2.fn.2.conv.weight", "decoder.5.2.fn.2.conv.bias", "decoder.5.3.fn.0.conv.weight", "decoder.5.3.fn.0.conv.bias", "decoder.5.3.fn.2.conv.weight", "decoder.5.3.fn.2.conv.bias", "decoder.6.prenorm.weight", "decoder.6.prenorm.bias", "decoder.6.mhema.expansion", "decoder.6.mhema.reduction", "decoder.6.mhema.alphas", "decoder.6.mhema.dampen_factors", "decoder.7.0.conv.weight", "decoder.7.0.conv.bias", "decoder.7.1.fn.0.conv.weight", "decoder.7.1.fn.0.conv.bias", "decoder.7.1.fn.2.conv.weight", "decoder.7.1.fn.2.conv.bias", "decoder.7.2.fn.0.conv.weight", "decoder.7.2.fn.0.conv.bias", "decoder.7.2.fn.2.conv.weight", "decoder.7.2.fn.2.conv.bias", "decoder.7.3.fn.0.conv.weight", "decoder.7.3.fn.0.conv.bias", "decoder.7.3.fn.2.conv.weight", "decoder.7.3.fn.2.conv.bias", "decoder.8.prenorm.weight", "decoder.8.prenorm.bias", "decoder.8.mhema.expansion", "decoder.8.mhema.reduction", "decoder.8.mhema.alphas", "decoder.8.mhema.dampen_factors", "decoder.9.conv.weight", "decoder.9.conv.bias", "discriminators.0.init_conv.weight", "discriminators.0.init_conv.bias", "discriminators.0.conv_layers.0.0.weight", "discriminators.0.conv_layers.0.0.bias", "discriminators.0.conv_layers.1.0.weight", "discriminators.0.conv_layers.1.0.bias", "discriminators.0.conv_layers.2.0.weight", "discriminators.0.conv_layers.2.0.bias", "discriminators.0.conv_layers.3.0.weight", "discriminators.0.conv_layers.3.0.bias", "discriminators.0.final_conv.0.weight", "discriminators.0.final_conv.0.bias", "discriminators.0.final_conv.2.weight", "discriminators.0.final_conv.2.bias", "discriminators.1.init_conv.weight", "discriminators.1.init_conv.bias", "discriminators.1.conv_layers.0.0.weight", "discriminators.1.conv_layers.0.0.bias", "discriminators.1.conv_layers.1.0.weight", "discriminators.1.conv_layers.1.0.bias", "discriminators.1.conv_layers.2.0.weight", "discriminators.1.conv_layers.2.0.bias", "discriminators.1.conv_layers.3.0.weight", "discriminators.1.conv_layers.3.0.bias", "discriminators.1.final_conv.0.weight", "discriminators.1.final_conv.0.bias", "discriminators.1.final_conv.2.weight", "discriminators.1.final_conv.2.bias", "discriminators.2.init_conv.weight", "discriminators.2.init_conv.bias", "discriminators.2.conv_layers.0.0.weight", "discriminators.2.conv_layers.0.0.bias", "discriminators.2.conv_layers.1.0.weight", "discriminators.2.conv_layers.1.0.bias", "discriminators.2.conv_layers.2.0.weight", "discriminators.2.conv_layers.2.0.bias", "discriminators.2.conv_layers.3.0.weight", "discriminators.2.conv_layers.3.0.bias", "discriminators.2.final_conv.0.weight", "discriminators.2.final_conv.0.bias", "discriminators.2.final_conv.2.weight", "discriminators.2.final_conv.2.bias", "stft_discriminator.init_conv.weight", "stft_discriminator.init_conv.bias", "stft_discriminator.layers.0.0.weight", "stft_discriminator.layers.0.0.bias", "stft_discriminator.layers.0.1.b", "stft_discriminator.layers.0.2.weight", "stft_discriminator.layers.0.2.bias", "stft_discriminator.layers.1.0.weight", "stft_discriminator.layers.1.0.bias", "stft_discriminator.layers.1.1.b", "stft_discriminator.layers.1.2.weight", "stft_discriminator.layers.1.2.bias", "stft_discriminator.layers.2.0.weight", "stft_discriminator.layers.2.0.bias", "stft_discriminator.layers.2.1.b", "stft_discriminator.layers.2.2.weight", "stft_discriminator.layers.2.2.bias", "stft_discriminator.layers.3.0.weight", "stft_discriminator.layers.3.0.bias", "stft_discriminator.layers.3.1.b", "stft_discriminator.layers.3.2.weight", "stft_discriminator.layers.3.2.bias", "stft_discriminator.layers.4.0.weight", "stft_discriminator.layers.4.0.bias", "stft_discriminator.layers.4.1.b", "stft_discriminator.layers.4.2.weight", "stft_discriminator.layers.4.2.bias", "stft_discriminator.layers.5.0.weight", "stft_discriminator.layers.5.0.bias", "stft_discriminator.layers.5.1.b", "stft_discriminator.layers.5.2.weight", "stft_discriminator.layers.5.2.bias", "stft_discriminator.final_conv.weight", "stft_discriminator.final_conv.bias", "mel_spec_transforms.0.spectrogram.window", "mel_spec_transforms.0.mel_scale.fb", "mel_spec_transforms.1.spectrogram.window", "mel_spec_transforms.1.mel_scale.fb", "mel_spec_transforms.2.spectrogram.window", "mel_spec_transforms.2.mel_scale.fb", "mel_spec_transforms.3.spectrogram.window", "mel_spec_transforms.3.mel_scale.fb", "mel_spec_transforms.4.spectrogram.window", "mel_spec_transforms.4.mel_scale.fb", "mel_spec_transforms.5.spectrogram.window", "mel_spec_transforms.5.mel_scale.fb".
Unexpected key(s) in state_dict: "model", "ema_model", "optim", "discr_optim", "multiscale_discr_optimizer_0", "multiscale_discr_optimizer_1", "multiscale_discr_optimizer_2".

Adapting AudioLM to support SingSong style accompaniment generation

Hi @lucidrains - thanks for your awesome work here. Great stuff as always.

I recently came across Google's new SingSong paper (https://arxiv.org/pdf/2301.12662.pdf), in which they adapt AudioLM for generation of instrumental accompaniments conditioned upon sung input vocals, and I was wondering if you (or anyone else ๐Ÿ™‚ ) might have any practical advice on implementing the adaptations necessary.

Also, to this end, would you happen to know if anyone has managed to train a decent soundstream model and made it publicly available yet?

Best, and thanks again for your work here,
Shaun

support SPEAR-TTS

Hi author, thanks for your contribution!

Google has released their new work named SPEAR-TTS (https://arxiv.org/abs/2302.03540), whose arch is similar to audiolm. Please consider interpolating it into this repo, and that will be really helpful to the community. Thanks!

About the Coarse and Fine acoustic modeling

Hi, I try to using your code to reproduce AudioLM. I found that your CoarseTransformer model is based on self-attention. In the AudioLM paper, they assert ". During training, we use random cropping to equivalent input lengths of 30, 10 and 3 seconds for the three stages". Considering their soundstream is 320 downsample, which means that 1s audio including 50 tokens. 10s audio inlcudes 500 tokens for one codebook, and soundstream inlcudes more than 4 codebook. I want to know whether CoarseTransformer can deal with such long sequence? Do you have any idea to solve the long sequence problem?

multi-scale spectral reconstruction loss

Hello and thank you for your excellent job, your implementations are really helpful.

I've got a question about losses in soundstream model, in the original paper, formulas (4) and (5) on page 6 correspond to the multi-scale spectral reconstruction loss, but I don't see it in your implementation. As far as I can tell, generator losses are presented here and there is no one that I'm talking about. Maybe I just missed it.

audio2audio?

I would like train AudioLM in an audio2audio setting. i.e. instead of taking a prime wave and continuing it, instead taking audio and transforming it to audio, probably of the same length.

Would that extension by possible?

Training VALL-E

In the given vall-e example only text prefix given but in the VALL-E paper we also need to pass the 3 seconds audio prompt as prefix along with text right? so is it possible to train vall-e with [text+audio] using the current settings in this repo. if possible can you shed some light on it? Thanks

Use FLAN variant of T5

Not sure where to leave suggestions, but I would consider replacing T5 with T5-FLAN (also available via HuggingFace) for a better bang for the buck.

Batch size > 1 in generate function

Hey, the line last_logits = rearrange(last_logits, 'b 1 c -> b c') in the SemanticTransformer generate() function, understandably doesn't like a batch size greater than 1. Is there a nice clean fix for this that doesn't involve doing things like prime_ids = prime_ids[:1, :] earlier in the function? FWIW I don't neccessarily think it's a bad thing to only do one generation rather than a whole batch.

Some clarifications and potential bug?

HI, first of all thank you for putting the efforts to reproduce Soundstream code. I have a few questions to which I would greatly appreciate any clarifications.

  1. I notice that you use encoder_attn and decoder_attn with LocalMHA. But I couldn't find where the attention parameters are being optimized. If I understood correctly the function: non_discr_parameters only returns encoder and decoder parameters and not the attention. Please correct me if I missed something.

  2. For the SoundStreamTrainer, train_step, shouldn't we first set the dicriminator grads to zero before calling the backward on the discriminator losses? The generator losses would have also added gradients to the discriminator parameters. Currently, the discriminator losses backward pass will add to those if I understood correctly. Again please let me know if I misunderstood something.

Thanks again,
Apoorv

Prefix context in CoarseTransformer and FineTransformer

Hey! What is the reason for not using prefix_context for the semantic tokens In CoarseTransformer and the coarse tokens in FineTransformer? It seems that we want to condition on these tokens without needing to calculate the logits, so it could save some computation.

I also noticed the implementation adds the semantic cross-entropy loss to the CoarseTransformer loss and the coarse cross-entropy loss to the FineTransformer loss. I looked through the paper and couldn't seem to find discussion on this. What effect does this have on training?

About training Soundstream from checkpoint

When trying to train Soundstream from a saved checkpoint, it seems to work well (in an test case the checkpoint after 9000 steps is loaded) until it saves the model. The saved model behaves like learning from the start.
When trying to learn like only 10 steps, the model (not saved to disk) seems to work like the 9000 model.

Model parameters

Hello! With the current neural network parameters presented in the README, can I get a decent result if I train AudioLM on Russian LibriSpeech? Or is it just a proof of concept?

Soundstream loss doesn't decrease after 1167 steps - version 0.7.1

Hi,

First of all, thank you for this project and all the other open source projects you're doing. I'm a big fan of your work.

I was training with the latest version on LibriSpeech dataset and looks like recon_loss shoots up and training goes to nowhere afterwards. I didn't seem to have this with previous releases, but I will roll back, try again and report results here, but this might be a regression with latest changes? I wanted to post it if it would help anyone.

image

Heads up that ComplexFloat doesn't appear to be supported by DistributedDataParallel (NCCL backend)

SoundStream training seems to be working well on a single GPU but when you attempt to use more than one GPU, this error is thrown up: RuntimeError: Input tensor data type is not supported for NCCL process group: ComplexFloat.

It would appear that the ComplexLeakyReLU class is probably the cause, but the workaround mentioned here doesn't neccessarily seem to be a great answer.

Not really a bug for AudioLM, just a heads up for anyone attempting multi-GPU training right now.

(Pytorch 1.13)

The training code of soundstream

Hi, when I try to train the soundstream using your code, but I find the loss is increased. from 10 to 100, 1000,..... and so on.
Whether you have competely train a soundstream model using this code?

causal convolutions

Screenshot from 2022-10-26 09-35-39

in the soundstream paper, they claim all the convolutions are causal, but does this apply to the upsampling convtranspose? i'm not sure if it is possible to maintain causality

Setup requirements

Hey, just a couple of minor requirements changes. I found that:

  • Hugging Face's T5Tokenizer library requires sentencepiece. The datasets library doesn't include it in the default installation (it is covered in [TESTS])
  • scikit-learn was required for the k-means HuBERT models
  • PIL and torchvision are not required by trainer.py

Semantic tokenizer training nicely using train.py with these minor changes. Thanks!

Grouped convolution in multi-scale discriminator

Perhaps I'm missing something, but shouldn't we be using groups = groups in this line of MultiScaleDiscriminator?

Each single-scale discriminator consists of an initial plain convolution followed by four grouped convolutions, each of which has a group size of 4, a down-sampling factor of 4, and a channel multiplier of 4 up to a maximum of 1024 output channels.

Support our open source music pretrained Transformer

Hi, we are researchers from the MAP (music audio pre-train) project. We pre-train transformer LMs on large-scale music audio datasets.
See below. Our model, MERT, uses a similar method as HuBERT and has verified its performance on downstream music information retrieval tasks. It has been released on hugging face and can be used interchangeably with HuBERT loading code: model = HubertModel.from_pretrained("m-a-p/MERT-v0")
We are currently working on training a better base model and scaling up to a large model with more music+speech data.
Using our weights as an initialization will be a better start than using speech HuBERT. Better checkpoints will be released soon.

https://huggingface.co/m-a-p/MERT-v0

Soundstream does not get better in training after update

I updated from 0.2.4 and the previous models didn't work anymore. In 0.2.4 after approx. 70,000 trainsteps an reconstruction where you could at least see the similarity between the original and the reconstruction was working.
In the new update after ~70,000 trainsteps there is only some improvement in the first 2 seconds, but its still only noise. After two seconds there is only one uniform signal. It also looks like it when plotting it. Im using librispeech to train and the configuration in the example (tried 1024 and 2048 codebooksize).
Im not sure if Im just using it wrong or its a bug.
The printet loss during training seems to go down very slowly, but i cant hear a difference.
Edit: add picture
train_70000

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.