Comments (4)
Hi @hengck23, happy to see you here (big fan of your sharing on Kaggle)!
Indeed when reading the paper it seems that everything is shared fc-bn-glu, but since GLU is just the activation function there is no parameter sharing here so the two important parts are fc and bn.
The first implementation we did shared both fc and bn but the network could not train, we removed the batch norm sharing and everything looked better (this is from early experimentation so please take this with a grain of salt). My interpretation behind this is that since you first apply different masks at each step, the input features will look very different at every step, so sharing the same batch norm weights don't actually allow a good batch normalization (the zeroed features are messing with the same non zeroed features at an other step).
That's why we chose to have a separate batch normalization at each step. But I agree that this seems different from what the paper says, it could be worth checking:
- whether the official tensorflow implementation is sharing batch norm as well
- if sharing batch norm in the current implementation still breaks things down
I'll try to have a look and come back with more information to you, if you do some experiments on this please don't hesitate to share them here!
Cheers!
from tabnet.
I had a look at the tensorflow implementation here https://github.com/google-research/google-research/blob/master/tabnet/tabnet_model.py
You can see lines 121 and 134 that they specifically reuse
the fc layers, this does not seem to be the case for the batch normalization layers, so I would say that the figure on the paper is a bit confusing but that the batch norm layers shouldn't/can't be shared here.
So I'm pretty confident that this should be the way to go!
You may still try to share batch norm as well in our implementation but I don't think it will work!
Hope my explanations are clear, let me know otherwise!
from tabnet.
@Optimox Thanks for the answer. Your code is very well written it helps me a lot in understanding the paper. Good work!
I checked the tensorflow code. it seems that only fc is shared but bn is not.
in my own experiments, i also confirm that the performance is as follows:
share only fn (best) > all independent and no sharing > share both fn+bn (very bad)
Also, the results is rather sensitive to:
- batch norm (including momentum, ghost batch size)
- random seed. changing random seed of random sampler of train data loader can affects results
- learning rate
i further note that your code can be speed by 2x if:
- remove "explaining the mask" for training (reduce about 3 sec per epoch for forest-cover experiment)
- use num_workers in dataloader for multi threading (reduce 4 sec)
there is significant overheads in ghost batch-norm. by performing batch norm on multiple splits, the speed is 2x solwer than normal batch norm. (the tensforflow code uses group norm and i am exploring other faster normalization to replace batch norm)
overall, it seems that this framework may be able to match the performance of the boosting trees frame work like xgboost or lightgbm in terms of accuracy and speed. I am trying to "stabilize the results" and make the training less sensitive to the problems described above.
fyi: i refractor your code for the tabenet model. i verify that loading a model trained from your code will gives the same results at inference. you may want to take a look at: https://gist.github.com/hengck23/c21b8b6f2f34634687ebd8a4e963f560
from tabnet.
Thanks @hengck23,
About the speed ups :
- I already added num_workers as a parameter for users (see #98) next release will have this easily accessible
- We probably don't need to care about masks explanations during training you are right, I'll try to find a way to disable this during training.
About ghost batch norm I guess this is one of the trickiest part that I don't master very well, I can see in your code from the link above that you have another way of implementing it. It might be worth switching to your implementation. Please don't hesitate to make a PR to change the code directly, I'd be glad to have you as a contributor.
About the momentum just a quick remark : I think momentum values are not the same in tensorflow and pytorch, the implementations are different and you need to switch 1-m
when going from one to another, so the momentum from the paper needs to be subtracted to one in ours.
Last but not least, we will soon add unitary test to make things easier to review and check.
Again, feel free to contribute with a PR!
Bests
from tabnet.
Related Issues (20)
- Wrapper for GridSearchCV with RuntimeError: "Cannot clone object..." for embeddings HOT 3
- Count the number of parameters HOT 2
- Loss goes to -inf HOT 1
- The mask tensor M in script tab_network.py needs to be transformed to realize the objective stated in the paper: "γ is a relaxation parameter – when γ = 1, a feature is enforced to be used only at one decision step".
- Current version on conda-forge is 4.0 while 4.1 is already released HOT 8
- Minimal working example for TabNetRegressor/Classifier HOT 4
- Transfer learning, capability to change structure of model HOT 1
- Generate Embeddings for Tabular Data HOT 1
- TabNet overfits (help wanted, not a bug) HOT 9
- TabNetRegressor vs other networks HOT 1
- spike in memory when training ends HOT 8
- Severe overfitting HOT 18
- OOM problem when I search hyperparameters with Tabnet HOT 3
- Support for complex-valued datasets HOT 4
- Different classification variables in the test set and train set HOT 1
- Struggling to get model to fit - Help Wanted HOT 7
- Optimizing TabNet for Disease Classification with Continuous Audio Features HOT 1
- Interpreting Sparsity on Global Importance HOT 5
- ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() HOT 1
- Validation loss HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tabnet.