Comments (3)
I am not very familiar with GridSearchCV
but maybe you could define the tabnet model inside the fit method, something similar to this (generated with the help of ChatGPT so take it with a grain of salt):
from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
class TabNetClassifierWrapper(BaseEstimator, ClassifierMixin):
def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, n_independent=2, n_shared=2,
lambda_sparse=1e-4, optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=2e-2),
mask_type="entmax", scheduler_params=None, scheduler_fn=None, seed=0,
verbose=1, device_name="auto"):
# Initialize parameters
self.n_d = n_d
self.n_a = n_a
self.n_steps = n_steps
self.gamma = gamma
self.n_independent = n_independent
self.n_shared = n_shared
self.lambda_sparse = lambda_sparse
self.optimizer_fn = optimizer_fn
self.optimizer_params = optimizer_params
self.mask_type = mask_type
self.scheduler_params = scheduler_params
self.scheduler_fn = scheduler_fn
self.seed = seed
self.verbose = verbose
self.device_name = device_name
self.tabnet_model = None
def fit(self, X, y):
self.tabnet_model = TabNetClassifier(
n_d=self.n_d, n_a=self.n_a, n_steps=self.n_steps,
gamma=self.gamma, n_independent=self.n_independent,
n_shared=self.n_shared, lambda_sparse=self.lambda_sparse,
optimizer_fn=self.optimizer_fn, optimizer_params=self.optimizer_params,
mask_type=self.mask_type, scheduler_params=self.scheduler_params,
scheduler_fn=self.scheduler_fn, seed=self.seed, verbose=self.verbose
)
self.tabnet_model.fit(
X_train=X,
y_train=y,
eval_set=[(X, y)]
)
return self
def predict(self, X):
return self.tabnet_model.predict(X)
def predict_proba(self, X):
return self.tabnet_model.predict_proba(X)
from tabnet.
Thank you for your response @Optimox,
I added a few parameters to make it complete
class TabNetClassifierWrapper2(BaseEstimator, ClassifierMixin):
def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=list([]), cat_dims=list([]), cat_emb_dim=list([]),
n_independent=2, n_shared=2, epsilon=1e-15,seed=0,momentum=0.002,clip_value=None,
lambda_sparse=1e-3, optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=2e-2), scheduler_fn=None,
scheduler_params=None, verbose=1 , device_name="auto", mask_type="entmax", grouped_features=None,
n_shared_decoder=1, n_indep_decoder=1):
# Initialize parameters
self.n_d = n_d
self.n_a = n_a
self.n_steps = n_steps
self.gamma = gamma
self.cat_idxs = cat_idxs
self.cat_dims= cat_dims
self.cat_emb_dim = cat_emb_dim
self.n_independent = n_independent
self.n_shared = n_shared
self.epsilon= epsilon
self.seed = seed
self.momentum = momentum
self.clip_value = clip_value
self.lambda_sparse = lambda_sparse
self.optimizer_fn = optimizer_fn
self.optimizer_params = optimizer_params
self.scheduler_fn = scheduler_fn
self.scheduler_params = scheduler_params
self.model_name = None
self.verbose = verbose
self.device_name = device_name
self.mask_type = mask_type
self.grouped_features = grouped_features
self.n_shared_decoder = n_shared_decoder
self.n_indep_decoder = n_indep_decoder
def fit(self, X, y, eval_set):
self.tabnet_model = TabNetClassifier(
n_d=self.n_d,
n_a=self.n_a,
n_steps=self.n_steps,
gamma=self.gamma,
cat_idxs=self.cat_idxs,
cat_dims=self.cat_dims,
cat_emb_dim=self.cat_emb_dim,
n_independent=self.n_independent,
n_shared=self.n_shared,
epsilon=self.epsilon,
seed=self.seed,
momentum=self.momentum,
clip_value=self.clip_value,
lambda_sparse=self.lambda_sparse,
optimizer_fn=self.optimizer_fn,
optimizer_params=self.optimizer_params,
scheduler_fn=self.scheduler_fn,
scheduler_params=self.scheduler_params,
#model_name=self.model_name,
verbose=self.verbose,
device_name=self.device_name,
mask_type=self.mask_type,
grouped_features=self.grouped_features,
n_shared_decoder=self.n_shared_decoder,
n_indep_decoder=self.n_indep_decoder)
self.tabnet_model.fit(
X_train=X,
y_train=y,
eval_set=eval_set,
max_epochs=10
)
return self
def predict(self, X):
return self.tabnet_model.predict(X)
def predict_proba(self, X):
return self.tabnet_model.predict_proba(X)
However, with your version and mine I get nan-score during CV
TabNetClassifierWrapper()
Fitting 5 folds for each of 96 candidates, totalling 480 fits
[CV 1/5; 1/96] START gamma=1, mask_type=entmax, n_d=3, n_independent=1, n_shared=1, n_steps=1, seed=42
[CV 2/5; 1/96] START gamma=1, mask_type=entmax, n_d=3, n_independent=1, n_shared=1, n_steps=1, seed=42
[CV 1/5; 1/96] END gamma=1, mask_type=entmax, n_d=3, n_independent=1, n_shared=1, n_steps=1, seed=42;, score=nan total time= 0.2s
[CV 2/5; 1/96] END gamma=1, mask_type=entmax, n_d=3, n_independent=1, n_shared=1, n_steps=1, seed=42;, score=nan total time= 0.1s
from tabnet.
Do you also get Nan scores if you simply launch the training with TabNetClassifier outside of GridSearchCV ?
from tabnet.
Related Issues (20)
- Count the number of parameters HOT 2
- Loss goes to -inf HOT 1
- The mask tensor M in script tab_network.py needs to be transformed to realize the objective stated in the paper: "γ is a relaxation parameter – when γ = 1, a feature is enforced to be used only at one decision step".
- Current version on conda-forge is 4.0 while 4.1 is already released HOT 8
- Minimal working example for TabNetRegressor/Classifier HOT 4
- Transfer learning, capability to change structure of model HOT 1
- Generate Embeddings for Tabular Data HOT 1
- TabNet overfits (help wanted, not a bug) HOT 9
- TabNetRegressor vs other networks HOT 1
- spike in memory when training ends HOT 8
- Severe overfitting HOT 18
- OOM problem when I search hyperparameters with Tabnet HOT 3
- Support for complex-valued datasets HOT 4
- Different classification variables in the test set and train set HOT 1
- Struggling to get model to fit - Help Wanted HOT 7
- Optimizing TabNet for Disease Classification with Continuous Audio Features HOT 1
- Interpreting Sparsity on Global Importance HOT 5
- ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() HOT 1
- Validation loss HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tabnet.