GithubHelp home page GithubHelp logo

awslabs / renate Goto Github PK

View Code? Open in Web Editor NEW
264.0 11.0 7.0 6.44 MB

Library for automatic retraining and continual learning

Home Page: https://renate.readthedocs.io

License: Apache License 2.0

Shell 0.02% Python 99.98%
aws continual-learning machine-learning machine-learning-algorithms neural neural-network pytorch pytorch-lightning sagemaker hyperparameter-tuning

renate's Introduction

PyPI - Status

Latest Release

PyPI - Downloads

License

Documentation Status

Coverage Badge

Renate: Automatic Neural Networks Retraining and Continual Learning in Python

Renate is a Python package for automatic retraining of neural networks models. It uses advanced Continual Learning and Lifelong Learning algorithms to achieve this purpose. The implementation is based on PyTorch and Lightning for deep learning, and Syne Tune for hyperparameter optimization.

Who needs Renate?

In many applications data is made available over time and retraining from scratch for every new batch of data is prohibitively expensive. In these cases, we would like to use the new batch of data provided to update our previous model with limited costs. Unfortunately, since data in different chunks is not sampled according to the same distribution, just fine-tuning the old model creates problems like catastrophic forgetting. The algorithms in Renate help mitigating the negative impact of forgetting and increase the model performance overall.

Renate's update mechanisms improve over naive fine-tuning approaches.

Renate's update mechanisms improve over naive fine-tuning approaches.1

Renate also offers hyperparameter optimization (HPO), a functionality that can heavily impact the performance of the model when continuously updated. To do so, Renate employs Syne Tune under the hood, and can offer advanced HPO methods such multi-fidelity algorithms (ASHA) and transfer learning algorithms (useful for speeding up the retuning).

Renate will benefit from hyperparameter tuning compared to Renate with default settings.

Renate will benefit from hyperparameter tuning compared to Renate with default settings.2

Key features

  • Easy to scale and run in the cloud
  • Designed for real-world retraining pipelines
  • Advanced HPO functionalities available out-of-the-box
  • Open for experimentation

Resources

Cite Renate

@misc{renate2023,
  title           = {Renate: A Library for Real-World Continual Learning}, 
  author          = {Martin Wistuba and
                     Martin Ferianc and
                     Lukas Balles and
                     Cedric Archambeau and
                     Giovanni Zappella},
  year            = {2023},
  eprint          = {2304.12067},
  archivePrefix   = {arXiv},
  primaryClass    = {cs.LG}
}

What are you looking for?

If you did not find what you were looking for, open an issue and we will do our best to improve the documentation.


  1. To create this plot, we simulated domain-incremental learning with CLEAR-100. The training data was divided by year, and we trained sequentially on them. Fine-tuning refers to the strategy to learn on the first partition from scratch, and train on each of the subsequent partitions for few epochs only. We compare to Experience Replay with an infinite memory size. For both methods we use the same amount of training time and choose the best checkpoint using a validation set. Results reported are on the test set.

  2. In this experiment, we consider class-incremental learning on CIFAR-10. We compare Experience Replay against a version in which its hyperparameters were tuned.

renate's People

Contributors

610v4nn1 avatar amazon-auto avatar dependabot[bot] avatar geoalgo avatar lballes avatar prabhuteja12 avatar wesk avatar wistuba avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

renate's Issues

Passing of booleans in the config file

I tried to pass a boolean via the config file, i.e.,

config = {..., "disable_grad_weight": False, ... }
execute_experiment_job(..., config_space = config, ....)

