GithubHelp home page GithubHelp logo

aevnmt.pt's Introduction

Auto-Encoding Variational Neural Machine Translation (PyTorch)

This repository contains a PyTorch implementation of our Auto-Encoding Variational Neural Machine Translation paper published at the 4th Workshop on Representation Learning for NLP (RepL4NLP). Note that the results in the paper are based on a TensorFlow implementation.

Installation

You will need python3.6 or newer:

virtualenv -p python3.6 ~/envs/aevnmt.pt
source ~/envs/aevnmt.pt/bin/activate

You will need an extension to torch distributions which you can install easily:

git clone https://github.com/probabll/dists.pt.git
cd dists.pt
pip install -r requirements.txt
pip install .
cd ..

git clone https://github.com/probabll/dgm.pt.git
cd dgm.pt
pip install -r requirements.txt
pip install .

Then you will need AEVNMT.pt:

git clone https://github.com/Roxot/AEVNMT.pt.git 
cd AEVNMT.pt
pip install -r requirements.txt

For developers, we recommend

pip install --editable .

For other users, we recommend

pip install .

Command line interface

python -u -m aevnmt.train \
    --hparams_file HYPERPARAMETERS.json \ 
    --training_prefix BILINGUAL-DATA \
    --validation_prefix VALIDATION-DATA \ 
    --src SRC --tgt TGT \
    --output_dir OUTPUT-DIRECTORY

python -u -m aevnmt.translate \
    --output_dir OUTPUT-DIRECTORY \
    --verbose true \
    --translation_input_file INPUT \
    --translation_output_file TRANSLATION \
    --translation_ref_file REFERENCE

There is also a sentence VAE mode:

python -u -m aevnmt.train_monolingual \
    --hparams_file HYPERPARAMETERS.json \ 
    --training_prefix BILINGUAL-DATA \
    --validation_prefix VALIDATION-DATA \ 
    --src SRC  \
    --output_dir OUTPUT-DIRECTORY

python -u -m aevnmt.generate  \
    --output_dir OUTPUT-DIRECTORY \
    --verbose true \
    --translation_output_file SAMPLED_TEXT \
    --decoding.sample true --translation.num_prior_samples 100
    

Demos

See some example training and translation scripts, and a demo notebook.

Experiments

Multi30k English-German

  • Development: only de-BPE'd outputs
Model English-German German-English
Conditional 40.1 43.5
AEVNMT 40.9 43.4
  • Test: post-processed
Model English-German German-English
Conditional 38.0 40.9
AEVNMT 38.5 40.9

aevnmt.pt's People

Contributors

eelcovdw avatar goncalomcorreia avatar roxot avatar vitaka avatar vmsanchez-dlsi avatar wilkeraziz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

aevnmt.pt's Issues

initialization.initialize_model does not support different cell_types

The RNN initialization in initialize_models uses different parameters for cell_type=lstm, and initializes all params with "rnn." in the parameter name. The new hparam format supports different cell types for inference, tm and lm, which causes a problem with this method.

Possible solution: Split initialize_model to initialize_tm, initialize_lm, initialize_inf

new hparams need more modularity

While the new hparams are much better than before, it still gets very verbose. One reason is that we have one argparser for any model and evaluation, which results in a lot of unused+confusing args.

I've added support for adding/removing arguments with arg groups (see bottom of args.py). These could be used to construct different argparsers, based on which model is used in which context.

Possible solution: Each model in the library gets its own arg groups with model-specific arguments, which is combined with train/eval specific arg groups and (if needed) other general arguments. These can already be combined by the existing argparser in hparams/hparams.py

Testing functionalities for base components

For the library to grow, it really needs some testing functionality.

  • Some parts of the code are out of date and will not work with the current codebase (Like aux. likelihoods), but it is not always clear what works and what doesnt.
  • (unit)testing would speed up development a lot, as going back and bugfixing parts of the code would be much less common, and updating older components would be easier.
  • flickr tests have to be updated (can just copy the iwslt configs), and some other benchmarks would be nice (IWSLT14 might be too large for a quick test benchmark)

Switch multi-GPU training to DistributedDataParallel

The current implementation of distributed training is done with DataParallel, which has some drawbacks.

  • Pytorch is switching to DistributedDataParallel which is much better, but requires us to rewrite most training code (and possibly some model parts).
  • DataParallel is not compatible with RNN cells, and can give worse results with the Transformer (due to normalization?).
  • My suggestion: Rewrite training code (Possibly with Pytorch-lighting), build it around DistributedDataParallel.

Have prior.size arg support multiple sizes, and remove prior.sizes

When prior.sizes is defined, prior.size is not needed (since sum(prior.sizes) == prior.size).

Make prior.size behave the same way as prior.params; Multiple semi-colon separated values, which gets converted to a list of ints by the argument parser. Changes need to be applied in the arg parser, and in the code base wherever hparams.prior.size / hparams.prior.sizes is used.

Training scripts need refactoring

  • Currently our training scripts are a bit of a mess, because we combine multiple models and different diagnostics (tensorboard etc).
  • Gonçalo suggests pytorch lightning (https://github.com/PyTorchLightning/pytorch-lightning). Seems really nice
  • Suggestion Wilker for experiment management: wandb. This might not be compatible with lightning.
  • Tensorboard could be more flexible: Let the Trainer object decide what to save during training for each model, not the train.py script.
  • Why are we using TensorboardX and not pytorch tensorboard?

Vocabulary not reading all lines from vocab file

When translating, Vocabulary will import the vocab file with from_file(). This will open the file with the following encoding: ISO-8859-2. When this is removed, the vocab file is correctly read.

Expand generative model feed_z options

feed_z arguments (for example, gen.tm.dec.feed_z) are now boolean which mirrors the RNN implementations, but the Transformer architectures support multiple options.

Additionally, We could also support feed_z methods for TM encoders: RNN architectures always initialize with z as hidden state, but for Transformers there are multiple options.

Hyperparameters requires hparams_file as first argument

Jsonargparse changes its behaviour depending on the index of the --hparams_file argument; If any command line arguments are before --hparams_file, they are overridden by the contents of hparams_file.

This can cause issues that are hard to track, and the preferred behavior is that the command line argument always have precedence over config file arguments.

I've added a temporary solution by checking for the first argument in hparams.py, but a better solution would be to somehow remove this feature from jsonargparse.

Transformer training needs more improvements to catch up to SOTA

Transformer training still needs some improvements:

  • Naom training schedule does not work great. Fairseq seems to have another schedule that works better (sqrt)
  • Batching based on number of tokens instead of number of sentences would let us use larger batch sizes on average
  • Possibly another optimizer than Adam.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.