GithubHelp home page GithubHelp logo

Comments (4)

Optimox avatar Optimox commented on May 15, 2024 2

Hi @hengck23, happy to see you here (big fan of your sharing on Kaggle)!

Indeed when reading the paper it seems that everything is shared fc-bn-glu, but since GLU is just the activation function there is no parameter sharing here so the two important parts are fc and bn.

The first implementation we did shared both fc and bn but the network could not train, we removed the batch norm sharing and everything looked better (this is from early experimentation so please take this with a grain of salt). My interpretation behind this is that since you first apply different masks at each step, the input features will look very different at every step, so sharing the same batch norm weights don't actually allow a good batch normalization (the zeroed features are messing with the same non zeroed features at an other step).

That's why we chose to have a separate batch normalization at each step. But I agree that this seems different from what the paper says, it could be worth checking:

  • whether the official tensorflow implementation is sharing batch norm as well
  • if sharing batch norm in the current implementation still breaks things down

I'll try to have a look and come back with more information to you, if you do some experiments on this please don't hesitate to share them here!

Cheers!

from tabnet.

Optimox avatar Optimox commented on May 15, 2024

I had a look at the tensorflow implementation here https://github.com/google-research/google-research/blob/master/tabnet/tabnet_model.py

You can see lines 121 and 134 that they specifically reuse the fc layers, this does not seem to be the case for the batch normalization layers, so I would say that the figure on the paper is a bit confusing but that the batch norm layers shouldn't/can't be shared here.
So I'm pretty confident that this should be the way to go!

You may still try to share batch norm as well in our implementation but I don't think it will work!

Hope my explanations are clear, let me know otherwise!

from tabnet.

hengck23 avatar hengck23 commented on May 15, 2024

@Optimox Thanks for the answer. Your code is very well written it helps me a lot in understanding the paper. Good work!

I checked the tensorflow code. it seems that only fc is shared but bn is not.

in my own experiments, i also confirm that the performance is as follows:
share only fn (best) > all independent and no sharing > share both fn+bn (very bad)

Also, the results is rather sensitive to:

  • batch norm (including momentum, ghost batch size)
  • random seed. changing random seed of random sampler of train data loader can affects results
  • learning rate

i further note that your code can be speed by 2x if:

  • remove "explaining the mask" for training (reduce about 3 sec per epoch for forest-cover experiment)
  • use num_workers in dataloader for multi threading (reduce 4 sec)

there is significant overheads in ghost batch-norm. by performing batch norm on multiple splits, the speed is 2x solwer than normal batch norm. (the tensforflow code uses group norm and i am exploring other faster normalization to replace batch norm)


overall, it seems that this framework may be able to match the performance of the boosting trees frame work like xgboost or lightgbm in terms of accuracy and speed. I am trying to "stabilize the results" and make the training less sensitive to the problems described above.


fyi: i refractor your code for the tabenet model. i verify that loading a model trained from your code will gives the same results at inference. you may want to take a look at: https://gist.github.com/hengck23/c21b8b6f2f34634687ebd8a4e963f560

from tabnet.

Optimox avatar Optimox commented on May 15, 2024

Thanks @hengck23,

About the speed ups :

  • I already added num_workers as a parameter for users (see #98) next release will have this easily accessible
  • We probably don't need to care about masks explanations during training you are right, I'll try to find a way to disable this during training.

About ghost batch norm I guess this is one of the trickiest part that I don't master very well, I can see in your code from the link above that you have another way of implementing it. It might be worth switching to your implementation. Please don't hesitate to make a PR to change the code directly, I'd be glad to have you as a contributor.

About the momentum just a quick remark : I think momentum values are not the same in tensorflow and pytorch, the implementations are different and you need to switch 1-m when going from one to another, so the momentum from the paper needs to be subtracted to one in ours.

Last but not least, we will soon add unitary test to make things easier to review and check.

Again, feel free to contribute with a PR!

Bests

from tabnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.