However the value gets set to True when arriving at the learner. Is this a know bug?
I could work around this by passing an int (and modifying the argument's type accordingly).

Example config file for nlp datasets

Hi authors, is there any example of a config file for nlp datasets, especially around applying transformations? I’m trying to create a minimal example. Here's the config file I'm using-

# renate_config.py
from pathlib import Path
from typing import Callable, Optional, Union

import torch
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url
from torchvision.transforms import Lambda

from renate.benchmark.models.mlp import MultiLayerPerceptron

from renate.models import RenateModule

from renate import defaults
from renate.benchmark.datasets.nlp_datasets import TorchTextDataModule
from renate.benchmark.scenarios import ClassIncrementalScenario, Scenario


def data_module_fn(data_path: Union[Path, str], chunk_id: int, seed: int = defaults.SEED) -> Scenario:
    """Returns a class-incremental scenario instance.
    The transformations passed to prepare the input data are required to convert the data to
    PyTorch tensors.
    """
    data_module = TorchTextDataModule(
        str(data_path),
        dataset_name="AG_news",
        val_size=0.1,
        seed=seed,
    )

    class_incremental_scenario = ClassIncrementalScenario(
        data_module=data_module,
        class_groupings=[[0, 1], [2, 3]],
        chunk_id=chunk_id,
    )
    return class_incremental_scenario


def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
    """Returns a model instance."""
    if model_state_url is None:
        model = MultiLayerPerceptron(
            num_inputs=256, num_outputs=4, num_hidden_layers=2, hidden_size=64
        )
    else:
        state_dict = torch.load(str(model_state_url))
        model = MultiLayerPerceptron.from_state_dict(state_dict)
    return model


def train_transform() -> Callable:
    """Returns a transform function to be used in the training."""
    padding_idx = 1
    bos_idx = 0
    eos_idx = 2
    max_seq_len = 256
    xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
    xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
    text_transform = T.Sequential(
        T.SentencePieceTokenizer(xlmr_spm_model_path),
        T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)),
        T.Truncate(max_seq_len - 2),
        T.AddToken(token=bos_idx, begin=True),
        T.AddToken(token=eos_idx, begin=False),
    )

    return Lambda(lambda x: text_transform(x))

And the training job is below-

from renate.tuning import execute_tuning_job

config_space = {
    "optimizer": "SGD",
    "momentum": 0.0,
    "weight_decay": 1e-2,
    "learning_rate": 0.05,
    "batch_size": 32,
    "max_epochs": 5,
    "memory_batch_size": 32,
    "memory_size": 500,
}

if __name__ == "__main__":

    execute_tuning_job(
        config_space=config_space,
        mode="max",
        metric="val_accuracy",
        updater="ER",
        max_epochs=5,
        chunk_id=0,  # this selects the first chunk of the dataset
        config_file="renate_config.py",
        next_state_url="./output_folder/",  # this is where the model will be stored
        backend="local",  # the training job will run on the local machine
    )

This is the stack trace of the error I'm getting-

Logs (stderr):

