GithubHelp home page GithubHelp logo

jambo6 / online-neural-cdes Goto Github PK

View Code? Open in Web Editor NEW
36.0 2.0 6.0 9.48 MB

Code for: "Neural Controlled Differential Equations for Online Prediction Tasks"

Python 100.00%
time-series machine-learning deep-learning neural-networks signatures neural-cdes neural-odes medical-time-series

online-neural-cdes's Introduction

Neural Controlled Differential Equations for Online Prediction Tasks
[arXiv]

Overview

Neural controlled differential equations (Neural CDEs) are state-of-the-art models for irregular time series. However, due to current implementations relying on non-causal interpolation schemes, Neural CDEs cannot currently be used in online prediction tasks; that is, in real-time as data arrives. This is in contrast to similar ODE models such as the ODE-RNN which can already operate in continuous time.

Here we introduce and benchmark new interpolation schemes, most notably, rectilinear interpolation, which allows for an online everywhere causal solution to be defined. This code will reproduce the tables seen in the paper which demonstrates that Neural CDEs can also be considered as SOTA irregular time series models for online prediction tasks.


Reproducing the Experiments

Setup the Environment

Create a new environment with python==3.9, earlier versions of python will not work. The environment is built with pbr. First, a git repo must be initialised and tagged.

git init
git add .
git commit -m "First commit"
git tag v0.0.1

then we can run

pip install -e .

to install the requirements and add the src directory to the path.

This repository also relies on modified versions of the torchdiffeq and torchcde libraries. These live in modules/ and need also to be installed with

pip install modules/torchdiffeq
pip install modules/torchcde

Reproducing the Brownian motion prediction example

Run the file sim_bm_toy_example.py. The results can then be found at /results/sim_bm/results_table.csv.

Downloading the regular datasets

Initialise empty folders in data/raw and data/processed. To download the regular datasets run the specified download script in /get_data/download. Note that the Beijing datasets are part of the TSR archive, and hence the file is named tsr.py.

To process, run the script with the same name in get_data

Downloading MIMIC-IV

The steps for processing the MIMIC-IV dataset are as follows

  1. Get access by completing the human subjects research training course -> https://mimic.mit.edu/iii/gettingstarted/
  2. Run the saved query get_data/mimic-iv/query.sql on the data. This can be run either in googles bigquery or in sql if the dataset is downloaded locally. Save the output to data/raw/raw_combined.csv.
  3. Run get_data/mimic-iv/build_raw.py.
  4. Run get_data/mimic-iv/prepare.py.

The processed datasets will now be saved in /data/processed

Setup a MongoDB instance

Experiments are run and saved into a MongoDB. We used the Mongo ATLAS service which can be adapted to your own cluster by setting the ATLAS_HOST variable in experiments/variables.py. If a different mongodb is wanted (for example local), simply modify the get_client function in experiments/utils.py.

Run the Experiments

Navigate to experiments, run python and in the shell run

import runs
runs.run(CONFIGURATION_NAME, gpus=GPU_LIST)

and thus will run for the configuration name in parallel over the GPUs in GPU_LIST.

To reproduce the tables from the paper the configurations hyperopt, interpolation, medical-sota must be run. hyperopt (for obvious reasons) must be run first.

The runs took us around 3 weeks in total to complete spread over multiple GPUs, so it is not recommended to try to reproduce everything. Instead, modify the configuration file and subselect only the datasets of interest.

Tables can be reproduced once the runs are complete by running experiments/analyse.py. However, this has not been tested for only subsets of completed runs and therefore is liable to break.

Citation

@article{morrill2021neuralcontrolled,
  title={Neural Controlled Differential Equations for Online Prediction Tasks},
  author={Morrill, James and Kidger, Patrick and Yang, Lingyi and Lyons, Terry},
  journal={arXiv preprint arXiv:2106.11028},
  year={2021}
}

online-neural-cdes's People

Contributors

jambo6 avatar lingyiyang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

online-neural-cdes's Issues

Sim_bm having the same final results

Hi, I have downloaded the library and run the given example for simple Brownian motion. The results seem to be puzzling as all the statistical output for 4 different interpolation methods have the exact same training and testing accuracy as well as standard deviation.

I assume there might be some error in printing the results. The rest is fine. The final statistic part of output is as follows:

