GithubHelp home page GithubHelp logo

isabella232 / keyword-transformer Goto Github PK

View Code? Open in Web Editor NEW

This project forked from arm-software/keyword-transformer

0.0 0.0 0.0 237.03 MB

License: Apache License 2.0

Shell 0.69% Python 49.31% Jupyter Notebook 50.01%

keyword-transformer's Introduction

Keyword Transformer: A Self-Attention Model for Keyword Spotting

drawing

This is the official repository for the paper Keyword Transformer: A Self-Attention Model for Keyword Spotting, presented at Interspeech 2021. Consider citing our paper if you find this work useful.

@inproceedings{berg21_interspeech,
  author={Axel Berg and Mark O’Connor and Miguel Tairum Cruz},
  title={{Keyword Transformer: A Self-Attention Model for Keyword Spotting}},
  year=2021,
  booktitle={Proc. Interspeech 2021},
  pages={4249--4253},
  doi={10.21437/Interspeech.2021-1286}
}

Setup

Download Google Speech Commands

There are two versions of the dataset, V1 and V2. To download and extract dataset V2, run:

wget https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz
mkdir data2
mv ./speech_commands_v0.02.tar.gz ./data2
cd ./data2
tar -xf ./speech_commands_v0.02.tar.gz
cd ../

And similarly for V1:

wget http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz
mkdir data1
mv ./speech_commands_v0.01.tar.gz ./data1
cd ./data1
tar -xf ./speech_commands_v0.01.tar.gz
cd ../

Install dependencies

Set up a new virtual environment:

pip install virtualenv
virtualenv --system-site-packages -p python3 ./venv3
source ./venv3/bin/activate

To install dependencies, run

pip install -r requirements.txt

Tested using Tensorflow 2.4.0rc1 with CUDA 11.

Note: Installing the correct Tensorflow version is important for reproducibility! Using more recent versions of Tensorflow results in small accuracy differences each time the model is evaluated. This might be due to a change in how the random seed generator is implemented, and therefore changes the sampling of the "unknown" keyword class.

Model

The Keyword-Transformer model is defined here. It takes the mel scale spectrogram as input, which has shape 98 x 40 using the default settings, corresponding to the 98 time windows with 40 frequency coefficients.

There are three variants of the Keyword-Transformer model:

  • Time-domain attention: each time-window is treated as a patch, self-attention is computed between time-windows
  • Frequency-domain attention: each frequency is treated as a patch self-attention is computed between frequencies
  • Combination of both: The signal is fed into both a time- and a frequency-domain transformer and the outputs are combined
  • Patch-wise attention: Similar to the vision transformer, it extracts rectangular patches from the spectrogram, so attention happens both in the time and frequency domain simultaneously.

Training a model from scratch

To train KWT-3 from scratch on Speech Commands V2, run

sh train.sh

Please note that the train directory (given by the argument --train_dir) cannot exist prior to start script.

The model-specific arguments for KWT are:

--num_layers 12 \ #number of sequential transformer encoders
--heads 3 \ #number of attentions heads
--d_model 192 \ #embedding dimension
--mlp_dim 768 \ #mlp-dimension
--dropout1 0. \ #dropout in mlp/multi-head attention blocks
--attention_type 'time' \ #attention type: 'time', 'freq', 'both' or 'patch'
--patch_size '1,40' \ #spectrogram patch_size, if patch attention is used
--prenorm False \ # if False, use postnorm

Training with distillation

We employ hard distillation from a convolutional model (Att-MH-RNN), similar to the approach in DeIT.

To train KWT-3 with hard distillation from a pre-trained model, run

sh distill.sh

Run inference using a pre-trained model

Pre-trained weights for KWT-3, KWT-2 and KWT-1 are provided in ./models_data_v2_12_labels.

Model name embedding dim mlp-dim heads depth #params V2-12 accuracy pre-trained
KWT-1 64 128 1 12 607K 97.7 here
KWT-2 128 256 2 12 2.4M 98.2 here
KWT-3 192 768 3 12 5.5M 98.7 here

To perform inference on Google Speech Commands v2 with 12 labels, run

sh eval.sh

Acknowledgements

The code heavily borrows from the KWS streaming work by Google Research. For a more detailed description of the code structure, see the original authors' README.

We also exploit training techniques from DeiT.

We thank the authors for sharing their code. Please consider citing them as well if you use our code.

License

The source files in this repository are released under the Apache 2.0 license.

Some source files are derived from the KWS streaming repository by Google Research. These are also released under the Apache 2.0 license, the text of which can be seen in the LICENSE file on their repository.

keyword-transformer's People

Contributors

axeber01 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.