GithubHelp home page GithubHelp logo

Severe overfitting about tabnet HOT 18 CLOSED

979-Ryan avatar 979-Ryan commented on May 31, 2024
Severe overfitting

from tabnet.

Comments (18)

Optimox avatar Optimox commented on May 31, 2024 1

Do you observe the same pattern with XGBoost or any other ML model ? If so, this is data related and not model related.

from tabnet.

979-Ryan avatar 979-Ryan commented on May 31, 2024

my learning_rate strategy is optimizer_params = dict(lr=1e-1), scheduler_params = dict(T_0=100, T_mult=1, eta_min=1e-2),scheduler_fn=CosineAnnealingWarmRestarts with Adam optimizer function, patience=10

from tabnet.

Optimox avatar Optimox commented on May 31, 2024

Your learning rate is probably too high, also start with a simple learning rate decay like OneCycleLR

from tabnet.

979-Ryan avatar 979-Ryan commented on May 31, 2024

Your learning rate is probably too high, also start with a simple learning rate decay like OneCycleLR

having tried smaller initial learning rate like 2e-2, 5e-2, the loss for training data didn't even decrease

from tabnet.

eduardocarvp avatar eduardocarvp commented on May 31, 2024

Maybe worth having a look at the explanation matrices to check if some of the features are causing the overfit, for example some sort of index that has not been dropped. Without more details about the data it's probably going to be hard to diagnose exactly what might be happening.

Other than that the batch size strikes me as pretty large. Maybe that's the reason you have to use such a high learning rate.

from tabnet.

979-Ryan avatar 979-Ryan commented on May 31, 2024

Maybe worth having a look at the explanation matrices to check if some of the features are causing the overfit, for example some sort of index that has not been dropped. Without more details about the data it's probably going to be hard to diagnose exactly what might be happening.

Other than that the batch size strikes me as pretty large. Maybe that's the reason you have to use such a high learning rate.

That's a possibility that some of the features cause the overfit indeed, but I've already configured lambda_sparse and gamma for regularization. Regarding the batch_size, I followed the original article's recommendation, setting it between 1% to 10% of the training set. Should I reduce it?

from tabnet.

eduardocarvp avatar eduardocarvp commented on May 31, 2024

It is true that they use large batch sizes, up to 16K in the paper. The virtual batch size is always much smaller though, at 512 max.

from tabnet.

979-Ryan avatar 979-Ryan commented on May 31, 2024

Do you observe the same pattern with XGBoost or any other ML model ? If so, this is data related and not model related.

lgbm performs much better with the same loss and evaluation metric

from tabnet.

979-Ryan avatar 979-Ryan commented on May 31, 2024

It is also worth mentioning that the training is extremely slow, around 9-10min per epoch. Any advice on this?

from tabnet.

Optimox avatar Optimox commented on May 31, 2024

Do you have a GPU ?

from tabnet.

Optimox avatar Optimox commented on May 31, 2024

what happens with batch size = 2048, virtual batch size = 256 ?

from tabnet.

979-Ryan avatar 979-Ryan commented on May 31, 2024

Do you have a GPU ?

Yes, training using Nvidia 3090. I haven't try batch_size smaller than 16384. Is training speed and learning rate strategy related to batch size? If i use a smaller batch size, should i lower the learning rate correspondingly? Thank you!

from tabnet.

Optimox avatar Optimox commented on May 31, 2024

Training speed is directly proportional to your batch size as long as 1) your gpu is not already reaching 100% usage 2) your cpu is NOT the bottleneck.
After that, larger batch size will make the training slower.

Batch size and learning rate are related in theory yes. lr=1e-2, batch_size=1024, virtual_batch_size=256 and nothing else specified in the parameters never let me down. If this does not work at all I can't help you more unless you give access to your dataset

from tabnet.

979-Ryan avatar 979-Ryan commented on May 31, 2024

Perhaps what you said above contradicts what you mentioned in this issue? [https://github.com//issues/391#issuecomment-1113099435]. After i reduce batch size to 4096 and virtual bath size to 512, training speed is slower.

from tabnet.

Optimox avatar Optimox commented on May 31, 2024

The larger your batch size, the faster your training is, where is the contradiction here ?

from tabnet.

979-Ryan avatar 979-Ryan commented on May 31, 2024

Sorry, misunderstood what you have said.

from tabnet.

979-Ryan avatar 979-Ryan commented on May 31, 2024

what happens with batch size = 2048, virtual batch size = 256 ?

Just tried a few epochs with batch size 4096 and virtual bath size 512. The performance is worse, the overfit is heavier.

from tabnet.

979-Ryan avatar 979-Ryan commented on May 31, 2024

Training speed is directly proportional to your batch size as long as 1) your gpu is not already reaching 100% usage 2) your cpu is NOT the bottleneck. After that, larger batch size will make the training slower.

Batch size and learning rate are related in theory yes. lr=1e-2, batch_size=1024, virtual_batch_size=256 and nothing else specified in the parameters never let me down. If this does not work at all I can't help you more unless you give access to your dataset

Did you tried this setting for large datasets and large number of features?

from tabnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.