GithubHelp home page GithubHelp logo

hubert-base-jax's Introduction

HuBERT Base JAX Implementation

An implementation of HuBERT Base in JAX. (I want to be able to pretrain and finetune HuBERT efficiently on TPUs)

This repository is a work in progress and is not yet complete.

To Do List:

  • Build the model for inference
  • Map and import weights from bshall/hubert:main
  • Add padding mask
  • Test pretrained model ABX on LibriSpeech
  • Add masking strategy
  • Build dataset prepare scripts and loader
  • Build trainer module
  • Test pretraining on LibriSpeech dataset (single GPU)
  • Add LoRA
  • Test LoRA finetuning for phoneme recognition
  • Extend training to multiple TPUs with data parallelism
  • Clean up code and add documentation

This repository is based on the following work:

  • Benji van Niekerk's stripped down implementation of HuBERT Base and easily accessible weights.
  • Phillip Lippe's tutorial on building a transformer in JAX.
  • The HuBERT paper.
  • The fairseq repo.

Installation

Install JAX for your system by following these instructions. For example, for CUDA 12.0, you can run the following command:

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

I also use PyTorch on the CPU to use their datasests and dataloaders as well as loading the weights from the PyTorch checkpoint:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

Install other requirements

pip install -r requirements.txt

Note: If you want to compute the ABX score, I recommend installing and using zrc_abx2 in a separate environment.

Results

ABX results for each layer on LibriSpeech when using the weights from bshall/hubert:main.

Layer Index ABX
Within
Within
ABX
Any
Within
ABX
Within
Across
ABX
Any
Across
0 6.15
1 6.03
2 5.13
3 4.20
4 3.41
5 2.77 10.04 3.67 10.74
6 2.38 9.75 3.10 10.33
7 2.32 10.24 3.20 10.81
8 2.39 10.23 3.16 10.76
9 1.97 8.77 2.74 9.20
10 1.91 8.52 2.60 9.12
11 2.12 8.79 2.94 9.34

hubert-base-jax's People

Contributors

nicolvisser avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.