GithubHelp home page GithubHelp logo

archinetai / audio-diffusion-pytorch Goto Github PK

View Code? Open in Web Editor NEW
1.8K 1.8K 157.0 251 KB

Audio generation using diffusion models, in PyTorch.

License: MIT License

Python 100.00%
artificial-intelligence audio-generation deep-learning denoising-diffusion

audio-diffusion-pytorch's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

audio-diffusion-pytorch's Issues

What loss function is being used?

Hi! Working through getting a working version of this set up locally and running it on some of my own data. The paradigm I see with a lot of models is something along the lines of:

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
# taken from https://pytorch.org/docs/stable/optim.html

Where an output is returned from the model, and then the loss computed with a loss function, before we call a optimizer.step() function. However, in all of the examples you provide, the sample is something like:

loss = model(input)
loss.backward()

What loss function is being used here? Is it being inherited from a-unet or one of the parent classes somewhere along the line? TIA.

How to use our own background noisy dataset to generate sample?

Thank you for sharing this awesome works!

How can I load our own dataset (for example: dataset with noisy background noise) into the prertrained diffusion model? It is possible to use pytorch dataloader?

Then use sampler. Finally get the sample output.

Thank you!!!

Typo in Paper

Great work! In section 5.2 of the paper there is a typo. I think you meant weight not wight

image

Pre-trained Weights of AutoEncoder

Hi, it seems that the checkpoint files in huggingface are all used for class Model1d, I wonder if there exist any checkpoints available for the DiffusionAutoencoder1d class to perform the latent encoding?

How to just train condition audio-diffusion without text-condition?

This is a wonderful project!
If I have a zero _ shot data set, it contains a one-dimensional feature vector ( size [ 2048 ] ) and its corresponding one-dimensional attribute ( size ). I would like to ask whether one-dimensional attribute can be directly used instead of text to guide the generation .
Thank you very much!

Future Work - Models

Hi!

I am very curious about the future work part of the paper.

There were a few suggestions in the paper. Let me talk about two.

1. Use perceptual losses.

You have just merged a PR that allows for loss customization. Which perceptual loss did you have in mind when you wrote the suggestion?

2. Using mel spectrograms instead of magnitude spectrograms as input.

dmae1d-ATC64-v2 Uses the magnitude spectrogram.

What would be a good mel feature extractor?

I sometimes ran into this one but I would like to know what you think about it:

encoder=MelE1d( # The encoder used, in this case a mel-spectrogram encoder
                in_channels=in_channels,
                channels=512,
                multipliers=[1, 1],
                factors=[2],
                num_blocks=[12],
                out_channels=32,
                mel_channels=80,
                mel_sample_rate=48000,
                mel_normalize_log=True,
                bottleneck=TanhBottleneck(),
            ),

I believe it extracts a lot of features, thus putting a strain on the GPU.

Curious what you have to say about 1 and 2.

Cheers,
Tristan

custom dataset

Hi,

is there any api can be used to register new dataset and model?

New Try

Is it possible to generate another similar audio by using a piece of audio and text information as a condition?
(This is similar to the style transformation in computer vision)

Question: Scaling guide/suggested parameters?

Hello! I'm in the process of training a model on a top-40s dataset using your library. However, I want to experiment with long-term consistency, so I've scaled sample rate/channels accordingly to fit ~90s windows during training. I think my results could be improved by further scaling up the number of model parameters, but I'm not sure what to change and by what ratios to get the most bang for my buck/VRAM/compute. Do you guys have a "scaled" config you could share or a general guide (e.g., 2X attention heads, 1.5X mults) for this? Thanks!

AssertionError: ClassiferFreeGuidancePlugin requires embedding

Hi, I test the example you gave for conditioning on text, but got error:

# Train model with audio waveforms
audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
loss = model(
    audio_wave,
    text=['The audio description'], # Text conditioning, one element per batch
    embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask)
)
loss.backward()

# Turn noise into new audio sample with diffusion
noise = torch.randn(1, 2, 2**18)
sample = model.sample(
    noise,
    text=['The audio description'],
    embedding_scale=5.0, # Higher for more text importance, suggested range: 1-15 (Classifier-Free Guidance Scale)
    num_steps=2 # Higher for better quality, suggested num_steps: 10-100
)

