GithubHelp home page GithubHelp logo

anthonydickson / learning2write Goto Github PK

View Code? Open in Web Editor NEW
4.0 0.0 0.0 1.28 MB

Teaching a neural network how to write letters and digits with reinforcement learning.

License: GNU General Public License v3.0

Python 96.68% Shell 3.32%
ai machine-learning reinforcement-learning openai-gym-environment acktr ppo ppo2 emnist mnist letters

learning2write's Introduction

Learning2write

This repo contains code for teaching a neural network based reinforcement learning agent how to write characters (see Figure 1 below).

Fig 1. Example of an ACKTR reinforcement learning agent trained on 5x5 patterns for about 44 million steps.

An example of a trained agent in a 5x5 environment

For training reinforcement learning agents I use stable-baselines, and for the environment I use my own custom gym environment. The environment provided has three sets of patterns (mostly letters and digits) that an agent can be trained on:

  • A set of simple 3x3 patterns
  • A set of 5x5 patterns
  • Letters and digits from EMNIST dataset (Extended MNIST dataset). This is essentially MNIST with both digits and letters (see Figure 2 below).

Fig 2. Sample of images from the EMNIST dataset. Sample of EMNIST Images

The goal is for the agent to fill in the squares in a grid to reproduce the pattern that it has been presented as accurately as possible.

Getting started

There is a convenience script setup.sh which assumes a UNIX-based system and automates most of the setup process. If you use this script then restart your terminal after successfully running it (to get conda set up correctly) and skip to step 3.

  1. Install the required system packages:

    sudo apt-get update && sudo apt-get install cmake libopenmpi-dev python3-dev zlib1g-dev unzip xvfb python-opengl

    See the prerequisites section of stable-baselines.readthedocs.io for instructions for other operating systems.

  2. Set up the python environment using conda:

    conda env create -f environment.yml

    or if you are not using conda, then make sure you have a python environment set up with all of the packages listed in the file environment.yml.

  3. Activate the conda environment:

    conda activate learning2write
  4. If you want to train an agent on images from the EMNIST dataset then you will need it handy. You can acquire the dataset by running the following:

    mkdir emnist_data
    cd emnist_data
    wget http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip
    unzip gzip.zip
    mv gzip/* .
    rmdir gzip
    cd ..

    The download is about 550MB.

  5. Train a model:

    python train.py -steps 1000000

    Training the agent for more steps usually provides better results :) An ACKTR agent requires roughly 5 ~ 10 million steps for 3x3 patterns, and 10 ~ 20 million steps for 5x5 patterns. Not sure for the EMNIST dataset, but probably a lot more :|

  6. Test a previously trained model:

    python test.py models/acktr_mlp_5x5.pkl acktr

    This opens a window that displays the environment and the reference/target pattern on the left, and the agent's drawing and its location (the big red dot).

    There are a couple of pretrained models in the models/ directory.

  7. You can see the help text for these scripts by adding the flag -h or --help.

learning2write's People

Contributors

anthonydickson avatar

Stargazers

 avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.