GithubHelp home page GithubHelp logo

mesalock-linux / gbdt-rs Goto Github PK

View Code? Open in Web Editor NEW
198.0 8.0 24.0 913 KB

MesaTEE GBDT-RS : a fast and secure GBDT library, supporting TEEs such as Intel SGX and ARM TrustZone

License: Apache License 2.0

Rust 51.80% Python 3.38% Makefile 2.34% C 41.82% Shell 0.65%

gbdt-rs's Introduction

MesaTEE GBDT-RS : a fast and secure GBDT library, supporting TEEs such as Intel SGX and ARM TrustZone

Build Status codecov

MesaTEE GBDT-RS is a gradient boost decision tree library written in Safe Rust. There is no unsafe rust code in the library.

MesaTEE GBDT-RS provides the training and inference capabilities. And it can use the models trained by xgboost to do inference tasks.

New! The MesaTEE GBDT-RS paper has been accepted by IEEE S&P'19!

Supported Task

Supppoted task for both training and inference

  1. Linear regression: use SquaredError and LAD loss types
  2. Binary classification (labeled with 1 and -1): use LogLikelyhood loss type

Compatibility with xgboost

At this time, MesaTEE GBDT-RS support to use model trained by xgboost to do inference. The model should be trained by xgboost with following configruation:

  1. booster: gbtree
  2. objective: "reg:linear", "reg:logistic", "binary:logistic", "binary:logitraw", "multi:softprob", "multi:softmax" or "rank:pairwise".

We have tested that MesaTEE GBDT-RS is compatible with xgboost 0.81 and 0.82

Quick Start

Training Steps

  1. Set configuration
  2. Load training data
  3. Train the model
  4. (optional) Save the model

Inference Steps

  1. Load the model
  2. Load the test data
  3. Inference the test data

Example

    use gbdt::config::Config;
    use gbdt::decision_tree::{DataVec, PredVec};
    use gbdt::gradient_boost::GBDT;
    use gbdt::input::{InputFormat, load};

    let mut cfg = Config::new();
    cfg.set_feature_size(22);
    cfg.set_max_depth(3);
    cfg.set_iterations(50);
    cfg.set_shrinkage(0.1);
    cfg.set_loss("LogLikelyhood"); 
    cfg.set_debug(true);
    cfg.set_data_sample_ratio(1.0);
    cfg.set_feature_sample_ratio(1.0);
    cfg.set_training_optimization_level(2);

    // load data
    let train_file = "dataset/agaricus-lepiota/train.txt";
    let test_file = "dataset/agaricus-lepiota/test.txt";

    let mut input_format = InputFormat::csv_format();
    input_format.set_feature_size(22);
    input_format.set_label_index(22);
    let mut train_dv: DataVec = load(train_file, input_format).expect("failed to load training data");
    let test_dv: DataVec = load(test_file, input_format).expect("failed to load test data");

    // train and save model
    let mut gbdt = GBDT::new(&cfg);
    gbdt.fit(&mut train_dv);
    gbdt.save_model("gbdt.model").expect("failed to save the model");

    // load model and do inference
    let model = GBDT::load_model("gbdt.model").expect("failed to load the model");
    let predicted: PredVec = model.predict(&test_dv);

Example code

  • Linear regression: examples/iris.rs
  • Binary classification: examples/agaricus-lepiota.rs

Use models trained by xgboost

Steps

  1. Use xgboost to train a model
  2. Use examples/convert_xgboost.py to convert the model
    • Usage: python convert_xgboost.py xgboost_model_path objective output_path
    • Note convert_xgboost.py depends on xgboost python libraries. The converted model can be used on machines without xgboost
  3. In rust code, call GBDT::load_from_xgboost(model_path, objective) to load the model
  4. Do inference
  5. (optional) Call GBDT::save_model to save the model to MesaTEE GBDT-RS native format.

Example code

  • "reg:linear": examples/test-xgb-reg-linear.rs
  • "reg:logistic": examples/test-xgb-reg-logistic.rs
  • "binary:logistic": examples/test-xgb-binary-logistic.rs
  • "binary:logitraw": examples/test-xgb-binary-logistic.rs
  • "multi:softprob": examples/test-xgb-multi-softprob.rs
  • "multi:softmax": examples/test-xgb-multi-softmax.rs
  • "rank:pairwise": examples/test-xgb-rank-pairwise.rs

Multi-threading

Training:

At this time, training in MesaTEE GBDT-RS is single-threaded.

Inference:

The related inference functions are single-threaded. But they are thread-safe. We provide an inference example using multi threads in example/test-multithreads.rs

SGX usage

