GithubHelp home page GithubHelp logo

kingfengji / mgbdt Goto Github PK

View Code? Open in Web Editor NEW
99.0 99.0 25.0 1.05 MB

This is the official clone for the implementation of the NIPS18 paper Multi-Layered Gradient Boosting Decision Trees (mGBDT) .

Python 99.46% Shell 0.54%
gbdt gradient-boosting-decision-trees mgbdt representation-learning target-propagation

mgbdt's Issues

Problem with Pop

/opt/conda/lib/python3.7/site-packages/joblib/parallel.py in (.0)
254 with parallel_backend(self._backend, n_jobs=self._n_jobs):
255 return [func(*args, **kwargs)
--> 256 for func, args, kwargs in self.items]
257
258 def len(self):

/kaggle/working/mGBDT/lib/mgbdt/model/online_xgb.py in fit_increment(self, X, y, num_boost_round, params)
13 for k, v in extra_params.items():
14 params[k] = v
---> 15 params.pop("n_estimators")
16
17 if callable(self.objective):

KeyError: 'n_estimators'

Performance of your model on regression tasks

Description

@kingfengji Thanks for making the code available. I believe that multi-layered decision trees is a very elegant and powerful approach! I was applying your model to the boston housing dataset but wasn't able to outperform a baseline xgboost model.

Details

To compare your approach to several alternatives, I ran a small benchmark study using the following approaches, where all models have the same hyper-parameters

  • baseline xgboost model (xgboost)
  • mGBDT with xgboost for hidden and output layer (mGBDT_XGBoost)
  • mGBDT with xgboost for hidden but with linear model for output layer (mGBDT_Linear)
  • linear model as implemented here (Linear)

I am using PyTorch's L1Loss for model training and use the MAE for evaluation, where all models are trained in serial mode. Results are as follows

image

In particular, I observe the following

  • irresepective of the hyper-parameters and number of epochs, a basline xgboost model tends to outperforms your approach
  • with increasing number of epochs, the runtime for an epoch increases considerably. Any idea as to why this happens?
  • using mGBDT_Linear,
    • I wasn't able to use PyTorch's MSELoss since the loss exploded after some iterations, even after normalizing X. Should we, similar to Neural Networks, also scale y to avoid exploding gradients?
    • the training loss starts at exceptionally high values, then decreases before it starts to increase again

Additional Questions

  • Given that you have mostly been using your approach for classification tasks, is there anything we need to change before we use it for regression tasks, except the PyTorch Loss?
  • Besides the loss of F, can we also track how well the target propagation is working by evaluating the reconstruction loss of G?
  • When using mGBDT with a linear output layer, would we expect to generally see better results compared to using xgboost for the output layer?
  • What is the benefit of using a linear output layer compared to a xgboost layer?
  • For training F and G, you are currently using the MSELoss for the xgboost models. Do you have some experience with modifying this loss?
  • What is the effect of the number of iterations for initializing the model before training?
  • What is the relationship between the number of boosting iterations (for xgboost training) and the number of epochs (for MGBDT training)?
  • In Section 4 of your paper you state "The experiments for this section is mainly designed to empirically examine if it is feasible to jointly train the multi-layered structure proposed by this work. That is, we make no claims that the current structure can outperform CNNs in computer vision tasks." So as a question, would that mean that your intention is not to outperform existing Deep Learning based models, say CNN, or to outperform existing GBM-models, like XGBoost, but rather to show that a Decision Tree based model can be also used for learning meaningful representations that can then be used for downstreaming tasks?
  • Connected to the previous question: Gradient boosting models are already very strong learners that obtain very good results in many applications. So what would be your motivation of using multiple layers of such a model? May it even happen that, based on the implicit error correction mechanism of GBM, training several of them leads to a drop in accuracy?

Code

To reproduce the results, you an use the attached notebook.

ModelComparison.zip

@kingfengji I would highly appreciate your feedback. Many thanks.

Can not find the uci dataset

Hi,
I wanna run the uci_year and uci_adult demo, but I can't find the get_data.sh files as ReadME said. Would you please upload it or tell me the data format so I can handle it by myself.
I find that the code uses features file, but it is not in the git too.

Environment

I feel that code written in python 3.5 would likely be compatible with other python 3 versions, are you sure that a build is necessary in 3.5?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.