GithubHelp home page GithubHelp logo

swordli111 / gradientbaseddecisiontrees Goto Github PK

View Code? Open in Web Editor NEW

This project forked from s-marton/gradientbaseddecisiontrees

0.0 0.0 0.0 3.37 MB

GradTree: Gradient-Based Axis-Aligned Decision Trees

License: MIT License

Shell 0.01% Python 13.65% Common Lisp 12.92% Scilab 8.05% Jupyter Notebook 65.37%

gradientbaseddecisiontrees's Introduction

🌳 GradTree: Gradient-Based Decision Trees 🌳

🌳 GradTree is a novel approach for learning hard, axis-aligned decision trees with gradient descent!

πŸ” What's new?

  • Reformulation of decision trees to dense representations
  • Approximation of step function with sigmoids and entmax function
  • ST operator to retain inductive bias of hard, axis-aligned splits

πŸ“ Details on the method can be found in the preprint available under: https://arxiv.org/abs/2305.03515

πŸš€ Follow-Up Work: "GRANDE: Gradient-Based Decision Tree Ensembles"

🌳 GRANDE is a novel gradient-based decision tree ensemble method for tabular data: https://github.com/s-marton/GRANDE

πŸ” What's new?

  • End-to-end gradient descent for tree ensembles
  • Combines inductive bias of hard, axis-aligned splits with the flexibility of a gradient descent optimization
  • Advanced instance-wise weighting to learn representations for both simple & complex relations in one model

πŸ“ More details can be found in our prepring: https://arxiv.org/abs/2309.17130

Installation

To download the latest official release of the package, use a pip command below:

pip install GradTree

More details can be found under: https://pypi.org/project/GradTree/

Cite us

@article{marton2023learning,
  title={GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent},
  author={Marton, Sascha and LΓΌdtke, Stefan and Bartelt, Christian and Stuckenschmidt, Heiner},
  journal={arXiv preprint arXiv:2305.03515},
  year={2023}
}

Usage

Example usage is in the following or available in GradTree_minimal_example.ipynb. Please note that a GPU is required to achieve competitive runtimes.

Load Data

from sklearn.model_selection import train_test_split
import openml

dataset = openml.datasets.get_dataset(40536)
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
categorical_feature_indices = [idx if idx_bool for idx, idx_bool in enumerate(categorical_indicator)]

X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train, X_valid, y_train, y_valid = train_test_split(X_temp, y_temp, test_size=0.2, random_state=42)

y_train = y_train.values.codes.astype(np.float64)
y_valid = y_valid.values.codes.astype(np.float64)
y_test = y_test.values.codes.astype(np.float64)

Preprocessing, Hyperparameters and Training

GradTree requires categorical features to be encoded appropriately. The best results are achieved using Leave-One-Out Encoding for high-cardinality categorical features and One-Hot Encoding for low-cardinality categorical features. Furthermore, all features should be normalized using a quantile transformation. Passing the categorical indices to the model wil automatically preprocess the data accordingly.

In the following, we will train the model using the default parameters. GradTree already archives great results with its default parameters, but a HPO can increase the performance even further. An appropriate grid is specified in the model class.

from GradTree import GradTree

params = {
        'depth': 5,
        'n_estimators': 2048,

        'learning_rate_weights': 0.005,
        'learning_rate_index': 0.01,
        'learning_rate_values': 0.01,
        'learning_rate_leaf': 0.01,

        'optimizer': 'SWA',
        'cosine_decay_steps': 0,

        'initializer': 'RandomNormal',

        'loss': 'crossentropy',
        'focal_loss': False,
        'temperature': 0.0,

        'from_logits': True,
        'apply_class_balancing': True,

        'dropout': 0.0,

        'selected_variables': 0.8,
        'data_subset_fraction': 1.0,
}

args = {
    'epochs': 1_000,
    'early_stopping_epochs': 25,
    'batch_size': 64,

    'cat_idx': categorical_feature_indices, # put list of categorical indices
    'objective': 'binary',
    
    'metrics': ['F1'], # F1, Accuracy, R2
    'random_seed': 42,
    'verbose': 1,       
}

model_gradtree = GradTree(params=params, args=args)

model_gradtree.fit(X_train=X_train,
          y_train=y_train,
          X_val=X_valid,
          y_val=y_valid)

model_gradtree = model_gradtree.predict(X_test)

Evaluate Model

preds = model_gradtree.predict(X_test)

accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds[:,1]))
f1_score = sklearn.metrics.f1_score(y_test, np.round(preds[:,1]), average='macro')
roc_auc = sklearn.metrics.roc_auc_score(y_test, preds[:,1], average='macro')

print('Accuracy:', accuracy)
print('F1 Score:', f1_score)
print('ROC AUC:', roc_auc)

More

Please note that this is an experimental implementation which is not fully tested yet. If you encounter any errors, or you observe unexpected behavior, please let me know.

The code for reproducing the experiments from the paper now is in a separate folder ./experiments_paper_gradtree/

gradientbaseddecisiontrees's People

Contributors

s-marton avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.