GithubHelp home page GithubHelp logo

somepago / saint Goto Github PK

View Code? Open in Web Editor NEW
378.0 8.0 59.0 234 KB

The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

License: Apache License 2.0

Python 73.44% Jupyter Notebook 25.96% Shell 0.60%
deep-learning tabular-data transformer

saint's People

Contributors

somepago avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

saint's Issues

Muilti output regression

Hello,
I am currently working on a regression task containing multiple label dimensions.
I saw that your code has an implemantation on the regression task, but only for y_dim=1 and not for multiple dimensions (y_dim=n where n>1).

I tried running my regression task by changing the y_dim variable and running the code but it does not work (apparently there is a mix up with the dimension, and I am not sure where is the right place to change the code).
I wanted to ask if there is a simple way to run the model for a regression task with several out dimensions?

Thank you

Problem about dataset id of Openml

Thank you for sharing your great work!
When I want to evaluate your result on all datasets that are listed on your paper, eg, Bank, Blastchar, Arrhythmia, ..., I had a problem about your code in data_openml.py.
The id_dataset id that you listed in the file (1487,44, ...) did not match with datasets you list on paper (bank, blastchar, ...).
id: 1487 when I use api of opennl.datasets.get_dataset(1487), I got ozone-level-8hr dataset.
Might you give me some suggestions to evaluate your result on datasets you listed on paper.

Many thanks!

issue about data

Hi,
I find in your paper “Results are averaged over 5 trials and 14 binary classification datasets.” However, there is “'binary [1487,44,1590,42178,1111,31,42733,1494,1017,4134]” in your code. Could you provide other datasets?

Deprecated saint_environment.yml file

Hi,

I am trying to reproduce your experiments as a baseline for my paper but lots of the dependencies in the yml file are deprecated. Would it be possible for you to advice on this please/provide an update to the yml file?

Thank you!

Issues with datasets

Hi,

First of all - very impressive project and repository! Chapeau to you all.

I'm trying to generate some of the results, having issues with, e.g., HTRU2. I searched openml datasets and reached dataset id 43377, but this doesn't load with your code (y values are None).

Maybe I'm looking in the wrong place, but - could you provide a list of opeml dataset ids to recreate the results in your paper?

Code for interpretability

Hello,

Thanks for this repo! Do you plan to release the code you used to plot intersample attention and self attention as in the paper (section 5.2)? I would like to reproduce the figure 3.

Including target data while pretraining

Hello thanks for the awesome works and the codes.

I've applied your code to some datasets and had some questions.

While pretraining, line 323 in data.py performs concatenating category features and target features.
This concatenated categorical data pass the embedding layer and the results used as an input data of the transformer.
I couldn't find the code that separating the target data before passing the concatenated data into the transformer.

It is okay to include target data while pretraining the model?

Thank you.

MLP is equivalent to a single linear layer

One more issue - your implementation of MLP, in model.py, is just a bunch of stacked linear layers with no non-linearities between them. This is mathematically equivalent to just a linear layer with in_dim = dims[0] and out_dim=dims[-1].

Why not use the activation between the layers?

Handling missing values

Hello thanks for the code and awesome works.

I read your paper impressively and have a question about the code.

In the paper, p.14 Data preprocessing, it is written to
"Each feature (or columns) has a different missing value token to account for missing data.".

However I found that the code just fill missing values with an average value for continuous features.

I wonder the token embedding works only for categorical data.

It was very exciting to read the paper and I hope to apply the algorithm to my dataset soon!

reps = model_saint.transformer(x_categ_enc, x_cont_enc)==nan

I applied this great model to regression, but the value is nan in the model.transformer part.

class RowColTransformer(nn.Module):
~~~~~~~~~
    def forward(self, x, x_cont=None, mask = None):
        if x_cont is not None:
            x = torch.cat((x,x_cont),dim=1)
        _, n, _ = x.shape
        print("TRANFOERMR")
        if self.style == 'colrow':
            for attn1, ff1, attn2, ff2 in self.layers: 
                x = attn1(x)##here x==nan

Did this happen during implementation? If anyone has used it for their own data, please let me know.

these are hyper params

