GithubHelp home page GithubHelp logo

releaunifreiburg / welltunedsimplenets Goto Github PK

View Code? Open in Web Editor NEW
71.0 0.0 13.0 98 KB

[NeurIPS 2021] Well-tuned Simple Nets Excel on Tabular Datasets

License: Apache License 2.0

Python 100.00%
automl deep-learning feedforward-neural-network hpo hyperparameter-optimization regularization state-of-the-art neurips-2021 tabular-data

welltunedsimplenets's Introduction

[NeurIPS 2021] Well-tuned Simple Nets Excel on Tabular Datasets

Introduction

This repo contains the source code accompanying the paper:

Well-tuned Simple Nets Excel on Tabular Datasets

Authors: Arlind Kadra, Marius Lindauer, Frank Hutter, Josif Grabocka

Tabular datasets are the last "unconquered castle" for deep learning, with traditional ML methods like Gradient-Boosted Decision Trees still performing strongly even against recent specialized neural architectures. In this paper, we hypothesize that the key to boosting the performance of neural networks lies in rethinking the joint and simultaneous application of a large set of modern regularization techniques. As a result, we propose regularizing plain Multilayer Perceptron (MLP) networks by searching for the optimal combination/cocktail of 13 regularization techniques for each dataset using a joint optimization over the decision on which regularizers to apply and their subsidiary hyperparameters.

We empirically assess the impact of these regularization cocktails for MLPs on a large-scale empirical study comprising 40 tabular datasets and demonstrate that: (i) well-regularized plain MLPs significantly outperform recent state-of-the-art specialized neural network architectures, and (ii) they even outperform strong traditional ML methods, such as XGBoost.

News: Our work is accepted in the Thirty-fifth Conference on Neural Information Processing Systems (NeurIPS 2021).

Setting up the virtual environment

Our work is built on top of AutoPyTorch. To look at our implementation of the regularization cocktail ingredients, you can do the following:

git clone https://github.com/automl/Auto-PyTorch.git
cd Auto-PyTorch/
git checkout regularization_cocktails

To install the version of AutoPyTorch that features our work, you can use these additional commands:

# The following commands assume the user is in the cloned directory
conda create -n reg_cocktails python=3.8
conda activate reg_cocktails
conda install gxx_linux-64 gcc_linux-64 swig
cat requirements.txt | xargs -n 1 -L 1 pip install
python setup.py install

Running the Regularization Cocktail code

The main files to run the regularization cocktails are in the cocktails folder and are main_experiment.py and refit_experiment.py. The first module can be used to start a full HPO search, while, the other module can be used to refit on certain datasets when the time does not suffice to perform the full HPO search and to complete the refit of the incumbent hyperparameter configuration.

The main arguments for main_experiment.py:

  • --task_id: The task id in OpenML. Basically the dataset that will be used in the experiment.
  • --wall_time: The total runtime to be used. It is the total runtime for the HPO search and also final refit.
  • --func_eval_time: The maximal time for one function evaluation parametrized by a certain hyperparameter configuration.
  • --epochs: The number of epochs for one hyperparameter configuration to be evaluated on.
  • --seed: The seed to be used for the run.
  • --tmp_dir: The temporary directory for the results to be stored in.
  • --output_dir: The output directory for the results to be stored in.
  • --nr_workers: The number of workers which corresponds to the number of hyperparameter configurations run in parallel.
  • --nr_threads: The number of threads.
  • --cash_cocktail: An important flag that activates the regularization cocktail formulation.

A minimal example of running the regularization cocktails:

python main_experiment.py --task_id 233088 --wall_time 600 --func_eval_time 60 --epochs 10 --seed 42 --cash_cocktail True

The example above will run the regularization cocktails for 10 minutes, with a function evaluation limit of 50 seconds for task 233088. Every hyperparameter configuration will be evaluated for 10 epochs, the seed 42 will be used for the experiment and data splits.

A minimal example of running only one regularization method:

python main_experiment.py --task_id 233088 --wall_time 600 --func_eval_time 60 --epochs 10 --seed 42 --use_weight_decay

In case you would like to investigate individual regularization methods, you can look at the different arguments that control them in the main_experiment.py. Additionally, if you want to remove the limit on the number of hyperparameter configurations, you can remove the following lines:

smac_scenario_args={
    'runcount_limit': number_of_configurations_limit,
}

Plots

The plots that are included in our paper were generated from the functions in the module results.py. Although mentioned in most function documentations, most of the functions that plot the baseline diagrams and plots expect a folder structure as follows:

common_result_folder/baseline/results.csv

