Comments (8)
I'm not sure about what is causing the spike, I would need to dig deeper to give you an answer.
If you don't want to look at the features importance then you can just skip the computation yes.
from tabnet.
@Borda I think this might be due to feature importance computation which can be heavy for large datasets. Have you tried setting compute_importance=False
when calling fit
?
from tabnet.
Have you tried setting
compute_importance=False
when callingfit
?
Not yet, let me check it... but would it be possible to compute the importance of GPU (maybe with cudf
) or use some booster over all CPU cores?
from tabnet.
The code can certainly be improved but that's not a quick fix, you may want to compute the feature importance on a smaller subset after training if you like : model._compute_feature_importances(X_subset)
from tabnet.
you may want to compute the feature importance on a smaller subset after training if you like :
model._compute_feature_importances(X_subset)
That would be a good alternative; just not sure about using protected API, which can be changed at any time, right?
from tabnet.
also hit an interesting crash:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[18], line 30
28 # Train a TabNet model for the current fold
29 model = TabNetRegressor(**tabnet_params)
---> 30 model._compute_feature_importances(fold_train_features[-100_000:])
31 model.fit(
32 X_train=fold_train_features, y_train=fold_train_target,
33 eval_set=[(fold_valid_features, fold_valid_target)],
34 eval_metric=['mae'], **FIT_PARAMETERS
35 )
36 # Free up memory by deleting fold specific variables
File /opt/conda/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py:759, in TabModel._compute_feature_importances(self, X)
750 def _compute_feature_importances(self, X):
751 """Compute global feature importance.
752
753 Parameters
(...)
757
758 """
--> 759 M_explain, _ = self.explain(X, normalize=False)
760 sum_explain = M_explain.sum(axis=0)
761 feature_importances_ = sum_explain / np.sum(sum_explain)
File /opt/conda/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py:336, in TabModel.explain(self, X, normalize)
318 def explain(self, X, normalize=False):
319 """
320 Return local explanation
321
(...)
334 Sparse matrix showing attention masks used by network.
335 """
--> 336 self.network.eval()
338 if scipy.sparse.issparse(X):
339 dataloader = DataLoader(
340 SparsePredictDataset(X),
341 batch_size=self.batch_size,
342 shuffle=False,
343 )
AttributeError: 'TabNetRegressor' object has no attribute 'network'
from tabnet.
you need to first train your model with fit and compute_importance=False, then model._compute_feature_importances(fold_train_features[-100_000:]), the network can't exist if the model does not know how many targets you have.
from tabnet.
you need to first train your model with fit and compute_importance=False, then model._compute_feature_importances(fold_train_features[-100_000:]), the network can't exist if the model does not know how many targets you have.
ok, but if just want to fit and predict, then I do not need this _compute_feature_importances
at all, right?
also, a bit looking at the code, do you have inside what is the root of the spike? Is it the Sparse matrix computation or the dot prod after it...
from tabnet.
Related Issues (20)
- Not able to use compute_importance with fit method of tabnet classifier HOT 3
- Wrapper for GridSearchCV with RuntimeError: "Cannot clone object..." for embeddings HOT 3
- Count the number of parameters HOT 2
- Loss goes to -inf HOT 1
- The mask tensor M in script tab_network.py needs to be transformed to realize the objective stated in the paper: "γ is a relaxation parameter – when γ = 1, a feature is enforced to be used only at one decision step".
- Current version on conda-forge is 4.0 while 4.1 is already released HOT 8
- Minimal working example for TabNetRegressor/Classifier HOT 4
- Transfer learning, capability to change structure of model HOT 1
- Generate Embeddings for Tabular Data HOT 1
- TabNet overfits (help wanted, not a bug) HOT 9
- TabNetRegressor vs other networks HOT 1
- Severe overfitting HOT 18
- OOM problem when I search hyperparameters with Tabnet HOT 3
- Support for complex-valued datasets HOT 4
- Different classification variables in the test set and train set HOT 1
- Struggling to get model to fit - Help Wanted HOT 7
- Optimizing TabNet for Disease Classification with Continuous Audio Features HOT 1
- Interpreting Sparsity on Global Importance HOT 5
- ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tabnet.