GithubHelp home page GithubHelp logo

bmistry4 / rims Goto Github PK

View Code? Open in Web Editor NEW
0.0 1.0 0.0 42 KB

Reimplementation of the paper Recurrent Independent Mechanisms (https://openreview.net/pdf?id=mLcmdlEUxy-).

License: MIT License

Python 96.15% Shell 3.85%

rims's Introduction

Recurrent Independent Mechanisms (RIMs)

Reimplementation of Recurrent Independent Mechanisms (Goyal et al. 2021) using Pytorch. This project was done purely for learning purposes.

RIMs are a RNN based architecture that learns to modularise the input dynamics into independent mechanisms to improve generalisation and modelling of long-term dependencies. These mechanisms are reusable modules (where each module is represented by a subset of a LSTM’s weight matrices). The modules are independent from each other, only interacting sparsely via attention. Specifically,

  1. The RIMs will compete for access to the input, from which only a subset (the top-k) will be selected.
  2. The selected RIMs update their knowledge with respect to the input.
  3. The RIMs communicate via information sharing. Only the top-k RIMs are allowed to access information from the other RIMs.

Implementation notes

Our implementation take influences from the official RIMs repo and an implementation by dido1998.

The experiments are logged on W&B and the code is implemented using the Pytorch Lightning framework. The torchtyping library is used to annotate tensor types and shapes in function signatures. Unit test have been created to do initial sanity checks for the model.

Installation

To clone the repo and create a conda environment with the relevant libraries installed run:

git clone https://github.com/bmistry4/rims.git
cd rims
conda create --name rim-env --file requirements.txt

Copy task

Example models to run the copying task can be found in jobs/copying.sh.

Architecture notes

Rims

  • Analogous to rnn_models_wiki in original code
  • Deals with looping over layers (imagine a stacked LSTM) and looping over the timesteps.
  • layers:list[RimsCell]

RimsCell

  • RimsCell ~ LSTMCell so the 'Rims' handle multiple rims in the cell.
  • Analogous to BlocksCore in original code
  • Does the steps for: input attn, indep dyn, comm attn
  • Contains a generic Cell class which is extended depending on the type of implementation you want.
    • Currently 2 implementations

      1. original paper way (BlockCell) which calls blockify,
      2. our way (BatchedCell) which uses a batch * |RIMs| dimension and does BMM.
      3. (was also a third option of a list of Cells which update in a for loop, but this would bee too expensive wrt time complexity)

      Both implementations should work on both GRU and LSTM

Cell

  • Highest level abstraction for a container representing a collection of recurrent cells
  • Cell(input:Tensor, states:list/tuple) -> states:list/tuple
  • state: assume hidden sate is at index 0, and all states will apply the same type of masking as the hidden state

Attention Mechnaism

  • Multi-headed Attention (MHA)
  • Scaled dot prod attn (SDPA)
  • Abstracts the sparse attention. Can be done at the SDPA level or at the MHA level. Or, use the top_k=-1 to be the default (for dense MHA) and a value >0 to use sparse attention

rims's People

Contributors

bmistry4 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.