GithubHelp home page GithubHelp logo

mbrukman / audio-diffusion-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from archinetai/audio-diffusion-pytorch

0.0 1.0 0.0 146 KB

Audio generation using diffusion models, in PyTorch.

License: MIT License

Python 100.00%

audio-diffusion-pytorch's Introduction

Unconditional audio generation using diffusion models, in PyTorch. The goal of this repository is to explore different architectures and diffusion models to generate audio (speech and music) directly from/to the waveform. Progress will be documented in the experiments section.

Install

pip install audio-diffusion-pytorch

PyPI - Python Version

Usage

from audio_diffusion_pytorch import AudioDiffusionModel

model = AudioDiffusionModel()

# Train model with audio sources
x = torch.randn(2, 1, 2 ** 18) # [batch, channels, samples], 2**18 โ‰ˆ 12s of audio at a frequency of 22050
loss = model(x)
loss.backward() # Do this many times

# Sample 2 sources given start noise
noise = torch.randn(2, 1, 2 ** 18)
sampled = model.sample(
    noise=noise,
    num_steps=5 # Suggested range: 2-100
) # [2, 1, 262144]

Usage with Components

UNet1d

from audio_diffusion_pytorch import UNet1d

# UNet used to denoise our 1D (audio) data
unet = UNet1d(
    in_channels=1,
    patch_size=16,
    channels=128,
    multipliers=[1, 2, 4, 4, 4, 4, 4],
    factors=[4, 4, 4, 2, 2, 2],
    attentions=[False, False, False, True, True, True],
    num_blocks=[2, 2, 2, 2, 2, 2],
    attention_heads=8,
    attention_features=64,
    attention_multiplier=2,
    resnet_groups=8,
    kernel_multiplier_downsample=2,
    kernel_sizes_init=[1, 3, 7],
    use_nearest_upsample=False,
    use_skip_scale=True,
    use_attention_bottleneck=True,
    use_learned_time_embedding=True,
)

x = torch.randn(3, 1, 2 ** 16)
t = torch.tensor([0.2, 0.8, 0.3])

y = unet(x, t) # [3, 1, 32768], compute 3 samples of ~1.5 seconds at 22050Hz with the given noise levels t

Diffusion

Training

from audio_diffusion_pytorch import Diffusion, LogNormalDistribution

diffusion = Diffusion(
    net=unet,
    sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0),
    sigma_data=0.1,
    dynamic_threshold=0.95
)

x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples
loss = diffusion(x)
loss.backward() # Do this many times

Sampling

from audio_diffusion_pytorch import DiffusionSampler, KarrasSchedule

sampler = DiffusionSampler(
    diffusion,
    num_steps=5, # Suggested range 2-100, higher better quality but takes longer
    sampler=ADPM2Sampler(rho=1),
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0)
)
# Generate a sample starting from the provided noise
y = sampler(noise = torch.randn(1,1,2 ** 18))

Inpainting

from audio_diffusion_pytorch import DiffusionInpainter, KarrasSchedule, ADPM2Sampler

inpainter = DiffusionInpainter(
    diffusion,
    num_steps=5, # Suggested range 2-100, higher for better quality
    num_resamples=1, # Suggested range 1-10, higher for better quality
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
    sampler=ADPM2Sampler(rho=1.0),
)

inpaint = torch.randn(1,1,2 ** 18) # Start track, e.g. one sampled with DiffusionSampler
inpaint_mask = torch.randint(0,2, (1,1,2 ** 18), dtype=torch.bool) # Set to `True` the parts you want to keep
y = inpainter(inpaint = inpaint, inpaint_mask = inpaint_mask)

Infinite Generation

from audio_diffusion_pytorch import SpanBySpanComposer

composer = SpanBySpanComposer(
    inpainter,
    num_spans=4 # Number of spans to inpaint after provided input
)
y_long = composer(y, keep_start=True) # [1, 1, 98304]

Experiments

Report Snapshot Description
Alpha 6bd9279f19 Initial tests on LJSpeech dataset with new architecture and basic DDPM diffusion model.
Bravo a05f30aa94 Elucidated diffusion, improved architecture with patching, longer duration, initial good (unsupervised) results on LJSpeech.
Charlie 50ecc30d70 Train on music with YoutubeDataset, larger patch tests for longer tracks, inpainting tests, initial test with infinite generation using SpanBySpanComposer.
Delta (current) Test model with the faster ADPM2 sampler and dynamic thresholding.

TODO

  • Add elucidated diffusion.
  • Add ancestral DPM2 sampler.
  • Add dynamic thresholding.
  • Add (variational) autoencoder option to compress audio before diffusion.
  • Fix inpainting and make it work with ADPM2 sampler.

Appreciation

Citations

DDPM

@misc{2006.11239,
Author = {Jonathan Ho and Ajay Jain and Pieter Abbeel},
Title = {Denoising Diffusion Probabilistic Models},
Year = {2020},
Eprint = {arXiv:2006.11239},
}

Diffusion inpainting

@misc{2201.09865,
Author = {Andreas Lugmayr and Martin Danelljan and Andres Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
Title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
Year = {2022},
Eprint = {arXiv:2201.09865},
}

Diffusion weighted loss

@misc{2204.00227,
Author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo Kim and Sungroh Yoon},
Title = {Perception Prioritized Training of Diffusion Models},
Year = {2022},
Eprint = {arXiv:2204.00227},
}

Improved UNet architecture

@misc{2205.11487,
Author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and S. Sara Mahdavi and Rapha Gontijo Lopes and Tim Salimans and Jonathan Ho and David J Fleet and Mohammad Norouzi},
Title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
Year = {2022},
Eprint = {arXiv:2205.11487},
}

Elucidated diffusion

@misc{2206.00364,
Author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
Title = {Elucidating the Design Space of Diffusion-Based Generative Models},
Year = {2022},
Eprint = {arXiv:2206.00364},
}

audio-diffusion-pytorch's People

Contributors

flavioschneider avatar ouhenio avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.