GithubHelp home page GithubHelp logo

dmtn's Introduction

Dynamic Memory Tensor Networks in Theano

The project is forked from https://github.com/YerevaNN/Dynamic-memory-networks-in-Theano The aim of this repository is to implement Dynamic Memory Tensor Networks, besides the Dynamic memory networks covered in the parent.

DMTN as described in ย https://arxiv.org/abs/1703.03939 Ramachandran, Govardana Sachithanandam, and Ajay Sohmshetty. "Ask Me Even More: Dynamic Memory Tensor Networks (Extended Model)." arXiv preprint arXiv:1703.03939 (2017).

Orginally published as http://cs224d.stanford.edu/reports/SohmshettyRamachandran.pdf Sohmshetty, Ajay, and Govardana Sachithanandam Ramachandran. "Ask Me Even More: Dynamic Memory Tensor Networks (Extended Model)." http://cs224d.stanford.edu/reports_2016.html (June 2016)

Abstract: We examine Memory Networks for the task of question answering (QA), under common real world scenario where training examples are scarce and under weakly supervised scenario, that is only extrinsic labels are available for training. We propose extensions for the Dynamic Memory Network (DMN), specifically within the attention mechanism, we call the resulting Neural Architecture as Dynamic Memory Tensor Network (DMTN). Ultimately, we see that our proposed extensions results in over 80% improvement in the number of task passed against the baselined standard DMN and 20% more task passed compared to state-of-the-art End-to-End Memory Network for Facebook's single task weakly trained 1K bAbi dataset.

dmtncomparison

[Table:1]Accuracies across all tasks for MemN2N, DMN, and DMTN. Here DMN baselines serves as the baseline for DTMN to measure the lift with the proposed changes. DMN best* is the best document performance of DMN with optimal hyperparameter tuning on bAbi weakly trained dataset- http://yerevann.github.io/2016/02/05/implementing-dynamic-memory-networks

The above results are obtained by using following Hyper-parameter was used between DMN baseline and DMTN. Please note that due to lack of time & resource Hyper-parameter tunning was not done, Hence we recommend you to play with Hyper-parameter for even better results

hyper_parameter

[Table:2]Hyperparameters used for DMN baseline and DMTN.

DMN as described in the paper by Kumar et al. and to experiment with its various extensions.

Pretrained models on bAbI tasks can be tested online.

We will cover the process in a series of blog posts.

Repository contents

file description
main.py the main entry point to train and test available network architectures on bAbI-like tasks
dmn_basic.py our baseline implementation. It is as close to the original as we could understand the paper, except the number of steps in the main memory GRU is fixed. Attention module uses T.abs_ function as a distance between two vectors which causes gradients to become NaN randomly. The results reported in this blog post are based on this network
dmn_smooth.py uses the square of the Euclidean distance instead of abs in the attention module. Training is very stable. Performance on bAbI is slightly better
dmtn.py DMTN implementaion
dmn_batch.py dmn_smooth with minibatch training support. The batch size cannot be set to 1 because of the Theano bug
dmn_qa_draft.py draft version of a DMN designed for answering multiple choice questions
utils.py tools for working with bAbI tasks and GloVe vectors
nn_utils.py helper functions on top of Theano and Lasagne
fetch_babi_data.sh shell script to fetch bAbI tasks (adapted from MemN2N)
fetch_glove_data.sh shell script to fetch GloVe vectors (by 5vision)
server/ contains Flask-based restful api server

Usage

This implementation is based on Theano and Lasagne. One way to install them is:

pip install -r https://raw.githubusercontent.com/Lasagne/Lasagne/master/requirements.txt
pip install https://github.com/Lasagne/Lasagne/archive/master.zip

The following bash scripts will download bAbI tasks and GloVe vectors.

./fetch_babi_data.sh
./fetch_glove_data.sh

Use main.py to train a network:

python main.py --network dmtn --babi_id 1

The states of the network will be saved in states/ folder. There is one pretrained state on the 1st bAbI task. It should give 100% accuracy on the test set:

python main.py --network dmtn --mode test --babi_id 1 --load_state states/dmn_basic.mh5.n40.babi1.epoch4.test0.00033.state

Server

If you want to start a server which will return the predication for bAbi tasks, you should do the following:

  1. Generate UI files as described in YerevaNN/dmn-ui
  2. Copy the UI files to server/ui
  3. Run the server
cd server && python api.py

If have Docker installed, you can pull our Docker image with ready DMN server.

docker pull yerevann/docker
docker run --name dmn_1 -it --rm -p 5000:5000 yerevann/dmn

Roadmap

  • Mini-batch training (done, 08/02/2016)
  • Web interface (done, 08/23/2016)
  • Visualization of episodic memory module (done, 08/23/2016)
  • Regularization (work in progress, L2 doesn't help at all, dropout and batch normalization help a little)
  • Support for multiple-choice questions (work in progress)
  • Evaluation on more complex datasets
  • Import some ideas from Neural Reasoner

License

The MIT License (MIT) Copyright (c) 2016 YerevaNN

dmtn's People

Contributors

rgsachin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

dmtn's Issues

ui for this project

Hi,

I think ui cannot use for this project, it cannot find network info.json

If it is right, please tell me

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.