GithubHelp home page GithubHelp logo

Comments (8)

Optimox avatar Optimox commented on June 9, 2024 1

I'm not sure about what is causing the spike, I would need to dig deeper to give you an answer.

If you don't want to look at the features importance then you can just skip the computation yes.

from tabnet.

Optimox avatar Optimox commented on June 9, 2024

@Borda I think this might be due to feature importance computation which can be heavy for large datasets. Have you tried setting compute_importance=False when calling fit ?

from tabnet.

Borda avatar Borda commented on June 9, 2024

Have you tried setting compute_importance=False when calling fit ?

Not yet, let me check it... but would it be possible to compute the importance of GPU (maybe with cudf) or use some booster over all CPU cores?

from tabnet.

Optimox avatar Optimox commented on June 9, 2024

The code can certainly be improved but that's not a quick fix, you may want to compute the feature importance on a smaller subset after training if you like : model._compute_feature_importances(X_subset)

from tabnet.

Borda avatar Borda commented on June 9, 2024

you may want to compute the feature importance on a smaller subset after training if you like : model._compute_feature_importances(X_subset)

That would be a good alternative; just not sure about using protected API, which can be changed at any time, right?

from tabnet.

Borda avatar Borda commented on June 9, 2024

also hit an interesting crash:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[18], line 30
     28 # Train a TabNet model for the current fold
     29 model = TabNetRegressor(**tabnet_params)
---> 30 model._compute_feature_importances(fold_train_features[-100_000:])
     31 model.fit(
     32     X_train=fold_train_features, y_train=fold_train_target,
     33     eval_set=[(fold_valid_features, fold_valid_target)],
     34     eval_metric=['mae'], **FIT_PARAMETERS
     35 )
     36 # Free up memory by deleting fold specific variables

File /opt/conda/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py:759, in TabModel._compute_feature_importances(self, X)
    750 def _compute_feature_importances(self, X):
    751     """Compute global feature importance.
    752 
    753     Parameters
   (...)
    757 
    758     """
--> 759     M_explain, _ = self.explain(X, normalize=False)
    760     sum_explain = M_explain.sum(axis=0)
    761     feature_importances_ = sum_explain / np.sum(sum_explain)

File /opt/conda/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py:336, in TabModel.explain(self, X, normalize)
    318 def explain(self, X, normalize=False):
    319     """
    320     Return local explanation
    321 
   (...)
    334         Sparse matrix showing attention masks used by network.
    335     """
--> 336     self.network.eval()
    338     if scipy.sparse.issparse(X):
    339         dataloader = DataLoader(
    340             SparsePredictDataset(X),
    341             batch_size=self.batch_size,
    342             shuffle=False,
    343         )

AttributeError: 'TabNetRegressor' object has no attribute 'network'

from tabnet.

Optimox avatar Optimox commented on June 9, 2024

you need to first train your model with fit and compute_importance=False, then model._compute_feature_importances(fold_train_features[-100_000:]), the network can't exist if the model does not know how many targets you have.

from tabnet.

Borda avatar Borda commented on June 9, 2024

you need to first train your model with fit and compute_importance=False, then model._compute_feature_importances(fold_train_features[-100_000:]), the network can't exist if the model does not know how many targets you have.

ok, but if just want to fit and predict, then I do not need this _compute_feature_importances at all, right?

also, a bit looking at the code, do you have inside what is the root of the spike? Is it the Sparse matrix computation or the dot prod after it...

from tabnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.