GithubHelp home page GithubHelp logo

wassname / recurrentshop Goto Github PK

View Code? Open in Web Editor NEW

This project forked from farizrahman4u/recurrentshop

1.0 2.0 0.0 65 KB

Framework for building complex recurrent neural networks with Keras

License: MIT License

Python 100.00%

recurrentshop's Introduction

Recurrent Shop

Framework for building complex recurrent neural networks with Keras

Ability to easily iterate over different neural network architectures is key to doing machine learning research. While deep learning libraries like Keras makes it very easy to prototype new layers and models, writing custom recurrent neural networks is harder than it needs to be in almost all popular deep learning libraries available today. One key missing feature in these libraries is reusable RNN cells. Most libraries provide layers (such as LSTM, GRU etc), which can only be used as is, and not be easily embedded in a bigger RNN. Writing the RNN logic itself can be tiresome at times. For example in Keras, information about the states (shape and initial value) are provided by writing two seperate functions, get_initial_states and reset_states (for stateful version). There are many architectures whose implementation is not trivial using modern deep learning libraries, such as:

  • Synchronising the states of all the layers in a RNN stack.
  • Initializing the hidden state of a RNN with the output of a previous layer.
  • Feeding back the output of the last layer of a RNN stack to the first layer in next time step (readout).
  • Decoders : RNNs who can look at the whole of the input sequence / vector at every time step.

Recurrent shop adresses these issues by providing a set of RNNCells, which can be added sequentially to a special layer called RecurrentContainer along with other layers such as Dropout and Activation, very similar to adding layers to Sequential model in Keras. The RecurrentContainer then behaves like a standard Keras Recurrent instance. In case of RNN stacks, the computation is done depth-first, which results in significant speed ups.

Writing the RNN logic itself has been simplified to a great extend. The user is only required to provide the step function and the shapes for the weights and the states. Default initialization for weights is glorot uniform. States are initialized by zeros, unless specified otherwise.


Writing a Simple RNN cell

 # This is only to demonstrate how easy it is to write a RNNCell.
 # See recurrentshop/recurrentshop/cells.py for a better version of SimpleRNNCell with more options.
 
 class SimpleRNNCell(RNNCell):
 
  def build(self, input_shape):
    input_dim = input_shape[-1]
    output_dim = self.output_dim
    h = (-1, output_dim)  # -1 = batch size
    W = (input_dim, output_dim)
    U = (output_dim, output_dim)
    b = (self.output_dim,)
   
      def step(x, states, weights):
        h_tm1 = states[0]
        W, U, b = weights
        h = K.dot(x, W) + K.dot(h_tm1, U) + b

    self.step = step
    self.weights = [W, U, b]
    self.states = [h]

    super(SimpleRNNCell, self).build(input_shape)

Recurrent container

rc = RecurrentContainer()
rc.add(SimpleRNNCell(10, input_dim=20))
rc.add(Activation('tanh'))

Stacking RNN cells

rc = RecurrentContainer()
rc.add(SimpleRNNCell(10, input_dim=20))
rc.add(SimpleRNNCell(10))
rc.add(SimpleRNNCell(10))
rc.add(SimpleRNNCell(10))
rc.add(Activation('tanh'))

State synchronization

# All cells will use the same state(s)

rc = RecurrentContainer(state_sync=True)
rc.add(SimpleRNNCell(10, input_dim=20))
rc.add(SimpleRNNCell(10))
rc.add(SimpleRNNCell(10))
rc.add(SimpleRNNCell(10))
rc.add(Activation('tanh'))

Readout

# Output of the final layer in the previous time step is available to the first layer(added to the input by default)

rc = RecurrentContainer(readout=True)
rc.add(SimpleRNNCell(10, input_dim=20))
rc.add(SimpleRNNCell(10))
rc.add(SimpleRNNCell(10))
rc.add(SimpleRNNCell(10))
rc.add(Activation('tanh'))

Decoder

# Here we decode a vector into a sequence of vectors. The input could also be a sequence, such as in the case of Attention models, where the whole input sequence is available to the RNN at every time step

# In this case, input to rc is a 2d vector, not a sequence

rc = RecurrentContainer(decode=True, output_length=10)
rc.add(SimpleRNNCell(10, input_dim=20))

LSTM and GRU

Recurrent Shop comes with LSTMCell and GRUCell built-in, which can be added to RecurrentContainers using the same API discussed above.

Finalizing your model

Once your RecurrentContainer is ready, you can add it to a Sequential model, or call it using functional API like any other layer:

model = Sequential()
# Add layers, if any
model.add(rc)
# Add more layers, if any
model.compile(loss='mse', optimizer='sgd')
a = Input((None, 20))
b = rc(a)
model = Model(a, b)
model.compile(loss='mse', optimizer='sgd')

Installation

git clone https://www.github.com/datalogai/recurrentshop.git
cd recurrentshop
python setup.py install

Contribute

Pull requests are highly welcome.

Need help?

Create an issue, with a minimal script to reproduce the problem you are facing.

Have questions?

Create an issue or drop me an email ([email protected]).

recurrentshop's People

Contributors

farizrahman4u avatar malaikannan avatar meain avatar wassname avatar

Stargazers

 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.