GithubHelp home page GithubHelp logo

zelaki / dreamsound Goto Github PK

View Code? Open in Web Editor NEW
25.0 3.0 4.0 60.87 MB

Code for Investigating Personalization Methods in Text to Music Generation

Home Page: https://zelaki.github.io/

Python 100.00%
text-to-music audiodiffusion dreambooth textual-inversion music-style-transfer audioldm

dreamsound's Introduction

Text-to-Music Personalization

arXiv githubio

Recently, text-to-music generation models have achieved unprecedented results in synthesizing high-quality and diverse music samples from a given text prompt. Despite these advances, it remains unclear how one can generate personalized, user-specific musical concepts, manipulate them, and combine them with existing ones. For example, can one generate a rock song using their personal guitar playing style or a specific ethnic instrument? Motivated by the computer vision literature, we investigate text-to-music personalization by exploring two established methods, namely Textual Inversion and DreamBooth.

  • Release code!

  • Example code for training and evaluation

  • DreamBooth with AudioLDM2

  • Gradio app!

  • Release code for Personalized Style Transfer

Install the dependencies and download AudioLDM:

Use python 3.10.13

pip install -r requirements.txt
git clone https://huggingface.co/cvssp/audioldm-m-full

You need Git Large File Storage for cloning the huggingface model.

Training Examples

DreamBooth:

To train the personalization methods for e.g. a short collection of guitar recordings, you can chose "guitar" or "string instrument" as a class, and a not commonly used word like "sks" as an instance word.

export MODEL_NAME="audioldm-m-full"
export DATA_DIR="path/to/concept/audios"
export OUTPUT_DIR="path/to/output/dir"
accelerate launch dreambooth_audioldm.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATA_DIR \
--instance_word="sks" \
--object_class="guitar" \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=300 \
--learning_rate=1.0e-06 \
--output_dir=$OUTPUT_DIR \
--num_vectors=1 \
--save_as_full_pipeline 

Textual Inversion:

In textual inversion you just need to specify a placeholder token, that can be any rearely used string. Here we use "" as a placeholder token.

export MODEL_NAME="audioldm-m-full"
export DATA_DIR="path/to/concept/audios"
export OUTPUT_DIR="path/to/output/dir"
accelerate launch textual_inversion_audioldm.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATA_DIR \
  --learnable_property="object" \
  --placeholder_token="<guitar>" \
  --validation_prompt="a recording of a <guitar>" \
  --initializer="mean" \
  --initializer_token="" \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --max_train_steps=3000 \
  --learning_rate=5.0e-04 --scale_lr \
  --output_dir=$OUTPUT_DIR \

Example Inference

For Textual inversion, you have to use the placeholder token used in training in the prompts, after loading the learned embeddings in the base model. For Dreambooth, you have to load the fine-tuned model and use [instance word] [class-word] in the inference prompt.

from pipeline.pipeline_audioldm import AudioLDMPipeline


#Textual Inversion

pipe = AudioLDMPipeline.from_pretrained("audioldm-m-full", torch_dtype=torch.float16).to("cuda")
learned_embedding = "path/to/learnedembedding"
prompt = "A recording of <guitar>"
pipe.load_textual_inversion(learned_embedding)
waveform = pipe(prompt).audios

#DreamBooth
pipeline = AudioLDMPipeline.from_pretrained("path/to/dreambooth/model", torch_dtype=torch.float16).to("cuda")
prompt = "A recording of a sks guitar"
waveform = pipe(prompt).audios

AudioLDM2 DreamBooth

To train AudioLDM2 DreamBooth:

export MODEL_NAME="cvssp/audioldm2"
export DATA_DIR="dataset/concepts/oud"
export OUTPUT_DIR="oud_ldm2_db_string_instrument"
accelerate launch dreambooth_audioldm2.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATA_DIR \
--instance_word="sks" \
--object_class="string instrument" \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=300 \
--learning_rate=1.0e-05 \
--output_dir=$OUTPUT_DIR \
--validation_steps=50 \
--num_validation_audio_files=3 \
--num_vectors=1 \

And for inference:

from pipeline.pipeline_audioldm2 import AudioLDM2Pipeline

pipeline = AudioLDM2Pipeline.from_pretrained("path/to/dreambooth/model", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

prompt="a recording of a sks string instrument"
waveform=pipeline(prompt,num_inference_steps=50,num_waveforms_per_prompt=1,audio_length_in_s=5.12).audios[0]

Citation

If you use this code please cite:

@article{plitsis2023investigating,
  title={Investigating Personalization Methods in Text to Music Generation},
  author={Plitsis, Manos and Kouzelis, Theodoros and Paraskevopoulos, Georgios and Katsouros, Vassilis and Panagakis, Yannis},
  journal={arXiv preprint arXiv:2309.11140},
  year={2023}
}

Acknowledgments

This code is heavily based on AudioLDM and Diffusers.

dreamsound's People

Contributors

manosplitsis avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

dreamsound's Issues

Feature request: Add AudioLDM_48Khz to DreamSound implementation

Hello,

Thanks for this great repo, been having a lot of fun with it.

I'm wondering if it's possible to implement the newer audioldm_48khz checkpoint for finetuning?
It seems it currently does not work with the standard Diffusers pipeline, but can be used using the usual AudioLDM2 inference.

Thanks!

Best,
M

questions about the dataset

Dear author,
I have a few question on the dataset:

  1. I wonder if there YouTube id for Obama and Trump.
    2.the YouTube id of agni_parthene is no longer valid.
    3.the YouTube id of reggae is BSJRvDR9RhQ or r982EVUnIc4?
    Thanks

questions about the hyper parameter of the DB setting

Dear authors,
You claimed in the paper that "using a single NVIDIA RTX-3090 GPU with a training batch size of 4, employing learning rates of 4 ร— 10โˆ’6 for DB", is that correct?
I found it slow on my V100GPU under this hyper parameter setting(two and an half hour for training a single concept),my gradient accumulation steps is set to 1, and max_train_steps is 1500, I would appreciate it if you could help me

prior preservation loss in Dreambooth

Just want to know that, are the result in the paper using prior preservation loss in Dreambooth? I saw your code have prior preservation loss, but when I run it in dreambooth_audioldm.py, it works, but in dreambooth_audioldm2.py there are some bugs.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.