GithubHelp home page GithubHelp logo

moon0316 / t2a Goto Github PK

View Code? Open in Web Editor NEW
81.0 5.0 11.0 140.24 MB

Project page for "Improving Few-shot Learning for Talking Face System with TTS Data Augmentation" for ICASSP2023

Python 97.59% Shell 2.41%

t2a's Introduction

Improving Few-shot Learning for Talking Face System with TTS Data Augmentation

Statements

  • This repository is only used for academic research, any commercial use is prohibited.
  • The copyright of digital human presented in our demo is reserved by SMT.

Acknowledgements

  • Thanks to Shanghai Media Tech(SMT) for providing the data set and rendering service.
  • We use pre-trained HuBERT model from this repository.
  • We use implementation of soft-DTW loss from this repository.
  • We use implementation of Transformer from this repository

Thanks to the authors of above repositories.

Demos

TTS Data Augmentation

contrast_TTSaug_record.mp4
contrast_TTSaug_Obama.mp4
contrast_TTSaug_News_f.mp4
contrast_TTSaug_News_m.mp4

TTS-driven Talking Face

contrast_T2A_record.mp4
contrast_T2A_Obama.mp4
contrast_T2A_News_f.mp4
contrast_T2A_News_m.mp4

Different Audio Features

contrast_features.mp4

Different Loss Functions

contrast_loss.mp4

Different Data Resources

contrast_resource.mp4

Pre-trained model and tools preparation

Download pre-trained HuBERT model

The pre-trained HuBERT model is obtained from this repository.

Please download Chinese HuBERT model and put it on directory ./data/pretrained_models/ by executing the following command:

wget -P ./data/pretrained_models/ https://huggingface.co/TencentGameMate/chinese-hubert-large/resolve/main/chinese-hubert-large-fairseq-ckpt.pt

Download fairseq tool

git clone [email protected]:facebookresearch/fairseq.git
cd fairseq
git checkout acd9a53
pip install --editable ./
cd ..
cp hubert.py ./fairseq/fairseq/models/hubert/

Feature extraction

Extract HuBERT feature

python utils/generate_hubert.py --input_dir ./data/wavs/[speaker name] --output_dir ./data/wav_features/[speaker name]

Extract MFCC feature

python utils/generate_mfcc.py --input_dir ./data/wavs/[speaker name] --output_dir ./data/wav_features/[speaker name]

Train

run bash train.sh to train

important arguments for main.py

  • arch: chinese_hubert_large | mfcc | pgg
  • feature_combine: True if you want to use weighted sum of hubert feature
  • output_path: "result" if you want to generate output of test set | [other name] if you want to generate other data
  • test_input_path: you should explicitly assign path of test_input_path if output_path != "result", test_input_path is the dir of csv files
  • test_epoch: do not need to explicitly assign, will find the model with best
  • root_dir: dir of dataset root
  • feature_dir: hubert_large | mfcc | ppg
  • train_speaker_list: assign several speaker names for training
  • train_json: used to change data resource, path of json file which includes list of audio name in training set
  • freq: 50 if feature is chinese_hubert_large or ppg , 100 if feature is mfcc
  • input_dim: 39 for mfcc, 128 for ppg

Validate

run bash validate.sh to pick the best model by validating on validation set of certain speaker, change --val_speaker to decide speaker for validation.

Test

run bash test.sh to test

Citation

@article{chen2023improving,
  title={Improving Few-Shot Learning for Talking Face System with TTS Data Augmentation},
  author={Chen, Qi and Ma, Ziyang and Liu, Tao and Tan, Xu and Lu, Qu and Yu, Kai and Chen, Xie},
  booktitle={ICASSP 2022-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
  year={2023}
}

t2a's People

Contributors

ddlbojack avatar moon0316 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

t2a's Issues

time seqenence len diff

您好,在使用您的代码时发现,hubert处理后的24层数据的fps是150fps,而您给的标签是25fps的,我看到您打代码中有rate除,但是发现最终结果predict和label的时间长度不一致,我的疑问是dtw做为loss时要保证时间长度不一致吗?输入的150fps的特征是正确的吗?

sdtw loss error

您好,您的工作效果非常棒。
我在试着训练您提供的网络时,softdtw loss 总是会报如下错误:
numba.cuda.cudadrv.driver.CudaAPIError: [1] Call to cuLaunchKernel results in CUDA_ERROR_INVALID_VALUE
报错位置在:sdtw_cuda_loss.py文件的
compute_softdtw_cuda[B, threads_per_block](cuda.as_cuda_array(D.detach()),
gamma.item(), bandwidth.item(), N, M, n_passes,
cuda.as_cuda_array(R))
如果我的训练batch_size是1,在迭代35个batch后,会报这个错误
如果我的训练batchsize 是8,在迭代到第6个batch后,会报这个错误
通过Google后,我怀疑是softdtw loss需要的cuda线程太多,我的gpu无法支持。但是由于不太懂softdtw loss部分的代码,不知道该如何修改。我的gpu是nvidia geforce rtx 3080. 显存10g
image

How to train the model with 30/60fps blendshape label

当前我们的blendshape标签数据是30/60fps,而不是25fps, 所以真实标签和模型输出完全对不上。想适配你们模型进行训练, 我尝试修改预训练模型hubert label_rate=60HZ, 但好像预训练模型修改不了。博主有什么更好的适配策略呢?

新手求助

input_values = torch.stack(input_values_new).squeeze(2) #12, 251, 768

这里我跑起来会出错,false少了一层维度,目的是生成类似[[-0.2113, 0.1069, 0.1805, ..., -0.0475, -0.1105, -0.1938],[False]]这样为一行的数据吗。
hubert提取音频特征加入的这个False有什么用处吗,是不是可以不要这个直接做

Audio stream inference

当前作者模型大多是拿一段长音频去推理的,作者是否考虑过流式推理。例如 我想一帧帧的输出。我尝试过输入模型语音块200ms,语音块一次移动40ms,效果是很抖动。关于流式推理,作者有什么建议没?

关于generate_hubert

您好,非常感谢您的工作
但是我在运行生成hubert特征时,使用您提供的A1.wav数据,出现了下述问题

Traceback (most recent call last):
File "utils/generate_hubert.py", line 76, in
handle_wav(args.input_dir, os.path.join(args.output_dir, model_short_alias))
File "utils/generate_hubert.py", line 54, in handle_wav
torch.save(logits[2], os.path.join(saved_to_path, wav_file_dir, wav_file.split(".")[0] +".pt"))
IndexError: tuple index out of range

该问题导致最后生成的A1.pt在运行test.sh时因维度不匹配出现异常
请问这里的logits[2]是否需要修改?
谢谢

pretrained model

can you share the model for testing?

and i want to train a new model myself, how to make the train dataset ?

What is included in the blendshape pkl ?
i know there are audio_data, vertice......

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.