GithubHelp home page GithubHelp logo

ozzie00 / char-rnn Goto Github PK

View Code? Open in Web Editor NEW

This project forked from karpathy/char-rnn

0.0 3.0 0.0 495 KB

Multi-layer Recurrent Neural Networks (LSTM, GRU, RNN) for character-level language models in Torch

Lua 100.00%

char-rnn's Introduction

char-rnn

This code implements multi-layer Recurrent Neural Network (RNN, LSTM, and GRU) for training/sampling from character-level language models. The input is a single text file and the model learns to predict the next character in the sequence.

The context of this code base is described in detail in my blog post.

There is also a project page that has some pointers and datasets.

This code is based on Oxford University Machine Learning class practical 6, which is in turn based on learning to execute code from Wojciech Zaremba. Chunks of it were also developed in collaboration with my labmate Justin Johnson.

Requirements

This code is written in LUA and requires Torch.

Additionally, you need to install nngraph using LuaRocks

luarocks install nngraph

Usage

Data

All input data is stored inside the data/ directory. You'll notice that there is an example dataset included in the repo (in folder data/tinyshakespeare) which consists of a subset of works of Shakespeare. I'm providing a few more datasets on the project page.

Your own data: If you'd like to use your own data create a single file input.txt and place it into a folder in data/. For example, data/some_folder/input.txt. The first time you run the training script it will write two more convenience files into data/some_folder.

Note that if your data is too small (1MB is already considered very very small) the RNN won't learn very effectively. Remember that it has to learn everything completely from scratch. But if you insist on smaller datasets you might want to decrease the batch size a bit and do many more epochs (hundreds perhaps).

Training

Start training the model using train.lua, for example:

$ th train.lua -data_dir data/some_folder -gpuid -1

The -data_dir flag is most important since it specifies the dataset to use. Notice that in this example we're also setting gpuid to -1 which tells the code to train using CPU, otherwise it defaults to GPU 0. There are many other flags for various options. Consult $ th train.lua -help for comprehensive settings. Here's another example:

$ th train.lua -data_dir data/some_folder -rnn_size 512 -num_layers 2 -dropout 0.5

While the model is training it will periodically write checkpoint files to the cv folder. You can use these checkpoints to generate text:

Sampling

Given a checkpoint file (such as those written to cv) we can generate new text. For example:

$ th sample.lua cv/some_checkpoint.t7 -gpuid -1

Make sure that if your checkpoint was trained with GPU it is also sampled from with GPU, or vice versa. Otherwise the code will (currently) complain. As with the train script, see $ th sample.lua -help for full options. One important one is (for example) -length 10000 which would generate 10,000 characters (default = 2000).

Temperature. An important parameter you may want to play with a lot is -temparature, which takes a number in range (0, 1] (notice 0 not included), default = 1. The temperature is dividing the predicted log probabilities before the Softmax, so lower temperature will cause the model to make more likely, but also more boring and conservative predictions. Higher temperatures cause the model to take more chances and increase diversity of results, but at a cost of more mistakes.

Priming. It's also possible to prime the model with some starting text using -primetext.

Happy sampling!

Tips and Tricks

If you're somewhat new to Machine Learning or Neural Networks it can take a bit of expertise to get good models. The most important quantity to keep track of is the difference between your training loss (printed during training) and the validation loss (printed once in a while when the RNN is run on the validation data (by default every 1000 iterations)). In particular:

  • If your training loss is much lower than validation loss then this means the network is overfitting. Solutions to this are to decrease your network size, or to increase dropout. For example you could try dropout of 0.5 and so on.
  • If your training/validation loss are about equal then your model is underfitting. Increase the size of your model (either number of layers or the raw number of neurons per layer)

The winning strategy to obtaining very good models (if you have the compute time) is to always make a network as large as fits on your GPU and then try different dropout values (between 0,1). Whatever model has the best validation performance is the one you should use in the end. If your GPU is not very big you can start to sacrifice by making batch size smaller or after that by making the sequence length shorter.

License

MIT

char-rnn's People

Contributors

josephmmisiti avatar karpathy avatar maurizi avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.