[100, 4096, 10, 256]
Train Accuracy: 0.8145344853401184
Test Accuracy: 0.8037109375
[100, 4096, 10, 256]
Train Accuracy: 0.8245442509651184
Test Accuracy: 0.8294270634651184
[100, 4096, 10, 256]
Train Accuracy: 0.822265625
Test Accuracy: 0.8216145634651184
[100, 4096, 10, 256]
Train Accuracy: 0.824462890625
Test Accuracy: 0.822265625
[100, 4096, 10, 256]
Train Accuracy: 0.8116047978401184
Test Accuracy: 0.8108723759651184
Mean train: 0.8194824457168579, Mean test: 0.817578136920929
s.d. train: 0.005379822105169296, s.d. test: 0.009120622649788857
[100, 4096, 10, 256]
Train Accuracy: 0.8145344853401184
Test Accuracy: 0.8037109375
[100, 4096, 10, 256]
Train Accuracy: 0.8245442509651184
Test Accuracy: 0.8294270634651184
[100, 4096, 10, 256]
Train Accuracy: 0.822265625
Test Accuracy: 0.8216145634651184
[100, 4096, 10, 256]
Train Accuracy: 0.824462890625
Test Accuracy: 0.822265625
[100, 4096, 10, 256]
Train Accuracy: 0.8116047978401184
Test Accuracy: 0.8108723759651184
Mean train: 0.8194824457168579, Mean test: 0.817578136920929
s.d. train: 0.005379822105169296, s.d. test: 0.009120622649788857
[100, 4096, 10, 256]
Train Accuracy: 0.8145344853401184
Test Accuracy: 0.8037109375
[100, 4096, 10, 256]
Train Accuracy: 0.8245442509651184
Test Accuracy: 0.8294270634651184
[100, 4096, 10, 256]
Train Accuracy: 0.822265625
Test Accuracy: 0.8216145634651184
[100, 4096, 10, 256]
Train Accuracy: 0.824462890625
Test Accuracy: 0.822265625
[100, 4096, 10, 256]
Train Accuracy: 0.8116047978401184
Test Accuracy: 0.8108723759651184
Mean train: 0.8194824457168579, Mean test: 0.817578136920929
s.d. train: 0.005379822105169296, s.d. test: 0.009120622649788857
[100, 4096, 10, 256]
Train Accuracy: 0.8145344853401184
Test Accuracy: 0.8037109375
[100, 4096, 10, 256]
Train Accuracy: 0.8245442509651184
Test Accuracy: 0.8294270634651184
[100, 4096, 10, 256]
Train Accuracy: 0.822265625
Test Accuracy: 0.8216145634651184
[100, 4096, 10, 256]
Train Accuracy: 0.824462890625
Test Accuracy: 0.822265625
[100, 4096, 10, 256]
Train Accuracy: 0.8116047978401184
Test Accuracy: 0.8108723759651184
Mean train: 0.8194824457168579, Mean test: 0.817578136920929
s.d. train: 0.005379822105169296, s.d. test: 0.009120622649788857

Cannot install

I've been trying to follow the README instructions for installing this package both locally and on the remote machine that I will use for experiments using this online NCDE. However, I cannot install the libraries that you've written (sacredex and autots) as currently expected (using requirements.txt through setup.py via the pip install -e . command.

The errors I'm seeing seem to correspond to a change in GitHub's privacy and security settings. The errors I'm seeing are as follows:

Collecting autots@ git+git://github.com/jambo6/[email protected]
  Cloning git://github.com/jambo6/autots (to revision v0.0.8) to /tmp/pip-install-do6rkh9r/autots_7cbd821d414844ec9e293c2bba5a389b
  Running command git clone -q git://github.com/jambo6/autots /tmp/pip-install-do6rkh9r/autots_7cbd821d414844ec9e293c2bba5a389b
  fatal: remote error:
    The unauthenticated git protocol on port 9418 is no longer supported.
  Please see https://github.blog/2021-09-01-improving-git-protocol-security-github/ for more information.
WARNING: Discarding git+git://github.com/jambo6/[email protected]. Command errored out with exit status 128: git clone -q git://github.com/jambo6/autots /tmp/pip-install-do6rkh9r/autots_7cbd821d414844ec9e293c2bba5a389b Check the logs for full command output.
Collecting sacredex@ git+git://github.com/jambo6/[email protected]
  Cloning git://github.com/jambo6/sacredex (to revision v0.0.5) to /tmp/pip-install-do6rkh9r/sacredex_9ff0834e97d7463d832fa49dcf71048f
  Running command git clone -q git://github.com/jambo6/sacredex /tmp/pip-install-do6rkh9r/sacredex_9ff0834e97d7463d832fa49dcf71048f
  fatal: remote error:
    The unauthenticated git protocol on port 9418 is no longer supported.
  Please see https://github.blog/2021-09-01-improving-git-protocol-security-github/ for more information.
WARNING: Discarding git+git://github.com/jambo6/[email protected]. Command errored out with exit status 128: git clone -q git://github.com/jambo6/sacredex /tmp/pip-install-do6rkh9r/sacredex_9ff0834e97d7463d832fa49dcf71048f Check the logs for full command output.
Collecting dnspython==2.1.0
  Using cached dnspython-2.1.0-py3-none-any.whl (241 kB)
Collecting pytest==6.2.2
  Using cached pytest-6.2.2-py3-none-any.whl (280 kB)
Collecting pre-commit==2.10.1
  Using cached pre_commit-2.10.1-py2.py3-none-any.whl (185 kB)
Collecting pandas==1.2.2
  Using cached pandas-1.2.2-cp39-cp39-manylinux1_x86_64.whl (9.7 MB)
Collecting scikit-learn==0.24.1
  Using cached scikit_learn-0.24.1-cp39-cp39-manylinux2010_x86_64.whl (23.8 MB)
Collecting tqdm==4.57.0
  Using cached tqdm-4.57.0-py2.py3-none-any.whl (72 kB)
ERROR: Could not find a version that satisfies the requirement autots (unavailable) (from online-neural-cdes) (from versions: 0.0.2, 0.0.3, 0.1.0, 0.1.1, 0.1.2, 0.1.5, 0.2.0a1, 0.2.0a3, 0.2.0a4, 0.2.0, 0.2.1, 0.2.2a1, 0.2.2, 0.2.3a1, 0.2.3, 0.2.4, 0.2.5, 0.2.6, 0.2.7, 0.2.8, 0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5, 0.3.6, 0.3.7, 0.3.8, 0.3.9, 0.3.10, 0.3.11, 0.3.12, 0.3.13a9, 0.4.0)
ERROR: No matching distribution found for autots (unavailable)

Would it perhaps be possible to place these dependencies in the modules folder like has been done with torchcde and torchdiffeq?

Any other advice?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.