model_saint = SAINT(
    categories = tuple(cat_dims.values()),#len(cat_dims)==2
    num_continuous = len(numerical_features)+1,         
    dim =128,                           
    dim_out = 1,                       
    depth = 6,                       
    heads = 8,                         
    attn_dropout = 0.1,             
    ff_dropout = 0.1,                  
    mlp_hidden_mults = (4, 2),       
    continuous_mean_std = None, 
    cont_embeddings = "MLP",
    attentiontype = 'col',
    final_mlp_style = 'sep',
    y_dim = 1
    )

optim:AdamW(model_saint.parameters(), lr=1e-3,weight_decay=5e-5)
BATCH_size=256

Fixed [CLS] token during inference?

Hi

For inference, the CLS token(L157 and L160 in train.py) is still basing on ground-truth label, should they be static CLS token instead?

Attention plotting code

Is there any progress on the releasement of the plotting code?

Thanks in advance!


Hi Fabien, I will release the attention plotting code in the next version. I am busy with another project rn, I am targeting the end of November for this. If you need it urgently let me know.

Originally posted by @somepago in #8 (comment)

Module is absent when importing

Hello!

There is an import in pretraining.py file:

from baselines.data_openml import data_prep_openml,task_dset_ids,DataSetCatCon on line 4

However, in the repo there is no folder baselines, and thus there is an error, when I attempt to apply pretraining in train file.

Thanks!

Plotting attention for explainability

Hello Gowthami,

Thank you for this project. It shows uplift in performance for my use-case over xgboost. It will be of great help to get the attention plotting code (both self attention and inter-sample attention)for the SAINT implementation as shown by you in the paper for SAINT.

Implementation of Attention module in Transformer

Thank you for sharing your work, it has actually been helping me a lot.
I have a problem with your code relating Attention module of Transformer. May I be wrong that the Attention module should have dropout layer after softmax function (link). For example, link or link, they used dropout layer in Attention module.

enumerate DataLoader is too slow.

for i, data in enumerate(trainloader, 0) #181
This code is stopped without turning around. I think it's going on an infinite loop, can you solve it?

Benchmark results difference

Hi,
I was trying to reproduce the benchmark (xgboost and lightgbm) results but i can't get the same showed in your paper.

I used this to split the dataset in train, valid and test:

saint/data.py

Line 190 in e0ee763

train["Set"] = np.random.choice(["train", "valid", "test"], p = datasplit, size=(train.shape[0],))

I used early stop on validation and collect test performance as final results and rerun the experiment on 5 different seed (0, 1, ..., 5) as you do for Saint model.

I used standard parameter for xgboost and lightgbm with some regularization.

I used the dataset you provide in the following link:
https://drive.google.com/file/d/1mJtWP9mRP0a10d1rT6b3ksYkp4XOpM0r/view?usp=sharing

The results i get are:

Model\Dataset Bank Blastchar arrhytmia Arcene Forest Shoppers Income Volkert
lightgbm 93.46 83.71 93.18 85.25 99.79 93.23 92.03 71.46
xgboost 93.41 83.67 93.13 87.66 99.71 92.62 92.36 70.32

My experiment show clear improvement of the benchmark result as showed below:

Model\Dataset Bank Blastchar arrhytmia Arcene Forest Shoppers Income Volkert
lightgbm +0.069 +0.54 +4.45 +4.2 +6.5 +0.03 -0.54 +3.55
xgboost +0.45 +1.89 +11.14 +6.25 +4.38 +0.11 0.05 +1.37

Can you share the code used to calculate benchmark results?

I used also quite standard parameter to train xgboost and lightgbm:

xgboost:
- max_depth=8,
- learning_rate=0.01,
- tree_method = 'hist',
- subsample=0.75,
- colsample_bytree=0.75,
- reg_alpha= 0.5,
- reg_lambda= 0.5,

lightgbm:
- learning_rate= 0.01,
- max_depth= -1,
- num_leaves= 2**8,
- lambda_l1= 0.5,
- lambda_l2= 0.5,
- feature_fraction= 0.75,
- bagging_fraction= 0.75,
- bagging_freq = 1,

I think i can improve these results by tuning these parameter more.

fillna continuous data