Because MesaTEE GBDT-RS is written in pure rust, with the help of rust-sgx-sdk, it can be used in sgx enclave easily as:

gbdt_sgx = { git = "https://github.com/mesalock-linux/gbdt-rs" }

This would import a crate named gbdt_sgx. If you prefer gbdt as normal:

gbdt = { package = "gbdt_sgx", git = "https://github.com/mesalock-linux/gbdt-rs" }

For more information and concret examples, please look at directory sgx/gbdt-sgx-test.

License

Apache 2.0

Authors

Tianyi Li @n0b0dyCN [email protected]

Tongxin Li @litongxin1991 [email protected]

Yu Ding @dingelish [email protected]

Steering Committee

Tao Wei, Yulong Zhang

Acknowledgment

Thanks to @qiyiping for his/her great previous work gbdt. We read his/her code before starting this project.

gbdt-rs's People

Contributors

dingelish avatar litongxin1991 avatar mesejo avatar mssun avatar npatsakula avatar pruthvikar avatar tommady avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gbdt-rs's Issues

let config return config :)

It would be much easier to configure the config struct of the sets methods instead of returning () they would return self.

Thank you so much for your work BTW

Unable to use xgboost trained model

Hello,
Thank you for the great library.
I have an issue, I train a model in python using xgboost, but when I try to use "from_xgboost_dump"
I get the following error: thread 'main' panicked at 'failed to load the model: ParseIntError { kind: Empty }'

I was wondering if it would be related to the version of the used xgboost library?!

Thank you

Feature importance

Feature importance shows how valuable each feature was in training a model. It can be used to select features.

It will be a good feature for gbdt-rs to provide feature importance.

ValueType

Hey, i would like to change the ValueType from f32 to f64. Is there any way to do it from myself or you have to implement new stuffs ?

I saw in your code source that you defined in this way :

///! For now we only support std::$t using this macro.
/// We will generalize ValueType in future.
macro_rules! def_value_type {
    ($t: tt) => {
        pub type ValueType = $t;
        pub const VALUE_TYPE_MAX: ValueType = std::$t::MAX;
        pub const VALUE_TYPE_MIN: ValueType = std::$t::MIN;
        pub const VALUE_TYPE_UNKNOWN: ValueType = VALUE_TYPE_MIN;
    };
}

// use continous variables for decision tree
def_value_type!(f32);

Can you bind the ValueType to an f64 behind the feature flag ?

Thanks,
Alexis D.

Crossfold analysis example?

Do you have a way to do this?

Or do I have to split the data - fine to do - but then can I call model.fit() multiple times to update or will it overwrite the model?

small difference between gdbt-rs and rust-xgboost(native)

Hi,

I'm experiencing small delta between prediction (same model, same inputs), of gdbt-rs and rust-xgboost, using xbtree and logistic regression, (https://github.com/davechallis/rust-xgboost) which is based on the C++ implementation.

I'm researching this at the moment and suspect a few causers:

  1. floating point precision differences native to C++ vs Rust
  2. different XGB implementation
  3. I'm training on python and loading into Rust via the convert script -- so maybe a problem in reading the dump on the Rust side (I assume the save side is OK because its using the C++ lib)

From your experience is this a known issue? or maybe you can point me into a more specific direction to research from what I listed above?

Thanks

UPDATE:
I have now narrowed it down to initializing parameters on the Python side vs Rust side. Looks like some of the parameters are not loaded or taking into account differently. When both models in Python and Rust sides are loaded with no parameters - results are equal.

How to specify &[f64] input?

I have inference samples encoded as a list of f64. How can I use gbdt to run inference with pretrained xgboost model?

reg:squarederror

Hi and thanks for your work.

Py's XGB now uses reg:squarederror as default objective. Is this not supported here (it's not in the list of supported objs)?

Use more Rust like interfaces for Config

Just some small changes

For gbdt::config::Config,

  • a ConfigBuilder (following the builder pattern) that has a build() method to create a Config() object is more Rust like.
  • The to_string() method is not necessary if you implement Display.
  • The config.set_loss function takes str rather than an enum.

Are you open to PRs?

Impurity calculation question

Here is the get_impurity method in decision_tree.rs:
for pair in sorted_data.iter() {
let (index, feature_value) = *pair;
if feature_value == VALUE_TYPE_UNKNOWN {
let cv: &CacheValue = &cache.cache_value[index];
s += cv.s;
ss += cv.ss;
c += cv.c;
unknown += 1;
} else {
break;
}
}
I wanna ask why use break instead of continue to break the loop. I think the purpose of this code is to count and calculate the unknown value of samples.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.