GithubHelp home page GithubHelp logo

rexbalaeniceps / jm-nat Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lemmonation/jm-nat

0.0 0.0 0.0 362 KB

Code for ACL2020 "Jointly Masked Sequence-to-Sequence Model for Non-Autoregressive Neural Machine Translation"

License: MIT License

Python 99.16% C++ 0.33% Shell 0.09% Lua 0.41%

jm-nat's Introduction

JM-NAT

Code for our ACL 2020 paper, "Jointly Masked Sequence-to-Sequence Model for Non-Autoregressive Neural Machine Translation". Please cite our paper if you find this repository helpful in your research:

@inproceedings{guo2020jointly,
    title = {Jointly Masked Sequence-to-Sequence Model for Non-Autoregressive Neural Machine Translation},
    author = {Guo, Junliang and Xu, Linli and Chen, Enhong},
    booktitle = {Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics},
    year = {2020},
    publisher = {Association for Computational Linguistics},
    pages = {376--385},
}

Requirements

The code is based on fairseq-0.6.2, PyTorch-1.2.0 and cuda-9.2.

Training Steps

To train a non-autoregressive machine translation model, please follow the three steps listed below:

  • Firstly, follow the instructions in fairseq to train an autoregressive model.
  • Generate distilled target samples by the autoregressive model, i.e., set --gen-subset train while decoding.
  • Train our model on the distilled training set. For example, on the IWSLT14 De-En task:
python train.py $DATA_DIR \
  --task xymasked_seq2seq \
  -a transformer_nat_ymask_pred_len_deep_small --share-all-embeddings \
  --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
  --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr '1e-07' \
  --lr 0.0007 --min-lr '1e-09' \
  --criterion label_smoothed_length_cross_entropy --label-smoothing 0.1 \
  --weight-decay 0.0 --max-tokens 4096 --max-update 500000 --mask-source-rate 0.1

While inference, our model utilizes similar decoding algorithm proposed in Mask-Predict, and we use the average of last 10 checkpoints to obtain the results:

python generate.py $DATA_DIR \
  --task xymasked_seq2seq --path checkpoint_aver.pt --mask_pred_iter 10 \
  --batch-size 64 --beam 4 --lenpen 1.1 --remove-bpe

jm-nat's People

Contributors

lemmonation avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.