GithubHelp home page GithubHelp logo

Comments (3)

Optimox avatar Optimox commented on May 30, 2024

I am not very familiar with GridSearchCV but maybe you could define the tabnet model inside the fit method, something similar to this (generated with the help of ChatGPT so take it with a grain of salt):

from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TabNetClassifierWrapper(BaseEstimator, ClassifierMixin):
    def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, n_independent=2, n_shared=2, 
                 lambda_sparse=1e-4, optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=2e-2), 
                 mask_type="entmax", scheduler_params=None, scheduler_fn=None, seed=0,
                 verbose=1, device_name="auto"):
        
        # Initialize parameters
        self.n_d = n_d
        self.n_a = n_a
        self.n_steps = n_steps
        self.gamma = gamma
        self.n_independent = n_independent
        self.n_shared = n_shared
        self.lambda_sparse = lambda_sparse
        self.optimizer_fn = optimizer_fn
        self.optimizer_params = optimizer_params
        self.mask_type = mask_type
        self.scheduler_params = scheduler_params
        self.scheduler_fn = scheduler_fn
        self.seed = seed
        self.verbose = verbose
        self.device_name = device_name
        
        self.tabnet_model = None

    def fit(self, X, y):
        self.tabnet_model = TabNetClassifier(
            n_d=self.n_d, n_a=self.n_a, n_steps=self.n_steps,
            gamma=self.gamma, n_independent=self.n_independent,
            n_shared=self.n_shared, lambda_sparse=self.lambda_sparse,
            optimizer_fn=self.optimizer_fn, optimizer_params=self.optimizer_params,
            mask_type=self.mask_type, scheduler_params=self.scheduler_params,
            scheduler_fn=self.scheduler_fn, seed=self.seed, verbose=self.verbose
        )
        
        self.tabnet_model.fit(
            X_train=X,
            y_train=y,
            eval_set=[(X, y)]
        )
        return self

    def predict(self, X):
        return self.tabnet_model.predict(X)

    def predict_proba(self, X):
        return self.tabnet_model.predict_proba(X)

from tabnet.

timbaessler avatar timbaessler commented on May 30, 2024

Thank you for your response @Optimox,

I added a few parameters to make it complete

class TabNetClassifierWrapper2(BaseEstimator, ClassifierMixin):
    def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=list([]), cat_dims=list([]), cat_emb_dim=list([]),
                 n_independent=2, n_shared=2, epsilon=1e-15,seed=0,momentum=0.002,clip_value=None,
                 lambda_sparse=1e-3, optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=2e-2), scheduler_fn=None,
                 scheduler_params=None, verbose=1 , device_name="auto", mask_type="entmax", grouped_features=None,
                 n_shared_decoder=1, n_indep_decoder=1):

        # Initialize parameters
        self.n_d = n_d
        self.n_a = n_a
        self.n_steps = n_steps
        self.gamma = gamma
        self.cat_idxs = cat_idxs
        self.cat_dims= cat_dims
        self.cat_emb_dim = cat_emb_dim
        self.n_independent = n_independent
        self.n_shared = n_shared
        self.epsilon= epsilon
        self.seed = seed
        self.momentum = momentum
        self.clip_value = clip_value
        self.lambda_sparse = lambda_sparse
        self.optimizer_fn = optimizer_fn
        self.optimizer_params = optimizer_params
        self.scheduler_fn = scheduler_fn
        self.scheduler_params = scheduler_params
        self.model_name = None
        self.verbose = verbose
        self.device_name = device_name
        self.mask_type = mask_type
        self.grouped_features = grouped_features
        self.n_shared_decoder = n_shared_decoder
        self.n_indep_decoder = n_indep_decoder

    def fit(self, X, y, eval_set):
        self.tabnet_model = TabNetClassifier(
            n_d=self.n_d,
            n_a=self.n_a,
            n_steps=self.n_steps,
            gamma=self.gamma,
            cat_idxs=self.cat_idxs,
            cat_dims=self.cat_dims,
            cat_emb_dim=self.cat_emb_dim,
            n_independent=self.n_independent,
            n_shared=self.n_shared,
            epsilon=self.epsilon,
            seed=self.seed,
            momentum=self.momentum,
            clip_value=self.clip_value,
            lambda_sparse=self.lambda_sparse,
            optimizer_fn=self.optimizer_fn,
            optimizer_params=self.optimizer_params,
            scheduler_fn=self.scheduler_fn,
            scheduler_params=self.scheduler_params,
            #model_name=self.model_name,
            verbose=self.verbose,
            device_name=self.device_name,
            mask_type=self.mask_type,
            grouped_features=self.grouped_features,
            n_shared_decoder=self.n_shared_decoder,
            n_indep_decoder=self.n_indep_decoder)

        self.tabnet_model.fit(
            X_train=X,
            y_train=y,
            eval_set=eval_set,
            max_epochs=10
        )
        return self

    def predict(self, X):
        return self.tabnet_model.predict(X)

    def predict_proba(self, X):
        return self.tabnet_model.predict_proba(X)

However, with your version and mine I get nan-score during CV

TabNetClassifierWrapper()
Fitting 5 folds for each of 96 candidates, totalling 480 fits
[CV 1/5; 1/96] START gamma=1, mask_type=entmax, n_d=3, n_independent=1, n_shared=1, n_steps=1, seed=42
[CV 2/5; 1/96] START gamma=1, mask_type=entmax, n_d=3, n_independent=1, n_shared=1, n_steps=1, seed=42
[CV 1/5; 1/96] END gamma=1, mask_type=entmax, n_d=3, n_independent=1, n_shared=1, n_steps=1, seed=42;, score=nan total time=   0.2s
[CV 2/5; 1/96] END gamma=1, mask_type=entmax, n_d=3, n_independent=1, n_shared=1, n_steps=1, seed=42;, score=nan total time=   0.1s

from tabnet.

Optimox avatar Optimox commented on May 30, 2024

Do you also get Nan scores if you simply launch the training with TabNetClassifier outside of GridSearchCV ?

from tabnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.