GithubHelp home page GithubHelp logo

bytesumoltd / pytorch-esn Goto Github PK

View Code? Open in Web Editor NEW

This project forked from stefanonardo/pytorch-esn

2.0 2.0 0.0 33.8 MB

An Echo State Network module for PyTorch, with evolutionary hyper-parameter tuning

License: MIT License

Python 54.32% Jupyter Notebook 45.48% Shell 0.20%

pytorch-esn's Introduction

PyTorch-ESN, with Genetic Algorithms for Tuning

PyTorch-ESN is a PyTorch module, written in Python, implementing Echo State Networks with leaky-integrated units. ESN's implementation with more than one layer is based on DeepESN. The readout is trainable by ridge regression or by PyTorch's optimizers.

Its development started under Stefano's master thesis titled "An Empirical Comparison of Recurrent Neural Networks on Sequence Modeling", which was supervised by Prof. Alessio Micheli and Dr. Claudio Gallicchio at the University of Pisa.

This https://github.com/ByteSumoLtd/pytorch-esn fork of Stefano's core library, and adds the following

  • examples of running a grid search of mackey-glass hyperparameters, illustrates hyperparams have a large effect on performance
  • conversion of univariate mackey-glass example, into a general command line tool for any univariate timeseries having two csv columns 'Input, Target'. The file should not have a header.
  • ability to manually set a seed in the commandline tool, for reproducible results
  • some hacky helpers to reinstall the modules when in active development
  • inclusion of torchesn/optim and an optimise module that configures DEAP to search for good ESN hyperparamets. Opinionated but with good defaults.
  • inclusion of fn_autotune commandline tool, that automates discovery of good hyper-parameters for your problem, automating the whole ESN training process. The specific cmdline tool to call is parameterised to optimise many different kinds of solutions.
  • inclusion of multiprocessing as a parallisation mechanism to accelerate genetic search for good models on a single GPU, and parameters to simply set worker number
  • inclusion of a number_islands parameter, to automate deploying multidemic solutions in Deap using ring emmigration, to explore preventing premature convergence. Of note - the multi island solution includes corrected coded improving common examples found online, to also leverage parallelizing fitness evaluations to speed up the search
  • inclusion of fn_cotrend commandline, a more dynamic timeseries ESN function supporting multiple input timeseries features, to predict single outputs.

to do:

  • include a python notebook for a worked example, illustrating the cost / benefits of genetic search
  • extend fn_cotrend to become a general solution for mulitvariate ESN pipelines

Examples: a) in the examples directory, test executing the mackey-glass example, using some parameters:

 fn_mackey_glass --hidden_size 15 --input_size 1 --output_size 1 --spectral_radius 0.8981482392105415 --density 0.5411114043211104 --leaking_rate 0.31832532215823006 --lambda_reg 0.3881908755655027 --nonlinearity tanh --readout_training cholesky --w_io True --w_ih_scale 0.826708317582006 --seed 10

B) in the examples directory, test using genetic search to find good hyperparameters for mackey-glass:


> fn_autotune --population 30 --generations 10 --max_layers 1 --hidden_size_low 150 --hidden_size_high 150 --worker_pool 6 | tee logs/example.log 
> cat logs/example.log

2020-06-03 10:37:48.642212
gen	nevals	avg    	std    	min        	max    
0  	30    	2.83442	2.86397	2.09471e-05	6.65214
1  	20    	2.78962	2.84464	2.09471e-05	7.11487
2  	24    	2.76368	2.80678	1.33755e-05	6.85562
3  	22    	3.17741	3.22544	1.33755e-05	7.29831
4  	26    	3.09174	3.12538	1.33755e-05	7.23457
5  	25    	2.95496	3.02999	1.13259e-05	7.81776
6  	24    	3.25476	3.33798	1.13259e-05	7.61532
7  	23    	2.86864	2.95118	1.13259e-05	7.61532
8  	25    	2.81973	2.87699	1.13259e-05	7.12179
9  	27    	2.82086	2.86294	1.13259e-05	7.34068
10 	25    	2.98786	3.06309	1.13259e-05	7.82064
2020-06-03 10:46:13.766080
{   '_cmdline_tool': 'fn_mackey_glass',
    'attr_batch_first': False,
    'attr_density': 0.6905894761994907,
    'attr_hidden': 150,
    'attr_input_size': 1,
    'attr_lambda_reg': 0.6304891830771884,
    'attr_leaking_rate': 0.9440340390508313,
    'attr_nonlinearity': 'tanh',
    'attr_num_layers': 1,
    'attr_output_size': 1,
    'attr_output_steps': 'all',
    'attr_readout_training': 'cholesky',
    'attr_spectral_radius': 1.3467750025214633,
    'attr_w_ih_scale': 0.43537928601439935,
    'attr_w_io': True,
    'auto_crossover_probability': 0.7,
    'auto_generations': 10,
    'auto_mutation_probability': 0.3,
    'auto_population': 30,
    'run_end_time': datetime.datetime(2020, 6, 3, 10, 46, 13, 766080),
    'run_start_time': datetime.datetime(2020, 6, 3, 10, 37, 48, 642212),
    'run_training_loss': deap.creator.FitnessMulti((1.1325887772018666e-05, 4.748876571655273))}

# the command line view of the params is:
# fn_mackey_glass --hidden_size 150 --input_size 1 --output_size 1 --spectral_radius 1.3467750025214633 --density 0.6905894761994907 --leaking_rate 0.9440340390508313 --lambda_reg 0.6304891830771884 --nonlinearity tanh --readout_training cholesky --w_io True --w_ih_scale 0.43537928601439935 --seed 10

Prerequisites

  • PyTorch, deap, multiprocessing, pprint, click

Basic Usage

Offline training (ridge regression)

SVD

Mini-batch mode is not allowed with this method.

from torchesn.nn import ESN
from torchesn.utils import prepare_target

# prepare target matrix for offline training
flat_target = prepare_target(target, seq_lengths, washout)

model = ESN(input_size, hidden_size, output_size)

# train
model(input, washout, hidden, flat_target)

# inference
output, hidden = model(input, washout, hidden)

Cholesky or inverse

from torchesn.nn import ESN
from torchesn.utils import prepare_target

# prepare target matrix for offline training
flat_target = prepare_target(target, seq_lengths, washout)

model = ESN(input_size, hidden_size, output_size, readout_training='cholesky')

# accumulate matrices for ridge regression
for batch in batch_iter:
    model(batch, washout[batch], hidden, flat_target)

# train
model.fit()

# inference
output, hidden = model(input, washout, hidden)

Classification tasks

For classification, just use one of the previous methods and pass 'mean' or 'last' to output_steps argument.

model = ESN(input_size, hidden_size, output_size, output_steps='mean')

For more information see docstrings or section 4.7 of "A Practical Guide to Applying Echo State Networks" by Mantas Lukoševičius.

Online training (PyTorch optimizer)

Same as PyTorch.

pytorch-esn's People

Contributors

bytesumoltd avatar stefanonardo avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.