GithubHelp home page GithubHelp logo

andrewfavor95 / jax-unirep Goto Github PK

View Code? Open in Web Editor NEW

This project forked from elarkk/jax-unirep

0.0 0.0 0.0 71.45 MB

Reimplementation of the UniRep protein featurization model.

License: GNU General Public License v3.0

Makefile 0.51% Python 99.49%

jax-unirep's Introduction

Build Status Code style: black

jax-unirep

Reimplementation of the UniRep protein featurization model in JAX.

The UniRep model was developed in George Church's lab, see the original publication here (bioRxiv) or here (Nature Methods), as well as the repository containing the original model.

The idea to reimplement the TF-based model in the much lighter JAX framework was coined by Eric Ma, who also developed a first version of it inside his functional deep-learning library fundl.

This repo is a self-contained version of the UniRep model (so far only the 1900 hidden-unit mLSTM), adapted and extended from fundl.

Installation

Ensure that your compute environment allows you to run JAX code. (A modern Linux or macOS with a GLIBC>=2.23 is probably necessary.)

For now, jax-unirep is available by pip installing from source.

Installation from GitHub:

pip install git+https://github.com/ElArkk/jax-unirep.git

Usage

Getting UniReps

To generate representations of protein sequences, pass a list of sequences as strings or a single sequence to jax_unirep.get_reps. It will return a tuple consisting of the following representations for each sequence:

  • h_avg: Average hidden state of the mLSTM over the whole sequence.
  • h_final: Final hidden state of the mLSTM
  • c_final: Final cell state of the mLSTM

From the original paper, h_avg is considered the "representation" (or "rep") of the protein sequence.

Only valid amino acid sequence letters belonging to the set:

MRHKDESTNQCUGPAVIFYWLOXZBJ

are allowed as inputs to get_reps. They may be passed in as a single string or an iterable of strings, and need not necessarily be of the same length.

In Python code, for a single sequence:

from jax_unirep import get_reps

sequence = "ASDFGHJKL"

# h_avg is the canonical "reps"
h_avg, h_final, c_final = get_reps(sequence)

And for multiple sequences:

from jax_unirep import get_reps

sequences = ["ASDF", "YJKAL", "QQLAMEHALQP"]

# h_avg is the canonical "reps"
h_avg, h_final, c_final= get_reps(sequences)

# each of the arrays will be of shape (len(sequences), 1900),
# with the correct order of sequences preserved

Evotuning

In the original paper the concept of 'evolutionary finetuning' is introduced, where the pre-trained mLSTM weights get fine-tuned through weight-updates using homolog protein sequences of a given protein of interest as input. This feature is available as well in jax-unirep. Given a set of starter weights for the mLSTM (defaults to the weights from the paper) as well as a set of sequences, the weights get fine-tuned in such a way that test set loss in the 'next-aa prediction task' is minimized. There are two functions with differing levels of control available.

The evotune function uses optuna under the hood to automatically find:

  1. the optimal number of epochs to train for, and
  2. the optimal learning rate,

given a set of sequences. The study object will contain all the information about the training process of each trial. evotuned_params will contain the fine-tuned mLSTM and dense weights from the trial with the lowest test set loss.

If you want to directly fine-tune the weights for a fixed number of epochs while using a fixed learning rate, you should use the fit function instead. The fit function has further customization options, such as different batching strategies. Please see the function docstring for more information.

You can find an example usages of both evotune and fit here.

If you want to pass a set of mLSTM and dense weights that were dumped in an earlier run, create params as follows:

from jax_unirep.utils import load_params

params = load_params(folderpath="path/to/params/folder")

If you want to start from randomly initialized mLSTM and dense weights instead:

from jax_unirep.evotuning import init_fun
from jax.random import PRNGKey

_, params = init_fun(PRNGKey(0), input_shape=(-1, 10))

The weights used in the 10-dimensional embedding of the input sequences always default to the weights from the paper, since they do not get updated during evotuning.

UniRep stax

We implemented the mLSTM layers in such a way that they are compatible with jax.experimental.stax. This means that they can easily be plugged into a stax.serial model, e.g. to train both the mLSTM and a top-model at once:

from jax.experimental import stax
from jax.experimental.stax import Dense, Relu

from jax_unirep.layers import mLSTM1900, mLSTM1900_AvgHidden

init_fun, apply_fun = stax.serial(
    mLSTM1900(),
    mLSTM1900_AvgHidden(),
    Dense(512), Relu(),
    Dense(1)
)

Have a look at the documentation and examples for more information about how to implement a model in jax.

More Details

To read more about how we reimplemented the model in JAX, we wrote it up. Both the HTML and PDF are available.

License

All the model weights are licensed under the terms of Creative Commons Attribution-NonCommercial 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

Otherwise the code in this repository is licensed under the terms of GPL v3.

jax-unirep's People

Contributors

elarkk avatar ericmjl avatar ivanjayapurna avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.