Hello, I'm a beginner interested in Tabular Learning. Your superb paper, SAINT, impresses me a lot. But I've had some problems learning your code.

For

train.fillna(train.loc[train_indices, col].mean(), inplace=True)
or
X.fillna(X.loc[train_indices, col].mean(), inplace=True)

a) Why is train.loc[train_indices, col] rather than train.loc[:, col]?
Vaild data and test data may also be nan.
b) Why is train.fillna rather than train[col].fillna?
It may fillnan for other columns.

I think the correct expression should be train[col].fillna(train.loc[:, col].mean(), inplace=True).

I'm not sure whether I am correct. I would appreciate it if you can reply. Thank you very much!

Regression Task

Thanks for awesome work.

I'm using a tabular dataset for a regression task. I would like to predict the last column (float values) in the picture below.

image

I'm not sure how should I setup network and esp these two parameters:

categories = tuple(cat_dims),
num_continuous = len(con_idxs) 

For now I'm using

con_idxs = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]

If I change the last column values to int using train[target] = train[target].astype(int) and use the following as cat dims it starts training but I want to predict floating values.

cat_dims = np.append(np.array(cat_dims),np.array([50])).astype(int)

If I dont convert target to int it throws following error:

 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.DoubleTensor instead (while checking arguments for embedding)
``

Environment: packages not found

Thank you so much for sharing this impressive work.
I failed in creating the environment. Anything I could do to fix this error? My error detail is listed below:

K:\library\saint>conda env create -f saint_environment.yml
Collecting package metadata (repodata.json): done
Solving environment: failed

