GithubHelp home page GithubHelp logo

lucidrains / tab-transformer-pytorch Goto Github PK

View Code? Open in Web Editor NEW
699.0 15.0 90.0 272 KB

Implementation of TabTransformer, attention network for tabular data, in Pytorch

License: MIT License

Python 100.00%
artificial-intelligence deep-learning transformer attention-mechanism tabular-data

tab-transformer-pytorch's Introduction

Tab Transformer

Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's breadth of GBDT's performance.

Update: Amazon AI claims to have beaten GBDT with Attention on a real-world tabular dataset (predicting shipping cost).

Install

$ pip install tab-transformer-pytorch

Usage

import torch
import torch.nn as nn
from tab_transformer_pytorch import TabTransformer

cont_mean_std = torch.randn(10, 2)

model = TabTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    dim = 32,                           # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    depth = 6,                          # depth, paper recommended 6
    heads = 8,                          # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1,                   # feed forward dropout
    mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
    mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm
)

x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually

pred = model(x_categ, x_cont) # (1, 1)

FT Transformer

This paper from Yandex improves on Tab Transformer by using a simpler scheme for embedding the continuous numerical values as shown in the diagram above, courtesy of this reddit post.

Included in this repository just for convenient comparison to Tab Transformer

import torch
from tab_transformer_pytorch import FTTransformer

model = FTTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    dim = 32,                           # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    depth = 6,                          # depth, paper recommended 6
    heads = 8,                          # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1                    # feed forward dropout
)

x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_numer = torch.randn(1, 10)              # numerical value

pred = model(x_categ, x_numer) # (1, 1)

Unsupervised Training

To undergo the type of unsupervised training described in the paper, you can first convert your categories tokens to the appropriate unique ids, and then use Electra on model.transformer.

Todo

Citations

@misc{huang2020tabtransformer,
    title   = {TabTransformer: Tabular Data Modeling Using Contextual Embeddings},
    author  = {Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
    year    = {2020},
    eprint  = {2012.06678},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@article{Gorishniy2021RevisitingDL,
    title   = {Revisiting Deep Learning Models for Tabular Data},
    author  = {Yu. V. Gorishniy and Ivan Rubachev and Valentin Khrulkov and Artem Babenko},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2106.11959}
}

tab-transformer-pytorch's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tab-transformer-pytorch's Issues

Is there any training example about tabtransformer?

Hi,
I want to use it in a tabular dataset to finish a supervised learning,But I dont really know how to train this model with dataset(it seems that there is no such content in the readme file ). Could you please help me? thank you.

Questions about GDBT in paper

Hi, I have read your paper recently. It is an interesting work and outperforms existing methods.
But I have some problems ( may be silly :) ).

  • Which GBDT method was used in the experiment? My guess is CatBoost, which focuses on category features.
  • Why LightGBM is not added to the reference to GBDT? LGBM also has special treatment for feature cross.
  • The embedding of category features requires a very large number of parameters (d x m). Considering the number of parameters, the improvement is not particularly large compared to MLP.
  • Can you give TabTransformer's running time? The overhead of such a method should be much greater than gbdt-based methods.

index -1 is out of bounds for dimension 1 with size 17

I encountered this problem during the training process. What is the possible reason for this problem, and how can I solve this problem? Thanks!

  File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 583, in forward
    return self.tabnet(x)
  File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 468, in forward
    steps_output, M_loss = self.encoder(x)
  File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 160, in forward
    M = self.att_transformers[step](prior, att)
  File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 637, in forward
    x = self.selector(x)
  File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 109, in forward
    return sparsemax(input, self.dim)
  File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 52, in forward
    tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
  File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 94, in _threshold_and_support
    tau = input_cumsum.gather(dim, support_size - 1)
RuntimeError: index -1 is out of bounds for dimension 1 with size 17
Experiment has terminated.

Minor Bug: actuation function being applied to output layer in class MLP

The code for class MLP is mistakingly applying the actuation function to the last (i.e. output) layer. The error is in the evaluation of the is_last flag. The current code is:

class MLP(nn.Module):
    def __init__(self, dims, act = None):
        super().__init__()
        dims_pairs = list(zip(dims[:-1], dims[1:]))
        layers = []
        for ind, (dim_in, dim_out) in enumerate(dims_pairs):
            is_last = ind >= (len(dims) - 1)

The last line should be changed to is_last = ind >= (len(dims) - 2):

