GithubHelp home page GithubHelp logo

zaidrz / clsurvey Goto Github PK

View Code? Open in Web Editor NEW

This project forked from mattdl/clsurvey

0.0 0.0 0.0 169 KB

Continual Hyperparameter Selection Framework. Compares 11 state-of-the-art Lifelong Learning methods and 4 baselines. Official Codebase of "A continual learning survey: Defying forgetting in classification tasks." in IEEE TPAMI.

Home Page: https://ieeexplore.ieee.org/abstract/document/9349197

License: Other

Shell 1.18% Python 98.82%

clsurvey's Introduction

A continual learning survey: Defying forgetting in classification tasks

This is the original source code for the Continual Learning survey paper "A continual learning survey: Defying forgetting in classification tasks" published at TPAMI [TPAMI paper] [Open-Access paper].

This work allows comparing the state-of-the-art in a fair fashion using the Continual Hyperparameter Framework, which sets the hyperparameters dynamically based on the stability-plasticity dilemma. This addresses the longstanding problem in literature to set hyperparameters for different methods in a fair fashion, using ONLY the current task data (hence without using iid validation data, which is not available in continual learning).

The code contains a generalizing framework for 11 SOTA methods and 4 baselines in Pytorch.
Implemented task-incremental methods are

SI | EWC | MAS | mean/mode-IMM | LWF | EBLL | PackNet | HAT | GEM | iCaRL

These are compared with 4 baselines:

Joint | Finetuning | Finetuning-FM | Finetuning-PM

  • Joint: Learn from all task data at once with a single head (multi-task learning baseline).
  • Finetuning: standard SGD
  • Finetuning with Full Memory replay: Allocate memory dynamically to incoming tasks.
  • Finetuning with Partial Memory replay: Divide memory a priori over all tasks.

This source code is released under a Attribution-NonCommercial 4.0 International license, find out more about it in the LICENSE file.

Pipeline

Reproducibility: Results from the paper can be obtained from src/main_'dataset'.sh. Full pipeline example in src/main_tinyimagenet.sh .

Pipeline: Constructing a custom pipeline typically requires the following steps.

  1. Project Setup
    1. For all requirements see requirements.txt. Main packages can be installed as in
      conda create --name <ENV-NAME> python=3.7
      conda activate <ENV-NAME>
      
      # Main packages
      conda install -c conda-forge matplotlib tqdm
      conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
      
      # For GEM QP
      conda install -c omnia quadprog
      
      # For PackNet: torchnet 
      pip install git+https://github.com/pytorch/tnt.git@master
      
    2. Set paths in 'config.init' (or leave default)
      1. '{tr,test}_results_root_path': where to save training/testing results.
      2. 'models_root_path': where to store initial models (to ensure same initial model)
      3. 'ds_root_path': root path of your datasets
    3. Prepare dataset: see src/data/"dataset"_dataprep.py (e.g. src/data/tinyimgnet_dataprep.py)
  2. Train any out of the 11 SOTA methods or 4 baselines
    1. Regularization-based/replay methods: We run a first task model dump, for Synaptic Intelligence (SI) as it acquires importance weights during training. Other methods start from this same initial model.
    2. Baselines/parameter isolation methods: Start training sequence from scratch
  3. Evaluate performance, sequence for testing on a task is saved in dictionary format under test_results_root_path defined in config.init.
  4. Plot the evaluation results, using one of the configuration files in utilities/plot_configs

Implement Your Method

  1. Find class "YourMethod" in methods/method.py. Implement the framework phases (documented in code).
  2. Implement your task-based training script in methods: methods/"YourMethodDir". The class "YourMethod" will call this code for training/eval/processing of a single task.

Project structure

  • src/data: datasets and automated preparation scripts for Tiny Imagenet and iNaturalist.
  • src/framework: the novel task incremental continual learning framework. main.py starts training pipeline, specify --test argument to perform evaluation with eval.py.
  • src/methods: all methods source code and method.py wrapper.
  • src/models: net.py all model preprocessing.
  • src/utilities: utils used across all modules and plotting.
  • Config:

Credits

Support

  • If you have troubles, please open a Git issue.
  • Have you defined your method in the framework and want to share it with the community? Send a pull request!

clsurvey's People

Contributors

mattdl avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.