Global seed set to 0
[2023-01-26 19:28:50,625] INFO [renate.updaters.model_updater:183] No location for current updater state provided. Updating will start from scratch.
[2023-01-26 19:28:50,625] WARNING [renate.updaters.model_updater:254] No updater state available. Updating from scratch.
Traceback (most recent call last):
  File "/Users/mridulgr/opt/anaconda3/envs/cl/lib/python3.9/site-packages/renate/cli/run_training.py", line 249, in <module>
    ModelUpdaterCLI().run()
  File "/Users/mridulgr/opt/anaconda3/envs/cl/lib/python3.9/site-packages/renate/cli/run_training.py", line 238, in run
    model_updater.update(
  File "/Users/mridulgr/opt/anaconda3/envs/cl/lib/python3.9/site-packages/renate/updaters/model_updater.py", line 331, in update
    train_loader, val_loader = self._learner.on_model_update_start(
  File "/Users/mridulgr/opt/anaconda3/envs/cl/lib/python3.9/site-packages/renate/updaters/experimental/er.py", line 83, in on_model_update_start
    train_loader, val_loader = super().on_model_update_start(
  File "/Users/mridulgr/opt/anaconda3/envs/cl/lib/python3.9/site-packages/renate/updaters/learner.py", line 240, in on_model_update_start
    train_loader = DataLoader(
  File "/Users/mridulgr/opt/anaconda3/envs/cl/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 344, in __init__
    sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
  File "/Users/mridulgr/opt/anaconda3/envs/cl/lib/python3.9/site-packages/torch/utils/data/sampler.py", line 107, in __init__
    raise ValueError("num_samples should be a positive integer "
ValueError: num_samples should be a positive integer value, but got num_samples=0

Target column named is unused In HuggingfaceTextDataModule

In HuggingfaceTextDataModule, the target_column attribute is not passed to the _InputTargetWrapper dataset class, because of which it only accesses target_column: str = "label" always (and fails when that column is not present in a dataset).

Code to reproduce:

tokenizer = transformers.AutoTokenizer.from_pretrained("bert-large-uncased")
dataset_name = "knkarthick/dialogsum"
data_module = HuggingfaceTextDataModule(
        str(data_path),
        dataset_name=dataset_name,
        tokenizer=tokenizer,
        input_column="dialogue",
        target_column="summary",
        val_size=0.05,
        seed=seed,
    )

Simplify Benchmarking

Passing arguments when benchmarking can be improved. Arguments names are too long, same values need to be passed multiple times and values are limited to primitive variable types.

Test Examples

We never run our examples ourselves so they may stop working without noticing. Since they are an integral part for the onboarding process of new users, we should cover them as part of our integration test to make sure that they are up-to-date.

Reduce the amount of dependencies

Is your feature request related to a problem? Please describe.
Currently Renate 0.1 installs a large number of dependencies, probably due to the options used to install SyneTune.
This is the list of installed dependencies after a pip install in a fresh virtual environment:

aiohttp==3.8.3
aiosignal==1.3.1
alabaster==0.7.13
async-timeout==4.0.2
attrs==22.2.0
autograd==1.5
Babel==2.11.0
black==22.3.0
boto3==1.26.58
botocore==1.29.58
botorch==0.8.1
certifi==2022.12.7
charset-normalizer==2.1.1
click==8.1.3
coloredlogs==15.0.1
ConfigSpace==0.6.1
contextlib2==21.6.0
cramjam==2.6.2
dill==0.3.6
distlib==0.3.6
docutils==0.17.1
exceptiongroup==1.1.0
fastparquet==0.8.3
filelock==3.9.0
flake8==6.0.0
flatbuffers==23.1.21
frozenlist==1.3.3
fsspec==2023.1.0
future==0.18.3
google-pasta==0.2.0
gpytorch==1.9.1
grpcio==1.51.1
h5py==3.8.0
humanfriendly==10.0
idna==3.4
imagesize==1.4.1
importlib-metadata==4.13.0
iniconfig==2.0.0
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.2.0
jsonschema==4.17.3
latexcodec==2.0.1
lightning-utilities==0.6.0.post0
linear-operator==0.3.0
markdown-it-py==2.1.0
MarkupSafe==2.1.2
mccabe==0.7.0
mdit-py-plugins==0.3.3
mdurl==0.1.2
more-itertools==9.0.0
mpmath==1.2.1
msgpack==1.0.4
multidict==6.0.4
multipledispatch==0.6.0
multiprocess==0.70.14
mypy-extensions==0.4.3
myst-parser==0.18.1
numpy==1.23.5
onnxruntime==1.13.1
opt-einsum==3.3.0
packaging==23.0
pandas==1.4.4
pathos==0.3.0
pathspec==0.11.0
patsy==0.5.3
Pillow==9.4.0
platformdirs==2.6.2
pluggy==1.0.0
portalocker==2.7.0
pox==0.3.2
ppft==1.7.6.6
protobuf==3.20.1
protobuf3-to-dict==0.1.5
pyaml==21.10.1
pybtex==0.24.0
pybtex-docutils==1.0.2
pycodestyle==2.10.0
pyflakes==3.0.1
Pygments==2.14.0
pyparsing==3.0.9
pyro-api==0.1.2
pyro-ppl==1.8.4
pyrsistent==0.19.3
pytest==7.2.1
pytest-timeout==2.1.0
python-dateutil==2.8.2
pytorch-lightning==1.8.6
pytz==2022.7.1
PyYAML==6.0
ray==2.2.0
Renate==0.1.0
requests==2.28.2
s3fs==0.4.2
s3transfer==0.6.0
sagemaker==2.130.0
schema==0.7.5
scikit-learn==1.2.1
scikit-optimize==0.9.0
scipy==1.10.0
six==1.16.0
smdebug-rulesconfig==1.0.1
snowballstemmer==2.2.0
Sphinx==5.3.0
sphinx-copybutton==0.5.1
sphinx-rtd-theme==1.1.1
sphinx_autodoc_typehints==1.21.8
sphinxcontrib-applehelp==1.0.4
sphinxcontrib-bibtex==2.5.0
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
statsmodels==0.13.5
sympy==1.11.1
syne-tune==0.3.4
tabulate==0.9.0
tensorboardX==2.5.1
threadpoolctl==3.1.0
tomli==2.0.1
torch==1.13.1
torchdata==0.5.1
torchmetrics==0.10.3
torchtext==0.14.1
torchvision==0.14.1
tqdm==4.64.1
typing_extensions==4.4.0
ujson==5.7.0
urllib3==1.26.14
virtualenv==20.17.1
xgboost==1.7.3
yahpo-gym==1.0.1
yarl==1.8.2
zipp==3.11.0

Describe the solution you'd like
Avoid (or make optional) unnecessary dependencies.
For example, it's unlikely that renate users will need yahpo-gym or ray in the current state of things.

Benchmarks as reference point for users

Currently users have hard time to understand what can be achieved using continual learning methods instead of ad-hoc heuristics. It would be great to provide some benchmarks against some simple baselines tracking both accuracy and training time.

Suggested baselines: retrain from scratch, fine-tune on latch data chunk.

Change LocalBackend to use num_gpus_per_trial for SyneTune > 0.6

SyneTune has a feature where one can select the number of GPUs. This has been released in 0.6.0. The current way to accomplish multi-gpu training has been to use rotate_gpus=False in LocalBackend. Once we upgrade Synetune, we should modify our code in src/renate/training/training.py in _execute_training_and_tuning_job_locally function to use that feature instead of rotate_gpus.
This is currently flagged with a TODO in that file.

Improve Main API

Some arguments and imports are a bit confusing, e.g. importing from tuning when doing an update or calling execute_tuning_job when running a simple update job with fixed hyperparameters.

Avoid setting transforms in the buffer

Currently, the DataBuffer receives transform and target_transform and applies them when accessing an item. When deserializing the Learner state, this requires resetting the transforms (see DataBuffer.set_transforms) and generally bloats the code of the buffer.

I would propose to remove the transforms form the buffer and instead generalize the _TransformedDataset such that it can be applied to DataBuffers as well. This way, transforms are handled in a unified way by wrapping all relevant datasets in _TransformedDataset in the Learner.

To make _TransformedDataset more general, instead of passing transform and target_transform, we could pass a single transform that operates on the entire data point. When wrapping a buffer in _TransformedDataset, we just need to make sure that the metadata are "ignored" (e.g., transform=lambda (data_point, metadata): (my_transform(data_point), metadata)).

Dashes in reference parse_by_updater

Hello,
I spotted some inconsistency in how you define class names and how you reference them in the parsing.
For example in

"Offline-ER": _add_offline_er_arguments,

you are using "Offline-ER" with a dash whereas the class name is camel case in
class OfflineExperienceReplayLearner(ReplayLearner):

Would it be possible to unify the parsed names? For example if I want to run a job using the run_training_job function I was not sure how to set the updater in

.
Thanks and great library!

Foward Transfer metric in metrics_summary.csv is the same across methods

The forward transfer metric in the metrics_summary.csv is the same across different methods.

Run with Offline-ER and Joint updater on Mnist.

Expected behavior
I would not expect this to be (strictly) same for different methods.

Example
Joint training

Task ID Average Accuracy Forgetting Forward Transfer Backward Transfer
1 0.999054 0.000000 0.000000 0.000000
2 0.968869 0.010875 -0.206170 -0.010875
3 0.913238 0.043618 -0.103085 -0.043618
4 0.853757 0.092997 -0.190409 -0.092997
5 0.646799 0.273982 -0.142806 -0.273982

Offline-ER

Task ID Average Accuracy Forgetting Forward Transfer Backward Transfer
1 0.998582 0.000000 0.000000 0.000000
2 0.977127 0.006619 -0.206170 -0.006619
3 0.956068 0.016186 -0.103085 -0.016186
4 0.948742 0.013383 -0.190409 -0.013383
5 0.932628 0.013746 -0.142806 -0.013746

Make experiment evaluation a separate process

The current experimentation code in benchmarking runs evaluation in the same thread as the (subsequent) trainings. This is a problem when using DDP as the first evaluation (Line 295) creates several processes (as many GPUs) and each of them try to spawn training processes causing a problem with DDP ports clashing.

Describe the solution you'd like
The evaluation/testing should run in a separate process a la run_training_job and this issue wouldn't occur.

**Additional info:
See the discussion on Lightning forum: Lightning-AI/pytorch-lightning#2537

Improve Buffer Scalability

Currently, the buffer data is kept in memory which limits its size. We aim to push this limit by storing the buffer on disk.

moduleNotFoundError: No module named 'wild_time_data'

after installing renate via pip install renate on a new environment python 3.8,
when importing from renate.benchmark.scenarios import ClassIncrementalScenario
there is this in the title mentioned error.

Had to use pip install wild-time-data to mitigate the problem

Fully move to `torch>=1.13`

Avalanche has currently problems loading checkpoints. Therefore, we made some temporary changes (#200). We will need to undo these as soon as the problems are fixed (June 2023, avalanche_lib==0.4.0).

Use more workers in DataLoader

Most of the DataLoader instances use no additional workers for data loading. This can be a bottleneck when performing non-trivial data augmentations on images with a large enough batch size, or with a slower disk.

There are some crude heuristics like 2 x num_gpus which might serve as a reasonable starting point. See
https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813
https://chtalhaanwar.medium.com/pytorch-num-workers-a-tip-for-speedy-training-ed127d825db7#:~:text=Num_workers%20tells%20the%20data%20loader,the%20GPU%20has%20to%20wait.

Tools for detecting shifts in data distribution

There are a number of situations in which detecting shifts in data distribution is useful (e.g., debug model behavior, decide to retrain). It would be great to have shift detectors available directly in Renate.

`run_training_job` allows to pass `chunk_id` to switch data batch but our functions are not more generic than that

          I changed back to using `chunk_id` to switch between the two datasets. I think it is a bit unintuitive for an example, when the `dataset_name` is part of the `config_space`. We explicitly introduced this `chunk_id` parameter in the `run_training_job` function to switch between chunks in the examples, so I think we should keep it consistent.

Originally posted by @lballes in #171 (comment)

We could generalize the function to be more flebible

Standardize usage of `batch_size`

Currently, our ER-based methods use a batch that consists of batch_size points from the current task and memory_batch_size points from the memory. This is inconvenient to compare/standardize to some other learners (e.g., Joint, GDumb, Fine-tuning) and will also result in a smaller total batch size in the first training stage (when no memory exists).

I propose that all methods use batches of size batch_size. ER-based methods can have an additional argument called memory_frac that determines which fraction of the batch will be filled with points from the memory.

Registering buffers in Components

The Component class implements a _register_parameters method that registers the loss' constructor arguments.

This currently works without glitches where Renate manages the state_dict of Learner, Model and others. However, once the move to relying on Lightning for all checkpointing is complete (in #218) this fails as registering the constructor args (say _weight) returns it in the state_dict, and causes an error when restoring the model from checkpoint. The 'fix' for now is going to by changing RenateModule's load_state_dict to use strict=False.

This issue should help cleanup the components code too.

Create a Renate HPO overview under Getting Started

Create a page providing:

  • brief introduction to HPO
  • jargon used to discuss HPO problems (can be part of the point above)
  • different strategies that can be used to do HPO in continual learning (first only, standards, transfer)

Inconsistency in paths between local jobs and training jobs

When switching from a local job to a SM training job there is an inconsistency between the path to the used for the config file.
This happens because the SM job uses the source_dir and copies the file in there, while the source directory path is not used in the local training jobs.

Improve documentation for max_num_trials*

Currently the definition of the different max_num_trials_* options is not clear and it is also not clear if they interact with each other or can be used simultaneously.

pip install renate not working with python 3.10.9

Describe the bug
pip install renate not working with python 3.10.9

To Reproduce (skip if you provide an example below)
Install python 3.10.9 and try to pip install renate

Expected behavior
Installing renate without problems

Output

ERROR: Ignored the following versions that require a different python version: 0.1.0 Requires-Python <=3.10,>=3.8
ERROR: Could not find a version that satisfies the requirement renate (from versions: none)
ERROR: No matching distribution found for renate

Desktop (please complete the following information):

  • OS: MacOS
  • Version: python 3.10.9

Dev Documentation not up-to-date

I noticed that our dev documentation is not up-to-date. It appears, that since the avalanche integration building no longer works. Our test failed to spot it since it was installing dependencies differently to how RTD is doing it.

I'm working on rectifying the test and fixing all issues we've added over the time as part of this PR: #195

[discussion] HPO in a CL scenario

Hi authors, thanks for the awesome library!

I was wondering how the hyperparameter optimization with HPO is implemented in a continual scenario.
From what I found in your documentation, it seems the stopping criterion is based on max_time. From my understanding, this means that the best configuration is chosen as the one with the best metric up to the current task, which may result in configurations that have a "slower" learning behavior being dropped in the early stages of training.

For example, if I tend to learn fast on the current tasks I may have a good accuracy at task 2, but this does not necessarily mean that I will have a good accuracy on task 10 due to forgetting.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.