Error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[49], line 3
      1 # Train model with audio waveforms
      2 audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
----> 3 loss = model(
      4     audio_wave,
      5     text=['The audio description'], # Text conditioning, one element per batch
      6     embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask)
      7 )
      8 loss.backward()
     10 # Turn noise into new audio sample with diffusion

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/audio_diffusion_pytorch/models.py:40, in DiffusionModel.forward(self, *args, **kwargs)
     39 def forward(self, *args, **kwargs) -> Tensor:
---> 40     return self.diffusion(*args, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/audio_diffusion_pytorch/diffusion.py:93, in VDiffusion.forward(self, x, **kwargs)
     91 v_target = alphas * noise - betas * x
     92 # Predict velocity and return loss
---> 93 v_pred = self.net(x_noisy, sigmas, **kwargs)
     94 return F.mse_loss(v_pred, v_target)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:63, in Module.<locals>.Module.forward(self, *args, **kwargs)
     62 def forward(self, *args, **kwargs):
---> 63     return forward_fn(*args, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:594, in TimeConditioningPlugin.<locals>.Net.<locals>.forward(x, time, features, **kwargs)
    592 # Merge time features with features if provided
    593 features = features + time_features if exists(features) else time_features
--> 594 return net(x, features=features, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:63, in Module.<locals>.Module.forward(self, *args, **kwargs)
     62 def forward(self, *args, **kwargs):
---> 63     return forward_fn(*args, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:534, in ClassifierFreeGuidancePlugin.<locals>.Net.<locals>.forward(x, embedding, embedding_scale, embedding_mask_proba, **kwargs)
    526 def forward(
    527     x: Tensor,
    528     embedding: Optional[Tensor] = None,
   (...)
    531     **kwargs,
    532 ):
    533     msg = "ClassiferFreeGuidancePlugin requires embedding"
--> 534     assert exists(embedding), msg
    535     b, device = embedding.shape[0], embedding.device
    536     embedding_mask = fixed_embedding(embedding)

AssertionError: ClassiferFreeGuidancePlugin requires embedding

Is it about dependencies ?
What dependencies am I supposed to install ?

P.s. can you please show two simple colab examples:

  • to train on own wav files
  • use pretrained networks and finetune on own wav files

I am trying to understand how to condition on text to validate research idea in bioacoustics, but not have a strong foundations to well understand yet your code, so a tutorial would be really helpful.

Unconditional Generation generates noise

Hi,

I'm training on a dataset of songs, and I was training with this package. After about 10 epochs (of 1000 samples each) the loss seems to converge, however after I sample I get pure noise. My intuition is even if the model is converging to a local minima, or I've not trained for enough time, it still should be producing some output (garbage in garbage out should still produce something other than pure noise). Thus I'm led to believe that there's an issue with the way I'm generating the audio. I've attached my code below.

Any suggestions, or anything more I need to provide?

def generate_samples(model, num_samples, sample_rate, audio_length_seconds, device):
    """
    Generate audio samples from the trained diffusion model.

    :param model: The trained diffusion model.
    :param num_samples: The number of audio samples to generate.
    :param sample_rate: The sample rate of the audio.
    :param audio_length_seconds: The length of the audio to generate, in seconds.
    :param device: The device ('cpu' or 'cuda') to run the sampling on.
    :return: A tensor containing the generated audio samples.
    """
    audio_length = sample_rate * audio_length_seconds
    # Initialize with random noise
    noise = torch.randn(num_samples, 1, audio_length, device=device)

    model.eval() 

    with torch.no_grad():  
        samples = model.sample(noise, num_steps=100)  

    return samples

waveform, sample_rate = torchaudio.load(audio_path)
# Example usage after training the model:
num_samples = 1  # Number of samples to generate
#sample_rate = dataset.sample_rate  # Sample rate of the audio
audio_length_seconds = 20  # Length of the audio to generate, in seconds
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generate samples
generated_audio = generate_samples(model, num_samples, sample_rate, audio_length_seconds, device)

# Save the generated samples as FLAC files
for i, audio_tensor in enumerate(generated_audio):
    filename = f"generated_sample_{i+1}.flac"
    torchaudio.save(filename, audio_tensor.cpu(), sample_rate)
    print(f"Saved: {filename}")

I have a few questions about 1D-UNet

The code shows that a-unet is used to construct the unet, but looking at the a-unet, the unet is constructed in a nested-like structure. So, does this unet have middle blocks other than the encoder and decoder parts as used in other diffsuion models? What is the unet without middle blocks?

Exploding loss

The loss suddenly increases from <0.1 to billions over one or two epochs.

I'm training an AudioDiffusionModel and I've had happen with both the default diffusion_type='v' as well as with diffusion_type='vk', also, it happens both with and without gradient clipping. It's happened with several datasets and different batch sizes (the output below is a particularly small dataset with a large batch size)

It seems to happen more often, the closer it gets to 0 loss.

Output:

1328 Loss : 0.0562
100% 6/6 [00:01<00:00,  3.93it/s]
1329 Loss : 0.0517
100% 6/6 [00:01<00:00,  3.95it/s]
1330 Loss : 0.0500
100% 6/6 [00:01<00:00,  3.95it/s]
1331 Loss : 0.0374
100% 6/6 [00:01<00:00,  3.93it/s]
1332 Loss : 0.0519
100% 6/6 [00:01<00:00,  3.69it/s]
1333 Loss : 0.0557
100% 6/6 [00:01<00:00,  3.47it/s]
1334 Loss : 0.0499
100% 6/6 [00:01<00:00,  3.33it/s]
1335 Loss : 0.0482
100% 6/6 [00:01<00:00,  3.74it/s]
1336 Loss : 1.4608
100% 6/6 [00:01<00:00,  3.89it/s]
1337 Loss : 35551447.3009
100% 6/6 [00:01<00:00,  3.91it/s]
1338 Loss : 17436217794.0833
100% 6/6 [00:01<00:00,  3.86it/s]
1339 Loss : 15120838197.3333
100% 6/6 [00:01<00:00,  3.88it/s]
1340 Loss : 1137136360.0000
100% 6/6 [00:01<00:00,  3.83it/s]
1341 Loss : 184102040.6667
100% 6/6 [00:01<00:00,  3.80it/s]
1342 Loss : 24171988.5000
100% 6/6 [00:01<00:00,  3.85it/s]
1343 Loss : 100907.1549
100% 6/6 [00:01<00:00,  3.80it/s]
1344 Loss : 10494.4541
100% 6/6 [00:01<00:00,  3.83it/s]
1345 Loss : 989.2273

The model:

class DiffModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AudioDiffusionModel(in_channels=1, diffusion_type='vk', diffusion_sigma_distribution=VKDistribution())
        self.optimizer = torch.optim.AdamW(list(self.model.parameters()))

    def train(self, x):
        self.optimizer.zero_grad()

        loss = self.model(x)
        loss.backward()

        clip_grad_norm_(self.model.parameters(), 1.)

        self.optimizer.step()

        return loss.item()
    ...

Training:

for epoch in range(load_epoch + 1, MAX_EPOCHS):
    acc_loss = 0
    for x in tqdm(dataloader):
        x = x.to(device)
        acc_loss += model.train(x)
    loss = acc_loss / epoch_steps
    print(f'{epoch} Loss : {loss:.4f}')
    ...

NaN after training for a while

Hi!

I'm having an issue training with the basic model provided in the README. After training on the LibriSpeech dataset for about 20 epochs, I start getting NaN losses returned from the model, and when sampling and saving to a file I just get silent audio randomly.

I had a go at debugging but couldn't really find the issue, other than the first NaN in the forward pass I could find was the input to the ResNet block. Not sure if this is helpful, but I've added my debug output here: output.log. The prints were just me testing in various forward functions whether the inputs and outputs were NaN but I didn't isolate any lines.

My training script is pretty short and I'm not doing anything particularly weird that would cause this I don't think! You can have a look here: https://github.com/jameshball/audio-diffusion/blob/master/train.py

Also here's a snipper of my output from training where it turns to nan loss: nan.txt

I should also be able to follow up with a google drive link to download the checkpoint so you can test it more easily - you might need to modify train.py to remove some wandb calls functions and just load the checkpoint from disk but should be straightforward. Alternatively, I also get NaNs when just sampling from the model here: https://github.com/jameshball/audio-diffusion/blob/master/sample.py

Please let me know if there's anything I can help with as this would be great to fix!

James

Unconditional model generates okay quality of fake human voice but failed on music.

Hi, I've been playing with this diffusion model library for a few days, it is great to have such library that allows common users to train audio data with limited resources.

I have a problem with regard to the training data and the output. I fed the unconditional model with mozilla's common voice dataset. I used only one language and the size is about 15k. I resampled them to 44.1k and padded them to 2^18 samples per file if shorter. And the unconditional results were okay, at least I could tell it's human speaking although never actually audible.

But when I replace the training data with music (mostly pure pianos, same sample rate but 2^17 samples per input tensor), the model is not generating outputs that sounds like piano, in fact they are mostly noise.

I used the same configurations for each layers for both models, tried lowering the downsampling factors or increase attentions heads, but no significant difference. Any tips on why my problem happens?

Error Locating Target

Hello,

I upgraded the trainer and audio diffusion to the latest releases. I am now getting this error when trying to run experiments:

[2022-10-19 09:47:41,538][main][INFO] - Instantiating model <main.module_base.Model>.
Error executing job with overrides: ['exp=base_youtube_l_3.yaml']
Error locating target 'audio_diffusion_pytorch.VDistribution', see chained exception above.
full_key: model.model.diffusion_sigma_distribution

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

I deleted the conda environment and reinstalled all of the requirements from scratch, and I am still getting the above. Any help would be appreciated.

Thanks,
MP

RuntimeError: The size of tensor a (37) must match the size of tensor b (36) at non-singleton dimension 2

I got an error when I try to train model.
It is occured on MergeModulate and is caused by mismatch in two tensor size.
Would you suggest any idea to resolve it?

1. input
batch_size:64
audio shape: (1,19200)
text embedding: T5

2. model

def create_model():
    return DiffusionModel(
        net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
        in_channels=1, # U-Net: number of input/output (audio) channels
        channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
        # channels=[12, 36, 72, 144, 288, 576, 576, 1152, 1152], # U-Net: channels at each layer
        factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
        items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
        attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
        attention_heads=8, # U-Net: number of attention heads per attention item
        attention_features=64, # U-Net: number of attention features per attention item
        diffusion_t=VDiffusion, # The diffusion method used
        sampler_t=VSampler, # The diffusion sampler used
        use_text_conditioning=True, # U-Net: enables text conditioning (default T5-base)
        use_embedding_cfg=True, # U-Net: enables classifier free guidance
        embedding_max_length=64, # U-Net: text embedding maximum length (default for T5-base)
        embedding_features=768, # U-Net: text mbedding features (default for T5-base)
        cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer
        use_time_conditioning=True,
    )

3.error
TimeConditioningPlugin features: torch.Size([64, 1024]) time_features: torch.Size([64, 1024])
TextConditioningPlugin x torch.Size([64, 1, 19200]) text_embotting torch.Size([64, 64, 768])
MergeModulate x: torch.Size([64, 1024, 18]) y: torch.Size([64, 1024, 18]) features: torch.Size([64, 1024])
MergeModulate x: torch.Size([64, 512, 37]) y: torch.Size([64, 512, 36]) features: torch.Size([64, 1024])
0%| | 0/18 [00:02<?, ?it/s]
Traceback (most recent call last):
File "/home/jovyan/meta-dataset/audio-diffusion-pytorch-trainer/tests/test_train_cond_dac.py", line 311, in
main()
File "/home/jovyan/meta-dataset/audio-diffusion-pytorch-trainer/tests/test_train_cond_dac.py", line 254, in main
loss = model(audio, text=info["text"], embedding_mask_proba=0.1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/meta-dataset/audio-diffusion-pytorch/audio_diffusion_pytorch/models.py", line 41, in forward
return self.diffusion(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/meta-dataset/audio-diffusion-pytorch/audio_diffusion_pytorch/diffusion.py", line 94, in forward
v_pred = self.net(x_noisy, sigmas, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 63, in forward
return forward_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 596, in forward
return net(x, features=features, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 63, in forward
return forward_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 624, in forward
return net(x, embedding=text_embedding, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 63, in forward
return forward_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 553, in forward
return net(x, embedding=embedding, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/apex.py", line 431, in forward
return self.net(x, features, embedding, channels) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/apex.py", line 382, in forward
x = self.block(x, features, embedding, channels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 77, in forward
x = block(x, *args)
^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/apex.py", line 382, in forward
x = self.block(x, features, embedding, channels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 77, in forward
x = block(x, *args)
^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/apex.py", line 382, in forward
x = self.block(x, features, embedding, channels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 77, in forward
x = block(x, *args)
^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/apex.py", line 382, in forward
x = self.block(x, features, embedding, channels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 77, in forward
x = block(x, *args)
^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/apex.py", line 382, in forward
x = self.block(x, features, embedding, channels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 77, in forward
x = block(x, *args)
^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/apex.py", line 382, in forward
x = self.block(x, features, embedding, channels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 77, in forward
x = block(x, *args)
^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/apex.py", line 382, in forward
x = self.block(x, features, embedding, channels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 77, in forward
x = block(x, *args)
^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/apex.py", line 383, in forward
x = self.skip(skip, x, features)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 63, in forward
return forward_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/miniconda3/envs/mg/lib/python3.11/site-packages/a_unet/blocks.py", line 433, in forward
return x + scale * y
~~^~~~~~~~~~~
RuntimeError: The size of tensor a (37) must match the size of tensor b (36) at non-singleton dimension 2

Add support to clip predicted samples to the desired range.

In diffusion it is common to want to clip samples to a desired range like [-1, 1], I think previous versions of this package supported this. However, the current implementation does not support this.

I think it would be useful to support clipping samples to a desired range.

VSampler

  def forward(..., clip_denoised: bool = False, dynamic_threshold: float = 0.0) -> Tensor:
    ...
    x_pred = alphas[i] * x_noisy - betas[i] * v_pred 
    # Add clipping support here 
    if clip_denoised:
      clip(x_pred, dynamic_threshold=dynamic_threshold)
    ...

I am happy to open a PR if this is acceptable.

Trained models

In the abstract of the paper, you make mention of providing trained models: " In addition to trained models, we provide a collection of open-source libraries", but I couldn't find them anywhere. Is there any possibility of sharing those models?

Add trainer

Hey there, I have been following this project pretty closely...looks great. Could you share the lightning trainer you're using here/any associated scripts for training?

I've put together my own trainer with accelerate which is working fine, but would be nice to work out of the same one as you here for reproducibility sake.

Am I training the model correctly?

hello, I am new to neural network models, I would like to ask if I am training the model correctly?
here is the part of the code

model = AudioDiffusionModel(in_channels=1).to("cuda")
optimizer = Adam(model.parameters(),lr=0.0001)

for i in range(epochs):
  for x in iter(Data):
    loss = model(x)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    loss_history.append(loss)

Spectrogram-based diffusion model

Thanks for your contribution to this repository! I wonder if we can utilize this repository to develop a diffusion model based on spectrograms instead of waveforms. While implementing, I discovered that the UNetV0 has a dim=2 option that allows for the use of 2D-CNN in spectrograms. However, there seem to be some discrepancies in the hyperparameters of UNetV0 that lead to an error. It's a bit hard for me to debug since it heavily relies on  a-unet. Below I'll give more context.

Suppose we have a spectrogram whose shape is [1, 1, 256, 512] # [B, C, T, F].

Here's the model architecture I used:

return DiffusionModel(
        net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
        dim=2, # for spectrogram we use 2D-CNN
        in_channels=1, # U-Net: number of input/output (audio) channels
        channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
        factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
        items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
        attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
        attention_heads=8, # U-Net: number of attention heads per attention item
        attention_features=64, # U-Net: number of attention features per attention item
        diffusion_t=VDiffusion, # The diffusion method used
        sampler_t=VSampler, # The diffusion sampler used
        embedding_features=512, # U-Net: embedding features
        cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer 
    )

Here's the output error:

Traceback (most recent call last):                                                                                                                                                                                                                                                                       File "/data/tinglok/texture/ldm.py", line 164, in <module>                                                                                                                                                                                                                                               main()                                                                                                                                                                                                                                                                                               File "/data/tinglok/texture/ldm.py", line 105, in main                                                                                                                                                                                                                                                   loss = model(audio, embedding=cond_embed)                                                                                                                                                                                                                                                            File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/audio_diffusion_pytorch/models.py", line 40, in forward                                                                                                                                                                          return self.diffusion(*args, **kwargs)                                                                                                                                                                                                                                                               File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py", line 93, in forward                                                                                                                                                                       v_pred = self.net(x_noisy, sigmas, **kwargs)                                                                                                                                                                                                                                                         File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 63, in forward                                                                                                                                                                                           return forward_fn(*args, **kwargs)                                                                                                                                                                                                                                                                   File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 594, in forward                                                                                                                                                                                      
    return net(x, features=features, **kwargs)                                                                                                                                                                                                                                                         
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                        
    return forward_call(*input, **kwargs)                                                                                                          
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 431, in forward                                                                                                                                                                                        
    return self.net(x, features, embedding, channels)  # type: ignore                                                                              
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                        
    return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                              
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                        
    x = self.block(x, features, embedding, channels)                                                                                                                                                                                                                                                   
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                        
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward                                                                                                                                                                                           x = block(x, *args)                                                                                                                                                                                                                                                                                  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                            x = self.block(x, features, embedding, channels)                                                                                                                                                                                                                                                     File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward                                                                                                                                                                                           x = block(x, *args)                                                                                                                                                                                                                                                                                  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                            x = self.block(x, features, embedding, channels)                                                                                                                                                                                                                                                     File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward                                                                                                                                                                                           x = block(x, *args)                                                                                                                                                                                                                                                                                  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                            x = self.block(x, features, embedding, channels)                                                                                                                                                                                                                                                     File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward                                                                                                                                                                                           x = block(x, *args)                                                                                                                                                                                                                                                                                  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                            x = self.block(x, features, embedding, channels)                                                                                                                                                                                                                                                     File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward                                                                                                                                                                                           x = block(x, *args)                                                                                                                                                                                                                                                                                
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward
    x = self.block(x, features, embedding, channels)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward
    x = block(x, *args)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward
    x = self.block(x, features, embedding, channels)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward
    x = block(x, *args)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 92, in forward
    return self.block(*args_fn(*args), **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Calculated padded input size per channel: (1 x 1). Kernel size: (2 x 2). Kernel size can't be greater than actual input size

Can we resolve this error by disabling some downsampling of UNetV0?

Question: the sigma_t is not samped from 0 to 1 in v-diffusion, which is not like your thesis mentioned, will it cause any trouble?

The sigma_t is not samped from 0 to 1 in v-diffusion, which is not like your thesis mentioned, will it cause any trouble?

By sampling a random σt ∈ [0,1], we are more likely to pick a value that resembles x x x0 instead of pure noise ε meaning that the model will more often see data with smaller amount of noise

Model architectures from the paper

Hi!

I am still having a lot of fun with this repo 🤗

I had a closer look at the paper and realized that I am training a different architecture. I will use the architecture from the paper next.

I am not 100% confident that I read the paper right. I this the 185M parameters autoencoder from section 5.2 of the Moûsai paper?

model = DiffusionAE(
            encoder=MelE1d( # The encoder used, in this case a mel-spectrogram encoder
                in_channels=in_channels,
                channels=512,
                multipliers=[1, 1],
                factors=[2],
                num_blocks=[12],
                out_channels=32,
                mel_channels=80,
                mel_sample_rate=48000,
                mel_normalize_log=True,
                bottleneck=TanhBottleneck(),
            ),
            inject_depth=4,
            net_t=UNetV0, # The model type used for diffusion upsampling
            in_channels=in_channels, # U-Net: number of input/output (audio) channels
            channels=[256, 512, 512, 512, 1024, 1024, 1024], # U-Net: channels at each layer
            factors=[1, 2, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
            items=[1, 2, 2, 2, 2, 2, 2], # U-Net: number of repeating items at each layer
            diffusion_t=VDiffusion, # The diffusion method used
            sampler_t=VSampler, # The diffusion sampler used
        )

Cheers,
Tristan

Alternative Noises: Offset, Pyramid, Pink

I've been seeing some promising results from using alternative noise methods to teach the model to adjust the lower frequency components of an input, since pure randn noise is mostly high frequency content and Stable Diffusion (and possibly other diffusion models trained on randn noise) learned to create image with the same average and can't make brighter or darker images. When sampling it appears to use normal randn noise for offset and pyramid, not certain for pink.

With offset noise it learns to shift the output up or down more. It's a very small change to the noise generation for training: noise = torch.randn_like(latents) + 0.1 * torch.randn(latents.shape[0], latents.shape[1], 1, 1)

With pyramid noise the input is more evenly masked across different frequencies, rather than just high frequency content. The noise is generated by scaling a low resolution noise up to a random scale (they wanted to avoid always doing 2x upscale), adding more noise after upsampling, and repeating. The code they use is given in the article, Ctrl+F for def pyramid_noise_like(x, discount=0.9):.

With pink noise (EleutherAI Discord message link) I'm not 100% sure on the benefit. It's apparently closer to the noise found in images so it seems to make sense for image generation, but perhaps it'll be good for audio too.

In case you can't open the Discord link, the code provided by crowsonkb / alstroemeria313 is

import math
from dctorch import functional as DF
import torch

def sqrtm(x):
    vals, vecs = torch.linalg.eigh(x)
    return vecs * vals.sqrt() @ vecs.T

def colored_noise(shape, power=2.0, mean=None, color=None, device='cpu', dtype=torch.float32):
    mean = torch.zeros([shape[-3]]) if mean is None else mean
    color = torch.eye(shape[-3]) if color is None else color
    f_h = math.pi * torch.arange(shape[-2], device=device, dtype=dtype) / shape[-2]
    f_w = math.pi * torch.arange(shape[-1], device=device, dtype=dtype) / shape[-1]
    freqs_sq = (f_h[:, None] ** 2 + f_w[None, :] ** 2)
    freqs_sq[..., 0, 0] = freqs_sq[..., 0, 1]
    spd = freqs_sq ** -(power / 2)
    spd /= spd.mean()
    noise = torch.randn(shape, device=device, dtype=dtype)
    noise = torch.einsum('...chw,cd->...dhw', noise, color.to(device, dtype))
    noise = DF.idct2(noise * spd.sqrt())
    noise = noise + mean.to(device, dtype)[..., None, None]
    return noise

Ideally this will help with generating lower frequency components in audio.

RuntimeError: The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2

It seems like all my batches have some underlying issue where they're all off by one, I've seen other issues opened about this, but no proper explanation, could I get some help on this?

Failed during forward The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2

Verified text and wavs are both the batch size (16), all wavs are padded in this case to 84480.

RuntimeError: The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                                
Failed during forward The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                        
The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                                              
torch.Size([16, 84480]) 16                                                                                                                  
Traceback (most recent call last):                                                                                                          
  File "/mnt/nvme/programs/qTTS/train_ttv_v1.py", line 184, in train_and_evaluate                                                           
    loss_gen_all = net_g(                                                                                                                   
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward                                                                                                                                                                                 
    else self._run_ddp_forward(*inputs, **kwargs)                                                                                           
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward                                                                                                                                                                        
    return self.module(*inputs, **kwargs)  # type: ignore[index]                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/audio_diffusion_pytorch/models.py", line 40, in forward                                                                                                                                                                                  
    return self.diffusion(*args, **kwargs)                                                                                                  
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/audio_diffusion_pytorch/diffusion.py", line 93, in forward                                                                                                                                                                               
    v_pred = self.net(x_noisy, sigmas, **kwargs)                                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                      
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 594, in forward                                                                                                                                                                                                  
    return net(x, features=features, **kwargs)                                                                                              
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl           
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                      
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 621, in forward                                                     
    return net(x, embedding=text_embedding, **kwargs)                                                                                       
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                                                                                                                                                                   
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 552, in forward                                                                                                                                                                                                  
    return net(x, embedding=embedding, **kwargs)                                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/apex.py", line 431, in forward                                                                                                                                                                                                    
    return self.net(x, features, embedding, channels)  # type: ignore                                                                       
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                                    
    x = self.block(x, features, embedding, channels)                                                                                        
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 77, in forward                                                      
    x = block(x, *args)                                                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                              

Followed the example from the README:

  net_g = DiffusionModel(
      net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
      in_channels=1, # U-Net: number of input/output (audio) channels
      channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
      factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
      items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
      attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
      attention_heads=8, # U-Net: number of attention heads per attention item
      attention_features=64, # U-Net: number of attention features per attention item
      diffusion_t=VDiffusion, # The diffusion method used
      sampler_t=VSampler, # The diffusion sampler used
      use_text_conditioning=True, # U-Net: enables text conditioning (default T5-base)
      use_embedding_cfg=True, # U-Net: enables classifier free guidance
      embedding_max_length=64, # U-Net: text embedding maximum length (default for T5-base)
      embedding_features=768, # U-Net: text mbedding features (default for T5-base)
      cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer
  )

Support usage with non-audio data e.g spectrograms

I am trying to use the package to work with spectrograms, but I have encountered a problem. Some of the operations in the package are only designed to work with 3-d tensors, which limits their usability.

Request

I would like to request a change to make these operations more generic, so that they can be used with spectrograms (or any other data that may not necessarily be 3-d tensors). This would enable more users to use the package for a wider range of applications, and improve the overall usability of the package.

Examples

To illustrate the issue and the desired change, I have provided some examples below.

Sequential mask generation

The sequential_mask operation generates a boolean mask for a tensor. The original version of the operation is shown below:

def sequential_mask(like: Tensor, start: int) -> Tensor:
    length, device = like.shape[2], like.device
    mask = torch.ones_like(like, dtype=torch.bool)
    mask[:, :, start:] = torch.zeros((length - start,), device=device)
    return mask

To make this operation more generic, we could change the third dimension (dim=2) to the last dimension (dim=-1). This would allow the operation to work with any tensor, regardless of its shape. The revised version of the operation would look like this:

def sequential_mask(like: Tensor, start: int) -> Tensor:
    length, device = like.shape[-1], like.device
    mask = torch.ones_like(like, dtype=torch.bool)
    mask[..., start:] = torch.zeros((length - start,), device=device)
    return mask

I am happy to contribute, to address these issues.

Questions about conditional generation

Hi!

I have worked with unconditional generation using this fine repo. It is a lot of fun! I will do latent diffusion next. I am already looking forward to it.

Text conditional generation promises a lot of fun. I have a few questions.

  • In the README, in the conditional section, we can read "Text conditioning, one element per batch", this means "one text per waveform" and thus "a batch of texts for a batch of waveforms", right? Not "one text for a batch of waveforms"?

  • I believe latent diffusion and text conditioning to be orthogonal. Is it safe to assume that DiffuserAE would work with text conditioning by just adding the right kwargs?

  • What would be necessary in order to replace the T5 embeddings with something else?

  • What would be the consequences of extending the number of tokens for T5?

This is so cool!

Best,
Tristan

Using the audio_975 model with colab fails

Hello,

I'm trying to get the colab to run with the new larger model available here:
https://huggingface.co/archinetai/audio-diffusion-pytorch/resolve/main/audio_975.pt

I've modified the colab to use the latest git code:
!pip install -e git+https://github.com/archinetai/audio-diffusion-pytorch@main#egg=audio-diffusion-pytorch

However loading the latest model fails with
Can't get attribute 'CrossEmbed1d' on <module 'audio_diffusion_pytorch.modules' from '/usr/local/lib/python3.7/dist-packages/audio_diffusion_pytorch/modules.py'>

Is there a way to get the new model to run?

Languages

Hello. The software looks great. Thank you. I'm trying to play and understand it. I have a question: are languages other than English supported? I'm not really sure.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.