GithubHelp home page GithubHelp logo

kyegomez / hsss Goto Github PK

View Code? Open in Web Editor NEW
13.0 3.0 2.0 2.24 MB

Implementation of a Hierarchical Mamba as described in the paper: "Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling"

Home Page: https://discord.gg/jkwyyFdANm

License: MIT License

Python 100.00%
ai artificial-intelligence machine-learning ml ssms multi-modal multi-modality jesus open-source pytorch

hsss's Introduction

Multi-Modality

HSSS

Implementation of a Hierarchical Mamba as described in the paper: "Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling" but instead of using traditional SSMs were using Mambas. Basically the flow is single input -> low level mambas -> concat -> high level ssm -> multiple outputs.

Paper link

I believe in this architecture alot as it segments local and global learning.

install

pip install hsss

usage

import torch
from hsss.model import LowLevelMamba, HSSS


# Random input text tokens
text = torch.randint(0, 10, (1, 100)).long()

# Low level model
mamba = LowLevelMamba(
    dim=12,  # dimension of input
    depth=6,  # depth of input
    dt_rank=4,  # rank of input
    d_state=4,  # state of input
    expand_factor=4,  # expansion factor of input
    d_conv=6,  # convolution dimension of input
    dt_min=0.001,  # minimum time step of input
    dt_max=0.1,  # maximum time step of input
    dt_init="random",  # initialization method of input
    dt_scale=1.0,  # scaling factor of input
    bias=False,  # whether to use bias in input
    conv_bias=True,  # whether to use bias in convolution of input
    pscan=True,  # whether to use parallel scan in input
)


# Low level model 2
mamba2 = LowLevelMamba(
    dim=12,  # dimension of input
    depth=6,  # depth of input
    dt_rank=4,  # rank of input
    d_state=4,  # state of input
    expand_factor=4,  # expansion factor of input
    d_conv=6,  # convolution dimension of input
    dt_min=0.001,  # minimum time step of input
    dt_max=0.1,  # maximum time step of input
    dt_init="random",  # initialization method of input
    dt_scale=1.0,  # scaling factor of input
    bias=False,  # whether to use bias in input
    conv_bias=True,  # whether to use bias in convolution of input
    pscan=True,  # whether to use parallel scan in input
)


# Low level mamba 3
mamba3 = LowLevelMamba(
    dim=12,  # dimension of input
    depth=6,  # depth of input
    dt_rank=4,  # rank of input
    d_state=4,  # state of input
    expand_factor=4,  # expansion factor of input
    d_conv=6,  # convolution dimension of input
    dt_min=0.001,  # minimum time step of input
    dt_max=0.1,  # maximum time step of input
    dt_init="random",  # initialization method of input
    dt_scale=1.0,  # scaling factor of input
    bias=False,  # whether to use bias in input
    conv_bias=True,  # whether to use bias in convolution of input
    pscan=True,  # whether to use parallel scan in input
)


# HSSS
hsss = HSSS(
    layers=[mamba, mamba2, mamba3],
    num_tokens=10,  # number of tokens in model
    seq_length=100,  # sequence length of model
    dim=128,  # dimension of model
    depth=3,  # depth of model
    dt_rank=2,  # rank of model
    d_state=2,  # state of model
    expand_factor=2,  # expansion factor of model
    d_conv=3,  # convolution dimension of model
    dt_min=0.001,  # minimum time step of model
    dt_max=0.1,  # maximum time step of model
    dt_init="random",  # initialization method of model
    dt_scale=1.0,  # scaling factor of model
    bias=False,  # whether to use bias in model
    conv_bias=True,  # whether to use bias in convolution of model
    pscan=True,  # whether to use parallel scan in model
    proj_layer=True,
)


# Forward pass
out = hsss(text)
print(out)

Citation

@misc{bhirangi2024hierarchical,
      title={Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling}, 
      author={Raunaq Bhirangi and Chenyu Wang and Venkatesh Pattabiraman and Carmel Majidi and Abhinav Gupta and Tess Hellebrekers and Lerrel Pinto},
      year={2024},
      eprint={2402.10211},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

License

MIT

Todo

  • Implement the chunking of the tokens by spliting it up the sequence dimension

  • Make the fusion projection layer dynamic and not use just a linear, ffn, or cross attention or even an output head.

hsss's People

Contributors

kyegomez avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.