GithubHelp home page GithubHelp logo

asgaardlab / overfitguard Goto Github PK

View Code? Open in Web Editor NEW
2.0 2.0 0.0 20.73 MB

Replication package of the paper "Keeping Deep Learning Models in Check: A History-Based Approach to Mitigate Overfitting"

Python 4.49% Jupyter Notebook 95.51%

overfitguard's Introduction

Replication package of the paper "Keeping Deep Learning Models in Check: A History-Based Approach to Mitigate Overfitting"

This repository provides all the data, code, notebook, and the trained models required to replicate our paper:

  • The datasets can be found under the ./data folder, where:
    • the training folder contains the simulated dataset
    • the testing folder contains the real-world dataset
  • Using ./train.py for training and ./predict.py for prediction
  • Using the notebook ./reproduce.ipynb to see the results and figures of our paper
  • The trained models can be found under the models folder
  • The full list of surveyed paper can be found in surveyed_paper.md

Setup environment

This repository is based on Python 3.8.13 version.

Conda

conda env create -f environment.yml

pip

pip install -r requirements.txt

Data preparation

This project is for detecting overfitting based on training logs. The format of the training log should be a json file and contain:

  • Training loss
  • Validation loss

The names of the keys should be specified as train_metric and monitor_metric. For example, a training log stores training loss under key named training_loss and validation loss under key named validation_loss:

{
  "training_loss": [0.720, 0.716, ...],
  "validation_loss": [0.707, 0.706, ...],
  "train_metric": "training_loss",
  "monitor_metric": "validation_loss"
}

Example training logs can be found in ./data/training/dataset_exp4 folder.

Training Overfitting Detection Methods

Correlation-based Methods

We provide three methods:

  • Spearman (recommend)
  • Pearson
  • Autocorrelation

Training by:

python train.py spearman ./data/training/dataset_exp4 ./out
python train.py pearson ./data/training/dataset_exp4 ./out
python train.py autocorr ./data/training/dataset_exp4 ./out

Time series classifiers

We provide 6 time series classifiers:

  • KNN-DTW: K-Nearest Neighbors with Dynamic Time Warping (recommended)
  • TSF: Time Series Forest (recommended)
  • TSBF: Time Series Bag-of-Features
  • HMM-GMM: Hidden Markov Model with Gaussian Mixture Model emissions
  • BOSSVS: Bag-of-SFA Symbols in Vector Space
  • SAX-VSM: Symbolic Aggregate approXimation and Vector Space Model

KNN-DTW has the highest F1-score in our experiments, but it requires more time to compute than other methods. TSF has a lower F1-score than KNN-DTW but faster.

python train.py knndtw ./data/training/dataset_exp4 ./out --zero_pad
python train.py tsf ./data/training/dataset_exp4 ./out
python train.py tsbf ./data/training/dataset_exp4 ./out
python train.py hmmgmm ./data/training/dataset_exp4 ./out
python train.py bossvs ./data/training/dataset_exp4 ./out
python train.py saxvsm ./data/training/dataset_exp4 ./out

The trained models are saved under the ./out folder for later use.

Using the Trained Detection Methods

Overfitting detection

Preparing the training logs (one or more) that requires overfitting detection and run the following code:

python predict.py TRAINED_METHOD_PATH DATA_DIR OUTPUT_DIR
# examples
python predict.py ./out/spearman_{DATE}.pkl ./data/testing/real_world_data/ ./out
python predict.py ./out/knndtw_{DATE}.pkl ./data/testing/real_world_data/ ./out
python predict.py ./out/tsf_{DATE}.pkl ./data/testing/real_world_data/ ./out

Overfitting prevention

The trained model can be used for overfitting prevention:

overfitguard's People

Contributors

leo-lihao avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.