GithubHelp home page GithubHelp logo

microsoft / dgt Goto Github PK

View Code? Open in Web Editor NEW
12.0 8.0 2.0 2.54 MB

Learning Accurate Decision Trees with Bandit Feedback via Quantized Gradient Descent

License: MIT License

Python 99.73% Shell 0.27%

dgt's Introduction

Dense Gradient Tree

This repository houses the supporting code for the paper Learning Accurate Decision Trees with Bandit Feedback via Quantized Gradient Descent.

The Dense Gradient Tree(DGT) technique supports learning decision trees of a given height for (a) multi-class classification, (b) regression settings, with both (a) standard supervised, and (b) bandit feedback. In the bandit feedback setting, the true loss function is unknown to the learning algorithm; the learner can only query the loss for a given prediction. The goal then is to learn decision trees in an online manner, where at each round the learner maintains a tree model, makes prediction for the presented features, receives a loss, and updates the tree model.

Setup

  1. Install necessary packages

Create a new conda environment named dgt_env with python==3.6.8, pytorch==1.7.0 and install all dependencies inside:

$ conda env create -f dgt_env.yml
$ conda activate dgt_env
  1. Change working directory to src:
$ cd src
  1. Run the algorithm

To reproduce some of our results, please run bash run.sh.

  • The script by default runs our algorithm with height 6 on ailerons. Commands for abalone, satimage, and pendigits are commented out.
  • To change height of the tree learnt, change the argument corresponding to --height flag.
  • The --proc_per_gpu option denotes how many processes to run per GPU. It defaults to 4 which is ideal for a typical GPU but on a GPU with small memory, reducing it from 4 might be required.
  • The --num_gpu option denotes how many GPUs to parallelize over (and assumes device ordinal of GPUs start with 0). It defaults to 1.

Note: For abalone dataset we report the final performance across 5 different shuffles.

  1. Check Results

Final scores, i.e. mean test RMSE/Accuracy and standard deviation, can be found in the file ./out/exp@{dataset}_{height}@{start_time}/meanstd-exps/meanstd-run-summary.csv under the columns test_acc_mean and test_acc_std.

Code Contributors

Ajaykrishna Karthikeyan
Naman Jain

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

dgt's People

Contributors

abilityguy avatar microsoft-github-operations[bot] avatar microsoftopensource avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dgt's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.