class MLP(nn.Module):
    def __init__(self, dims, act = None):
        super().__init__()
        dims_pairs = list(zip(dims[:-1], dims[1:]))
        layers = []
        for ind, (dim_in, dim_out) in enumerate(dims_pairs):
            is_last = ind >= (len(dims) - 2)

If you like, I can do a pull request.

low gpu usage,

Hi.

I'm having a problem with running your code with my dataset. It's pretty slow. GPU runs at 50% usage in average and each epoch takes almost 900 seconds to run.

My dataset has 590540 rows, 24 categorical features, and 192 continuous features. Categories are encoded using Label encoder. Total dataset size is around 600Mb. My gpu is an integrated NVIDIA RTX 3060 with 6Gb of RAM. Optimizer is Adam.

These are the software versions:

Windows 10

Python: 3.7.11
Pytorch: 1.7.0+cu110
Numpy: 1.21.2

Let me know if you need more info from my side.

Thanks.

Xin.

The Paper describes one Embedding for each column

Hi,

In the TabTranformer paper it says:
"Column embedding. For each categorical feature (col-
umn) i, we have an embedding lookup table eφi (.), for
i ∈ {1, 2, ..., m}. [...]"

In your source code, for both implementations you use one Lookup Table for all columns.
Or am I missing something? I am still learning.

Thank you!

How to pretrain?

Hello,sir,I really appreciate your work! I am very curious about how to pretrain the tabtransformer? Is there an example?

No Category Shared Embedding?

I noticed that this implementation does not seem to have the feature of a shared embedding between each value belonging to the same category (unless I missed it) that the paper mentions (c_phi_i). If it's indeed missing, do you have plans to add that?

Thanks for this implementation!

TypeError: can only concatenate str (not "int") to str

Hi,

I am getting this issue when I try to run pred = model(x_categ, x_cont)

One more question, we are supposed to give exact number of unique values for each categorical column in the order they appear in data (categories) and the total number of continuous columns (num_continuous)... is there anything else I am missing out?

model = TabTransformer(categories = tuple(cat_list), num_continuous = len(continuous_cols), dim = 32, dim_out = 61, depth = 6, heads = 8, attn_dropout = 0.1, ff_dropout = 0.1, mlp_hidden_mults = (4, 2), mlp_act = nn.ReLU())

FT_Transformer - Attention weights

Thanks for these models. The FT_transformer is working well. Is there a way to extract the attention weights from the model. I understand these can be used to get feature importance.

Extracting Latent Spaces

Dear Team,
first thank you for your awesome work. This is less an issue than a question.

I have the goal to train a model with a contrasting-learning-method between tabular data and MRIs. Is it possible to get a latent space representation from your data that I could use to compute it with my images?
As I understand just the categorial variables go through the transformer and the continues through a Layer-Normalization.
As your model is working on labeled data as I understand, does this latent space actually has any specific meeting?

What could I return from your code to use it as latent space? And would it be possible to generate data back from this latent space in human readable data?

So to summarize:
I want a usable latent space from tabular data that represents the relationships between the items in a meaningful way to use contrastive learning on it. Do you think your TabTransformer is suitable for this?

Thank you very much for your work and I hope you can help me.

Intended usage of num_special_tokens?

From what I understand, these are supposed to be reserved for oov values. Is the intended usage to set oov values in the input to some negative number and overwrite the offset? That is what it seems like it would take to achieve the desired outcome, but also seems somewhat confusing and clunky to do. Or perhaps I am misunderstanding its purpose? Thanks!

Hyperparameters of MLP part should be changed, if it refers to the paper

I appreciate for your code :)

I want to suggest an issue about hyperparameters.
I think, according to the paper, hyperparameters of MLP part should be changed.

According to Appendix B of the paper, "mlp_hidden_mults" is multiplied to "input_size",
and "l" is shared embedding dimension

The code should be changed as below. (class TabTransformer() - def init())

[Original code]

input_size = (dim * self.num_categories) + num_continuous
l = input_size // 8

hidden_dimensions = list(map(lambda t: l * t, mlp_hidden_mults))

[Modified code]

input_size = (dim * self.num_categories) + num_continuous
l = dim // 8 # to be used shared embedding

hidden_dimensions = list(map(lambda t: input_size * t, mlp_hidden_mults))

I think it could be very confusing because the author of the paper used two kinds of "l" parameters (size of the input & dimension of shared embedding)

Other person already created issue about shared embedding, so the code should be modified considering this issue too.
#12

Please check whether my opinion is correct or not.

Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.