ResolvePackageNotFound:
  - gmp==6.2.1=h58526e2_0
  - certifi==2021.5.30=py38h578d9bd_0
  - lame==3.100=h7f98852_1001
  - promise==2.3=py38h578d9bd_3
  - jupyter_core==4.7.1=py38h578d9bd_0
  - libglib==2.68.3=h3e27bee_0
  - setuptools==49.6.0=py38h578d9bd_3
  - ffmpeg==4.3=hf484d3e_0
  - markupsafe==2.0.1=py38h497a2fe_0
  - libprotobuf==3.17.2=h780b84a_0
  - libgomp==9.3.0=h2828fa1_19
  - protobuf==3.17.2=py38h709712a_0
  - yaml==0.2.5=h516909a_0
  - gst-plugins-base==1.14.0=hbbd80ab_1
  - freetype==2.10.4=h0708190_1
  - pcre==8.45=h9c3ff4c_0
  - tornado==6.1=py38h497a2fe_1
  - _openmp_mutex==4.5=1_gnu
  - debugpy==1.3.0=py38h709712a_0
  - xgboost==1.4.0=py38h578d9bd_0
  - expat==2.4.1=h9c3ff4c_0
  - kiwisolver==1.3.1=py38h1fd1430_1
  - pyzmq==22.1.0=py38h2035c66_0
  - glib==2.68.3=h9c3ff4c_0
  - tk==8.6.10=h21135ba_1
  - pysocks==1.7.1=py38h578d9bd_3
  - websocket-client==0.57.0=py38h578d9bd_4
  - ipython==7.25.0=py38hd0cf306_1
  - numpy-base==1.20.2=py38hfae3a4d_0
  - libffi==3.3=h58526e2_2
  - nbconvert==6.1.0=py38h578d9bd_0
  - libuuid==2.32.1=h7f98852_1000
  - numpy==1.20.2=py38h2d18471_0
  - mkl_random==1.2.2=py38h1abd341_0
  - pthread-stubs==0.4=h36c2ea0_1001
  - libpng==1.6.37=h21135ba_2
  - mkl_fft==1.3.0=py38h42c9631_2
  - chardet==4.0.0=py38h578d9bd_1
  - readline==8.1=h46c0cb4_0
  - psutil==5.8.0=py38h497a2fe_1
  - shortuuid==1.0.1=py38h578d9bd_4
  - gstreamer==1.14.0=h28cd5cc_2
  - ld_impl_linux-64==2.35.1=hea4e1c9_2
  - libgcc-ng==9.3.0=h2828fa1_19
  - xorg-libxau==1.0.9=h7f98852_0
  - mistune==0.8.4=py38h497a2fe_1004
  - pytorch==1.8.1=py3.8_cuda11.1_cudnn8.0.5_0
  - libunistring==0.9.10=h14c3975_0
  - fontconfig==2.13.1=hba837de_1005
  - importlib-metadata==4.6.1=py38h578d9bd_0
  - glib-tools==2.68.3=h9c3ff4c_0
  - libuv==1.41.0=h7f98852_0
  - click==8.0.1=py38h578d9bd_0
  - xorg-libxdmcp==1.1.3=h7f98852_0
  - mkl-service==2.4.0=py38h497a2fe_0
  - watchdog==0.10.4=py38h578d9bd_0
  - pillow==8.2.0=py38he98fc37_0
  - py-xgboost==1.4.0=py38h578d9bd_0
  - qt==5.9.7=h5867ecd_1
  - libidn2==2.3.1=h7f98852_0
  - brotlipy==0.7.0=py38h497a2fe_1001
  - libwebp-base==1.2.0=h7f98852_2
  - cryptography==3.4.7=py38ha5dfef3_0
  - gettext==0.19.8.1=h0b5b191_1005
  - scikit-learn==0.23.2=py38h0573a6f_0
  - libxcb==1.13=h7f98852_1003
  - argon2-cffi==20.1.0=py38h497a2fe_2
  - sqlite==3.35.5=h74cdb3f_0
  - nettle==3.6=he412f7d_0
  - openssl==1.1.1k=h7f98852_0
  - matplotlib==3.4.2=py38h578d9bd_0
  - anyio==3.2.1=py38h578d9bd_0
  - jedi==0.18.0=py38h578d9bd_2
  - libxml2==2.9.12=h03d6c58_0
  - sniffio==1.2.0=py38h578d9bd_1
  - xz==5.2.5=h516909a_1
  - wget==1.20.1=h22169c7_0
  - mkl==2021.2.0=h06a4308_296
  - libiconv==1.16=h516909a_0
  - jpeg==9b=h024ee3a_2
  - ca-certificates==2021.5.30=ha878542_0
  - gnutls==3.6.13=h85f3911_1
  - matplotlib-base==3.4.2=py38hcc49a3a_0
  - libgfortran-ng==7.3.0=hdf63c60_0
  - lcms2==2.12=h3be6417_0
  - icu==58.2=hf484d3e_1000
  - libxgboost==1.4.0=h9c3ff4c_0
  - pandoc==2.14.0.3=h7f98852_0
  - libsodium==1.0.18=h36c2ea0_1
  - dbus==1.13.18=hb2f20db_0
  - pandas==1.2.4=py38h1abd341_0
  - pyyaml==5.4.1=py38h497a2fe_0
  - zstd==1.4.9=ha95c52a_0
  - cudatoolkit==11.1.1=h6406543_8
  - python==3.8.10=h49503c6_1_cpython
  - _libgcc_mutex==0.1=conda_forge
  - zeromq==4.3.4=h9c3ff4c_0
  - pyrsistent==0.17.3=py38h497a2fe_2
  - cffi==1.14.5=py38ha65f79e_0
  - openh264==2.1.1=h780b84a_0
  - libtiff==4.2.0=h85742a9_0
  - lz4-c==1.9.3=h9c3ff4c_0
  - scipy==1.6.2=py38had2a1c9_1
  - ipykernel==6.0.2=py38hd0cf306_0
  - ninja==1.10.2=h4bd325d_0
  - pyqt==5.9.2=py38h05f1152_4
  - intel-openmp==2021.2.0=h06a4308_610
  - sip==4.19.13=py38he6710b0_0
  - zlib==1.2.11=h516909a_1010
  - bzip2==1.0.8=h7f98852_4
  - ncurses==6.2=h58526e2_4
  - libstdcxx-ng==9.3.0=h6de172a_19
  - terminado==0.10.1=py38h578d9bd_0

Question: only continuous variables (no category)

Is it possible to use SAINT for the tabular data, which contains only continuous variables, without categorical?

We need to pass to SAINT model two parameters: x_categ and x_cont
Do I need to pass some torch.empy tensor as x_categ?
What to pass as "categories" parameter to the SAINT model? Empty tuple?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.