GithubHelp home page GithubHelp logo

clemkoa / ntm Goto Github PK

View Code? Open in Web Editor NEW
46.0 3.0 7.0 702 KB

Neural Turing Machines in pytorch

License: MIT License

Python 99.20% Dockerfile 0.80%
neural-turing-machines deep-learning pytorch turing neural-networks

ntm's Introduction

ntm - Neural Turing Machines in pytorch

A Neural Turing Machines implementation in pytorch.

The goal was to implement a simple NTM with 1 read head and 1 write head, to reproduce the original paper's results.

Copy task

The copy task tests whether NTM can store and recall a long sequence of arbitrary information. The network is presented with an input sequence of random binary vectors followed by a delimiter flag. The target sequence is a copy of the input sequence. No inputs are presented to the model while it receives the targets, to ensure that there is no assistance.

The model is trained on sequences of 1 to 20 8-bit random vectors. In less than 50k iterations, the model usually becomes really accurate.

Here is the net output compared to the target for a sequence of 20.

Here is the net output compared to the target for a sequence of 100. Note that the network was only trained with sequences of 20 or less.

Here is an example (seed=1) of loss during training, with a batch size of 8.

Repeat copy task

As said in the paper, "the repeat copy task extends copy by requiring the network to output the copied sequence a specified number of times and then emit an end-of-sequence marker. [...] The network receives random-length sequences of random binary vectors, followed by a scalar value indicating the desired number of copies, which appears on a separate input channel. To emit the end marker at the correct time the network must be both able to interpret the extra input and keep count of the number of copies it has performed so far. As with the copy task, no inputs are provided to the network after the initial sequence and repeat number."

The model is trained on sequences of 1 to 10 8-bit random vectors, with a repeat between 1 and 10.

Here is the model output for a sequence of 10 and a repeat of 10.

Here it is for a sequence of 10 and a repeat of 20. Note that the network was trained with a repeat of 10 max.

Here it is for a sequence of 20 and a repeat of 10. Maybe it needs a bit more training here! Note that the network was trained on sequences of 10 or less.

Training on the repeat copy task takes substantially longer than the copy task. It usually takes at least 100k iterations to start seeing good results.

Usage

# installation
pip install -r requirements.txt
# to train
python copy_task.py --train
# to evaluate
python copy_task.py --eval

References

  1. Graves, Alex, Greg Wayne, and Ivo Danihelka. "Neural turing machines." arXiv preprint arXiv:1410.5401 (2014).
  2. https://github.com/loudinthecloud/pytorch-ntm/
  3. https://github.com/MarkPKCollier/NeuralTuringMachine

ntm's People

Contributors

clement-inventia avatar clemkoa avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

ntm's Issues

Could this implementation be use to solve the parity issue?

Thank you for the nice work!

I am wondering if your implementation can be used to classifying the parity (even or odd) of bitstring.
For example,
input: [0, 1, 0, 1, 0, 0] --> output 0
input: [0, 1, 0, 1, 0, 1] --> output 1

If it can, how can I modify this code? Could you please give me a example? Thank you very much.

Solution to a Potential Error in Windows

Following this installation in windows, It gave me the following error when I tried to run :
python copy_task.py --eval

For anybody experiencing this, a quick workaround is :
Paste this in ntm/__init__.py :

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

The better solution that worked for me was using python 3.8 and pytorch 1.7

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.