There are functions inside the module itself that generate the results.csv files.

Baselines

The code for running the baselines can be found in the baselines folder.

  • TabNet, XGBoost, CatBoost can be found in the baselines/bohb folder.
  • The other baselines like AutoGluon, auto-sklearn and Node can be found in the corresponding folders named the same.

TabNet, XGBoost, CatBoost and AutoGluon have the same two main files as our regularization cocktails, main_experiment.py and refit_experiment.py.

Figures

alt text

Citation

@inproceedings{kadra2021well,
  title={Well-tuned Simple Nets Excel on Tabular Datasets},
  author={Kadra, Arlind and Lindauer, Marius and Hutter, Frank and Grabocka, Josif},
  booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
  year={2021}
}

welltunedsimplenets's People

Contributors

arlindkadra avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

welltunedsimplenets's Issues

Double definitions in search space

I get problems, that were solved when I replaced this if with an elif. I only understand the code on a very surface level, though. Not sure this is the right fix for these problems. It seems like 2 things happen: i) you assign to the same node twice (imputer) ii) the one-hot encoder is not supported for non-categorical features.

if has_cat_features:

No requirements.txt?

Hi,
looking to run a comparison of your model vs a few of the current transformer models on some tabular data that I have
In the installation instructions, after the gxx and gcc install, it states to install the requirements.txt, but I don't see one in the directory? Am I missing something?
Note: I am on windows and was unable to install gxx and gcc, unsure if that has an effect on this
thanks,
Jonathan

Package installation

Why do you do cat requirements.txt | xargs -n 1 -L 1 pip install instead of pip install -r requirements.txt?

AttributeError: 'NoneType' object has no attribute 'predict'

Hi there,

I stumbled across your project and paper and it seemed highly interesting. Unfortunately, I cannot successfully execute your project with my dataset. The following error occurs consistently

Traceback (most recent call last):
  File "cocktails/main_experiment.py", line 399, in <module>
    train_predictions = fitted_pipeline.predict(X_train)
AttributeError: 'NoneType' object has no attribute 'predict'

I have checked, and double-checked the dataset but cannot find "the problem". All values are fine, properly scaled etc. etc.
The onlything thats maybe remarkable is: The dataside is a bit wide: 80000 x 650 bu nothing too out of bounds.

Do you have an idea what I should check?
Thanks for your support

Kind regards

How data augmentation is applied to tabular data?

Hi author, thanks for sharing the code. I'm wondering how data augmentation strategies like mix-up, cut-out, cut-mix, etc., can be applied to tabular data (I understand they are usually applied to images though). Please advise, many thanks.

Adapting refit to a custom dataset

Hi!

Loved the paper. Wondering how one could adapt your repo to run a refit experiment on a custom dataset, rather than those on openML?

Some confusion with "cash_cocktail" option

I ran into your paper not too long ago and found it pretty interesting, thanks for sharing the code.

I'm looking to run this on my own dataset and I'm a bit confused as to the cash_cocktail option in main_experiment.py. My impression is that this automatically turns on all the options to search for HPO but when I run the code I get output that looks like this:

{'task_id': 233088, 'wall_time': 9000, 'func_eval_time': 1000, 'epochs': 105, 'seed': 11, 'tmp_dir': './runs/autoPyTorch_cocktails', 'output_dir': './runs/autoPyTorch_cocktails', 'nr_workers': 6, 'nr_threads': 1, 'cash_cocktail': True, 'use_swa': [False], 'use_se': [False], 'use_lookahead': [False], 'use_weight_decay': [False], 'use_batch_normalization': [False], 'use_skip_connection': [False], 'use_dropout': [False], 'mb_choice': 'none', 'augmentation': 'standard'}
{'task_id': 233088, 'wall_time': 9000, 'func_eval_time': 1000, 'epochs': 105, 'seed': 11, 'tmp_dir': './runs/autoPyTorch_cocktails', 'output_dir': './runs/autoPyTorch_cocktails', 'nr_workers': 6, 'nr_threads': 1, 'cash_cocktail': True, 'use_swa': [False], 'use_se': [False], 'use_lookahead': [False], 'use_weight_decay': [False], 'use_batch_normalization': [False], 'use_skip_connection': [False], 'use_dropout': [False], 'mb_choice': 'none', 'augmentation': 'standard'}

It looks like the options aren't being used? e.g. 'use_swa': [False], 'use_se': [False], 'use_lookahead': [False],...

I guess to rephrase the question: if one were to use your code on their private dataset, what options do you need to pass in to ensure that you are doing the full HPO? is it just the cash_cocktail flag?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.