GithubHelp home page GithubHelp logo

trendingtechnology / torchdyn Goto Github PK

View Code? Open in Web Editor NEW

This project forked from diffeqml/torchdyn

0.0 1.0 0.0 16.66 MB

A PyTorch based library for all things neural differential equations

License: Apache License 2.0

Python 1.75% Jupyter Notebook 98.25%

torchdyn's Introduction

torchdyn

A PyTorch based library for all things neural differential equations. Maintained by DiffEqML.

Installation

git clone https://github.com/DiffEqML/torchdyn.git

cd torchdyn

python setup.py install

Documentation

https://torchdyn.readthedocs.io/

Introduction

Interest in the blend of differential equations, deep learning and dynamical systems has been reignited by recent works [1,2]. Modern deep learning frameworks such as PyTorch, coupled with progressive improvements in computational resources have allowed the continuous version of neural networks, with versions dating back to the 80s [3], to finally come to life and provide a novel perspective on classical machine learning problems (e.g. density estimation [4])

Since the introduction of the torchdiffeq library with the seminal work [1] in 2018, little effort has been made by the PyTorch research community on an unified framework for neural differential equations. While significant progress is being made by the Julia community and SciML [5], we believe a native PyTorch version of torchdyn with a focus on deep learning to be a valuable asset for the research ecosystem.

Central to the torchdyn approach are continuous neural networks, where width, depth (or both) are taken to their infinite limit. On the optimization front, we consider continuous "data-stream" regimes and gradient flow methods, where the dataset represents a time-evolving signal processed by the neural network to adapt its parameters.

By providing a centralized, easy-to-access collection of model templates, tutorial and application notebooks, we hope to speed-up research in this area and ultimately contribute to turning neural differential equations into an effective tool for control, system identification and common machine learning tasks.

torchdyn is developed and maintained by the core DiffEqML team, with the generous support of the deep learning community.

Dependencies

torchdyn leverages modern PyTorch best practices and handles training with pytorch-lightning [6]. We build Graph Neural ODEs utilizing the Graph Neural Networks (GNNs) API of dgl [6].

Goals of torchdyn

Our aim with torchdyn aims is to provide a unified, flexible API to the most recent advances in continuous deep learning. Examples include neural differential equations variants, e.g.

  • Neural Ordinary Differential Equations (Neural ODE) [1]
  • Neural Stochastic Differential Equations (Neural SDE) [7,8]
  • Graph Neural ODEs [9]
  • Hamiltonian Neural Networks [10]

Depth--variant versions,

  • ANODEv2 [11]
  • Galerkin Neural ODE [12]

Recurrent or "hybrid" versions

  • ODE-RNN [13]
  • GRU-ODE-Bayes [14]

Augmentation strategies to relieve neural differential equations of their expressivity limitations and reduce the computational burden of the numerical solver

  • ANODE (0-augmentation) [15]
  • Input-layer augmentation [16]
  • Higher-order augmentation [17]

Alternative or modified adjoint training techniques

  • Integral loss adjoint [18]
  • Checkpointed adjoint [19]

Applications and tutorials

The current version of torchdyn contains the following self-contained quickstart examples / tutorials (with a lot more to come):

  • 00_quickstart: offers a quickstart guide for torchdyn and Neural DEs
  • 01_cookbook: here, we explore the API and how to define Neural DE variants within torchdyn
  • 02_classification: convolutional Neural DEs on MNIST
  • 03_crossing_trajectories: a standard benchmark problem, highlighting expressivity limitations of Neural DEs, and how they can be addressed.
  • 04_augmentation_strategies: augmentation API for Neural DEs

and the advanced tutorials

  • 05_integral_adjoint: minimize integral losses with torchdyn's special integral loss adjoint [18] to track a sinusoidal signal.
  • 06_hamiltonian_neural_network: learn dynamics of energy preserving systems with a simple implementation of Hamiltonian Neural Networks in torchdyn [10]
  • 07_neural_graph_de: first steps into the vast world of Neural GDEs [9], or ODEs on graphs parametrized by graph neural networks (GNN). Classification on Cora.

Features

Check our wiki for a full description of available features.

Feature roadmap

The current offering of torchdyn is limited compared to the rich ecosystem of continuous deep learning. If you are a researcher working in this space, and particularly if one of your previous works happens to be a WIP feature, feel free to reach out and help us in its implementation.

In particular, we are missing the following, which will be added, in order.

  • Latent variable variants: Latent Neural ODE, ODE2VAE
  • Advanced recurrent versions: GRU-ODE-Bayes
  • Alternative adjoint for Neural SDE and Jump Stochastic Neural ODEs, as in [16]
  • Lagrangian Neural Networks [17]

Contribute

torchdyn is meant to be a community effort: we welcome all contributions of tutorials, model variants, numerical methods and applications related to continuous deep learning.

Cite us

If you find torchdyn valuable for your research or applied projects:

@article{massaroli2020stable,
  title={Stable Neural Flows},
  author={Massaroli, Stefano and Poli, Michael and Bin, Michelangelo and Park, Jinkyoo and Yamashita, Atsushi and Asama, Hajime},
  journal={arXiv preprint arXiv:2003.08063},
  year={2020}
}

torchdyn's People

Contributors

massastrello avatar zymrael avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.