GithubHelp home page GithubHelp logo

somepago / saint Goto Github PK

View Code? Open in Web Editor NEW
376.0 8.0 60.0 234 KB

The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

License: Apache License 2.0

Python 73.44% Jupyter Notebook 25.96% Shell 0.60%
deep-learning tabular-data transformer

saint's Introduction

This repository is the official PyTorch implementation of SAINT. Find the paper on arxiv

SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Overview

Requirements

We recommend using anaconda or miniconda for python. Our code has been tested with python=3.8 on linux.

Create a conda environment from the yml file and activate it.

conda env create -f saint_environment.yml
conda activate saint_env

Make sure the following requirements are met

  • torch>=1.8.1
  • torchvision>=0.9.1

Optional

We used wandb to update our logs. But it is optional.

conda install -c conda-forge wandb 

Training & Evaluation

In each of our experiments, we use a single Nvidia GeForce RTX 2080Ti GPU.

To train the model(s) in the paper, run this command:

python train.py --dset_id <openml_dataset_id> --task <task_name> --attentiontype <attention_type> 

Pretraining is useful when there are few training data samples. Sample code looks like this. (Use train_robust.py file for pretraining and robustness experiments)

python train_robust.py --dset_id <openml_dataset_id> --task <task_name> --attentiontype <attention_type>  --pretrain --pt_tasks <pretraining_task_touse> --pt_aug <augmentations_on_data_touse> --ssl_samples <Number_of_labeled_samples>

Arguments

  • --dset_id : Dataset id from OpenML. Works with all the datasets mentioned in the paper. Works with all OpenML datasets.
  • --task : The task we want to perform. Pick from 'regression','multiclass', or 'binary'.
  • --attentiontype : Variant of SAINT. 'col' refers to SAINT-s variant, 'row' is SAINT-i, and 'colrow' refers to SAINT.
  • --embedding_size : Size of the feature embeddings
  • --transformer_depth : Depth of the model. Number of stages.
  • --attention_heads : Number of attention heads in each Attention layer.
  • --cont_embeddings : Style of embedding continuous data.
  • --pretrain : To enable pretraining
  • --pt_tasks : Losses we want to use for pretraining. Multiple arguments can be passed.
  • --pt_aug : Types of data augmentations used in pretraining. Multiple arguments are allowed. We support only mixup and CutMix right now.
  • --ssl_samples : Number of labeled samples used in semi-supervised experiments.
  • --pt_projhead_style : Projection head style used in contrastive pipeline.
  • --nce_temp : Temperature used in contrastive loss function.
  • --active_log : To update the logs onto wandb. This is optional

Most of the hyperparameters are hardcoded in train.py file. For datasets with really high number of features, we suggest using smaller batchsize, lower embedding dimension and fewer number of heads.

Evaluation

We choose the best model by evaluating the model on validation dataset. The AuROC(for binary classification datasets), Accuracy (for multiclass classification datasets), and RMSE (for regression datasets) of the best model on test datasets is printed after training is completed. If wandb is enabled, they are logged to 'test_auroc_bestep', 'test_accuracy_bestep', 'test_rmse_bestep' variables.

What's new in this version?

  • Regression and multiclass classification models are added.
  • Data can be accessed directly from openml just by calling the id of the dataset.

Acknowledgements

We would like to thank the following public repo from which we borrowed various utilites.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Cite us

@article{somepalli2021saint,
  title={SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training},
  author={Somepalli, Gowthami and Goldblum, Micah and Schwarzschild, Avi and Bruss, C Bayan and Goldstein, Tom},
  journal={arXiv preprint arXiv:2106.01342},
  year={2021}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.