GithubHelp home page GithubHelp logo

bonsai's Introduction

bonsai

Lifecycle: experimental CRAN status Codecov test coverage R-CMD-check

bonsai provides bindings for additional tree-based model engines for use with the parsnip package.

This package is based off of the work done in the treesnip repository by Athos Damiani, Daniel Falbel, and Roel Hogervorst. bonsai is the official CRAN version of the package; new development will reside here.

Installation

You can install the most recent official release of bonsai with:

install.packages("bonsai")

You can install the development version of bonsai from GitHub with:

# install.packages("pak")
pak::pak("tidymodels/bonsai")

Available Engines

The bonsai package provides additional engines for the models in the following table:

model engine mode
boost_tree lightgbm regression
boost_tree lightgbm classification
decision_tree partykit regression
decision_tree partykit classification
rand_forest partykit regression
rand_forest partykit classification
rand_forest aorsf classification
rand_forest aorsf regression

Code of Conduct

Please note that the bonsai project is released with a Contributor Code of Conduct. By contributing to this project, you agree to abide by its terms.

bonsai's People

Contributors

bcjaeger avatar hfrick avatar jameslamb avatar p-schaefer avatar simonpcouch avatar topepo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

bonsai's Issues

pass `stop_iter` as `early_stopping_rounds` to `train_lightgbm`

One of boost_tree's main arguments is a main argument to lgb.train, the rest go in params.

boost_tree: stop_iter The number of iterations without improvement before stopping (specific engines only).

lgb.train: early_stopping_rounds Activates early stopping. When this parameter is non-null, training will stop if the evaluation of any metric on any validation set fails to improve for early_stopping_rounds consecutive boosting rounds. If training stops early, the returned model will have attribute best_iter set to the iteration number of the best iteration.

Release bonsai 0.3.1

Prepare for release:

  • git pull
  • Check current CRAN check results
  • Polish NEWS
  • urlchecker::url_check()
  • devtools::build_readme()
  • devtools::check(remote = TRUE, manual = TRUE)
  • devtools::check_win_devel()
  • revdepcheck::cloud_check()
  • Update cran-comments.md
  • git push

Submit to CRAN:

  • usethis::use_version('patch')
  • devtools::submit_cran()
  • Approve email

Wait for CRAN...

  • Accepted 🎉
  • usethis::use_github_release()
  • usethis::use_dev_version(push = TRUE)

bump minimum required R version

Latest 3.6 runs give:

  Error: 
  ! error in pak subprocess
  Caused by error: 
  ! Could not solve package dependencies:
  * deps::.: Can't install dependency tune
  * tune: Needs R >= 4.0
  * any::sessioninfo: dependency conflict
  * any::rcmdcheck: dependency conflict
``

interface for `stop_iter` validation set

stop_iter can be used to set how many iterations without an improvement in the objective function occur before training should be halted.

xgboost allows for passing a proportion of the training set to use as an internal validation set via parameters (?details_boost_tree_xgboost).

When stop_iter is supplied (as early_stopping_rounds), a valids argument is required by lightgbm: ”a list of lgb.Dataset objects, used for validation”. This seems like a much more cumbersome interface than a proportion—a good portion of bonsai's lightgbm infrastructure is taking care of the formatting for lgb.Dataset objects. Allowing for passing a proportion of the training data feels much more straightforward, but a possibly confusing interface given that other set_engine arguments are always in the format that the modeling function takes them. As of now, the training set is passed.

[LightGBM] [Fatal] [tweedie]: at least one target label is negative

I receive an error while using Bonsai to train a LightGBM model with a tweedie objective.

The same error does not occur when I train a model using LightGBM without Bonsai.

I have included an example below where I replace one record in the Boston dataset with a zero.

Thank you!

library(lightgbm)
#> Loading required package: R6
library(caret)
#> Loading required package: ggplot2
#> Loading required package: lattice
library(bonsai)
#> Warning: package 'bonsai' was built under R version 4.2.1
#> Loading required package: parsnip

boston <- MASS::Boston
str(boston)
#> 'data.frame':    506 obs. of  14 variables:
#>  $ crim   : num  0.00632 0.02731 0.02729 0.03237 0.06905 ...
#>  $ zn     : num  18 0 0 0 0 0 12.5 12.5 12.5 12.5 ...
#>  $ indus  : num  2.31 7.07 7.07 2.18 2.18 2.18 7.87 7.87 7.87 7.87 ...
#>  $ chas   : int  0 0 0 0 0 0 0 0 0 0 ...
#>  $ nox    : num  0.538 0.469 0.469 0.458 0.458 0.458 0.524 0.524 0.524 0.524 ...
#>  $ rm     : num  6.58 6.42 7.18 7 7.15 ...
#>  $ age    : num  65.2 78.9 61.1 45.8 54.2 58.7 66.6 96.1 100 85.9 ...
#>  $ dis    : num  4.09 4.97 4.97 6.06 6.06 ...
#>  $ rad    : int  1 2 2 3 3 3 5 5 5 5 ...
#>  $ tax    : num  296 242 242 222 222 222 311 311 311 311 ...
#>  $ ptratio: num  15.3 17.8 17.8 18.7 18.7 18.7 15.2 15.2 15.2 15.2 ...
#>  $ black  : num  397 397 393 395 397 ...
#>  $ lstat  : num  4.98 9.14 4.03 2.94 5.33 ...
#>  $ medv   : num  24 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 ...
dim(boston)
#> [1] 506  14

set.seed(12)

indexes <- createDataPartition(boston$medv, p = .85, list = F)
train <- boston[indexes, ]
test <- boston[-indexes, ]

train_x <- train[, -14]
train_x <- scale(train_x)[, ]
train_y <- train[, 14]

test_x <- test[, -14]
test_x <- scale(test[, -14])[, ]
test_y <- test[, 14]

dtrain <- lgb.Dataset(train_x, label = train_y)
dtest <- lgb.Dataset.create.valid(dtrain, test_x, label = test_y)

# define parameters
params <- list(
  objective = "tweedie",
  tweedie_variance_power = 1.5,
  metric = "l2",
  min_data = 1L,
  learning_rate = .3
)

# validataion data
valids <- list(test = dtest)

# train model
model <- lgb.train(
  params = params,
  data = dtrain,
  nrounds = 5L,
  valids = valids
)
#> [LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.001566 seconds.
#> You can set `force_row_wise=true` to remove the overhead.
#> And if memory is not enough, you can set `force_col_wise=true`.
#> [LightGBM] [Info] Total Bins 1078
#> [LightGBM] [Info] Number of data points in the train set: 432, number of used features: 13
#> [LightGBM] [Info] Start training from score 3.117960
#> [1] "[1]:  test's l2:35.351"
#> [1] "[2]:  test's l2:21.9136"
#> [1] "[3]:  test's l2:15.9518"
#> [1] "[4]:  test's l2:13.4712"
#> [1] "[5]:  test's l2:12.782"

# Test tweedie with a zero in target variable

boston <- MASS::Boston

boston[, 14][[1]] <- 0

str(boston)
#> 'data.frame':    506 obs. of  14 variables:
#>  $ crim   : num  0.00632 0.02731 0.02729 0.03237 0.06905 ...
#>  $ zn     : num  18 0 0 0 0 0 12.5 12.5 12.5 12.5 ...
#>  $ indus  : num  2.31 7.07 7.07 2.18 2.18 2.18 7.87 7.87 7.87 7.87 ...
#>  $ chas   : int  0 0 0 0 0 0 0 0 0 0 ...
#>  $ nox    : num  0.538 0.469 0.469 0.458 0.458 0.458 0.524 0.524 0.524 0.524 ...
#>  $ rm     : num  6.58 6.42 7.18 7 7.15 ...
#>  $ age    : num  65.2 78.9 61.1 45.8 54.2 58.7 66.6 96.1 100 85.9 ...
#>  $ dis    : num  4.09 4.97 4.97 6.06 6.06 ...
#>  $ rad    : int  1 2 2 3 3 3 5 5 5 5 ...
#>  $ tax    : num  296 242 242 222 222 222 311 311 311 311 ...
#>  $ ptratio: num  15.3 17.8 17.8 18.7 18.7 18.7 15.2 15.2 15.2 15.2 ...
#>  $ black  : num  397 397 393 395 397 ...
#>  $ lstat  : num  4.98 9.14 4.03 2.94 5.33 ...
#>  $ medv   : num  0 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 ...
dim(boston)
#> [1] 506  14

set.seed(12)

indexes <- createDataPartition(boston$medv, p = .85, list = F)
train <- boston[indexes, ]
test <- boston[-indexes, ]

train_x <- train[, -14]
train_x <- scale(train_x)[, ]
train_y <- train[, 14]

test_x <- test[, -14]
test_x <- scale(test[, -14])[, ]
test_y <- test[, 14]

dtrain <- lgb.Dataset(train_x, label = train_y)
dtest <- lgb.Dataset.create.valid(dtrain, test_x, label = test_y)

# define parameters
params <- list(
  objective = "tweedie",
  tweedie_variance_power = 1.5,
  metric = "l2",
  min_data = 1L,
  learning_rate = .3
)

# validataion data
valids <- list(test = dtest)

# train model
model <- lgb.train(
  params = params,
  data = dtrain,
  nrounds = 5L,
  valids = valids
)
#> [LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.001375 seconds.
#> You can set `force_row_wise=true` to remove the overhead.
#> And if memory is not enough, you can set `force_col_wise=true`.
#> [LightGBM] [Info] Total Bins 1071
#> [LightGBM] [Info] Number of data points in the train set: 432, number of used features: 13
#> [LightGBM] [Info] Start training from score 3.115550
#> [1] "[1]:  test's l2:72.8958"
#> [1] "[2]:  test's l2:94.8558"
#> [1] "[3]:  test's l2:92.4295"
#> [1] "[4]:  test's l2:91.6841"
#> [1] "[5]:  test's l2:95.9142"

bonsai_tweedie_test <-
  boost_tree() %>%
  set_engine(
    engine = "lightgbm",
    objective = "tweedie",
    tweedie_variance_power = 1.7,
    metric = "l2"
  ) %>%
  set_mode(mode = "regression") %>%
  fit(
    formula = medv ~ .,
    data = boston
  )
#> Error in try({ : [tweedie]: at least one target label is negative
#> Error in initialize(...): lgb.Booster: cannot create Booster handle

Custom metrics

Guys, could you please share a small example on how to use custom metrics in

set_engine(   
     "lightgbm",
     ...
     eval = CUSTOM_METRIC()
)

{lightgbm} v4.0.0 is coming

👋 Hello! I'm James, one of the maintainers of LightGBM.

We were excited to see the release of {bonsai}. Thanks so much for making it easier for some R users to work with {lightgbm}! I want to do what I can to support you all.

First, I wanted to let you know that the next release of LightGBM (the entire project, including the R package) will be a major version release with significant breaking changes. See this discussion of v4.0.0: microsoft/LightGBM#5153. We don't have a planned date for that release yet, as we have been struggling from a lack of maintainer attention / activity. But I expect it will be months not weeks from now.

Please open issues at https://github.com/microsoft/LightGBM/issues if there's anything we could do to make {lightgbm} easier to use with {bonsai}.

I'm opening this issue mainly to ask a question. In this period prior to v4.0.0, I'm willing to contribute and maintain the following here:

  • patches that make {bonsai} compatible with the latest release on CRAN (v3.3.2) and the upcoming release (v4.0.0) so that the next release of {lightgbm} isn't disruptive to you
  • CI job(s) testing against the latest development version of LightGBM

Are you open to such contributions?

Thanks for your time and consideration.

Release bonsai 0.3.0

Prepare for release:

  • git pull
  • Check current CRAN check results
  • Polish NEWS
  • urlchecker::url_check()
  • devtools::build_readme()
  • devtools::check(remote = TRUE, manual = TRUE)
  • devtools::check_win_devel()
  • revdepcheck::cloud_check()
  • Update cran-comments.md
  • git push
  • Draft blog post
  • Slack link to draft blog in #open-source-comms

Submit to CRAN:

  • usethis::use_version('minor')
  • devtools::submit_cran()
  • Approve email

Wait for CRAN...

  • Accepted 🎉
  • Finish & publish blog post
  • Add link to blog post in pkgdown news menu
  • usethis::use_github_release()
  • usethis::use_dev_version(push = TRUE)
  • Tweet

aorsf support for `mtry_prop`

aorsf is a great addition to bonsai! Any chance of supporting mtry_prop?

library(tidymodels)
library(bonsai)

set.seed(1)
folds <- vfold_cv(mtcars, v = 5)

rec <- recipe(cyl ~ ., data = mtcars)

mod_lgbm <- boost_tree(mtry = tune()) |> 
  set_engine("lightgbm", count = FALSE) |>
  set_mode("regression")

mod_aorsf <- rand_forest(mtry = tune()) |> 
  set_engine("aorsf", count = FALSE) |>
  set_mode("regression")

lgbm_wflow <- workflow() |>
  add_model(mod_lgbm) |>
  add_recipe(rec)

aorsf_wflow <- workflow() |>
  add_model(mod_aorsf) |>
  add_recipe(rec)

# lightgbm supports mtry_prop
param_info <-
  lgbm_wflow |>
  extract_parameter_set_dials() |>
  update(mtry = mtry_prop(c(0, 1)))

tune_grid(
  lgbm_wflow, 
  resamples = folds, 
  param_info = param_info,
  metrics = metric_set(rmse)
  )
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics          .notes          
#>   <list>         <chr> <list>            <list>          
#> 1 <split [25/7]> Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
#> 2 <split [25/7]> Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
#> 3 <split [26/6]> Fold3 <tibble [10 × 5]> <tibble [0 × 3]>
#> 4 <split [26/6]> Fold4 <tibble [10 × 5]> <tibble [0 × 3]>
#> 5 <split [26/6]> Fold5 <tibble [10 × 5]> <tibble [0 × 3]>

# could aorsf do the same?
param_info <-
  aorsf_wflow |>
  extract_parameter_set_dials() |>
  update(mtry = mtry_prop(c(0, 1)))

tune_grid(
  aorsf_wflow, 
  resamples = folds, 
  param_info = param_info,
  metrics = metric_set(rmse)
  )
#> → A | error:   there were unrecognized arguments:
#>                  count is unrecognized - did you mean control?
#> There were issues with some computations   A: x1
#> There were issues with some computations   A: x50
#> 
#> Warning: All models failed. Run `show_notes(.Last.tune.result)` for more
#> information.
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics .notes           
#>   <list>         <chr> <list>   <list>           
#> 1 <split [25/7]> Fold1 <NULL>   <tibble [10 × 3]>
#> 2 <split [25/7]> Fold2 <NULL>   <tibble [10 × 3]>
#> 3 <split [26/6]> Fold3 <NULL>   <tibble [10 × 3]>
#> 4 <split [26/6]> Fold4 <NULL>   <tibble [10 × 3]>
#> 5 <split [26/6]> Fold5 <NULL>   <tibble [10 × 3]>
#> 
#> There were issues with some computations:
#> 
#>   - Error(s) x50: there were unrecognized arguments:   count is unrecognized - did ...
#> 
#> Run `show_notes(.Last.tune.result)` for more information.

Created on 2024-07-21 with reprex v2.1.1

R CMD check runs with 1600%+ CPU load

Hi, I'm running revdep checks, and noticed that bonsai consumes a lot of CPU during the package tests for quite a long time, i.e. it's not just for a brief period. Running this on a shared machine risks overloading the machine for others. I don't know if the number CPU cores is fixed, or a function of number of physical CPU cores on the machine.

Here's a snapshot as an example:

$ pstree -c -l -a -p 285309                                                     sh,285309 /software/c4/cbi/software/R-4.2.2-gcc10/lib64/R/bin/Rcmd check bonsai_0.2.0.tar.gz --no-manual -o /c4/home/henrik/repositories/globals/revdep/checks/bonsai/old
  └─R,285313 --no-restore --no-echo --args nextArgbonsai_0.2.0.tar.gznextArg--no-manualnextArg-onextArg/c4/home/henrik/repositories/globals/revdep/checks/bonsai/old
     └─sh,287890 -c LANGUAGE=en _R_CHECK_INTERNALS2_=1 '/software/c4/cbi/software/R-4.2.2-gcc10/lib64/R/bin/R' --vanilla --no-echo < '/scratch/henrik/1009210/RtmpQphiUe/file45a8140f4c4c1'
         └─R,287891 --vanilla --no-echo
            └─sh,287904 /software/c4/cbi/software/R-4.2.2-gcc10/lib64/R/bin/Rcmd BATCH --vanilla testthat.R testthat.Rout
                └─R,287908 -f testthat.R --restore --save --no-readline --vanilla
                   ├─{R},287930
                   ├─{R},288042
                   ├─{R},288043
                   ├─{R},288044
                   ├─{R},288045
                   ├─{R},288046
                   ├─{R},288047
                   ├─{R},288048
                   ├─{R},288049
                   ├─{R},288050
                   ├─{R},288051
                   ├─{R},288052
                   ├─{R},288053
                   ├─{R},288054
                   ├─{R},288055
                   ├─{R},288056
                   ├─{R},288057
                   ├─{R},288058
                   ├─{R},288059
                   ├─{R},288060
                   ├─{R},288061
                   ├─{R},288062
                   ├─{R},288063
                   ├─{R},288064
                   ├─{R},288065
                   ├─{R},288066
                   ├─{R},288067
                   ├─{R},288068
                   ├─{R},288069
                   ├─{R},288070
                   ├─{R},288071
                   └─{R},288072

flameshot_2022-11-21T085452

bump minimum aorsf version

Once the version of aorsf that follows up on 0.1.4 is on CRAN, bump the minimum required package version! At that point, those changes should be ready to go to CRAN.

Related to #73.

lightgbm model doesn't see sample_size parameter from boost_tree

While I was testing bonsai package I noticed some odd behaviour regarding tuning parameters. When I indicated which parameters should be tuned via boost_tree I saw that lighgbm did not use sample_size. Is that correct behaviour?
I used following code to configure model.

pre_proc <-
recipe(status ~ ., data = data_train) %>%
update_role(cif, data_alertu, okres_alertu, new_role = "id variable") %>%
step_zv(all_predictors()) %>%
step_corr(all_predictors(), threshold = .9)

gbm_mod <-
boost_tree(
mode = "classification",
trees = tune(),
tree_depth = tune(),
min_n = tune(),
loss_reduction = tune(),
sample_size = tune(),
mtry = tune(),
learn_rate = tune()
) %>%
set_engine("lightgbm")

gbm_wflow <-
workflow() %>%
add_model(gbm_mod) %>%
add_recipe(pre_proc)

gbm_set <- extract_parameter_set_dials(gbm_wflow) %>% update(mtry = mtry(c(10,50)))

gbm_set returned only 6 parameters, although I set 7 to be tuned. I saw that sample_size is missing. Shouldn't it be translated to lightgbm's bagging_fraction parameter?

Things to keep in mind when saving lightgbm models?

I've had some issues with saving and {butcher}ing (reducing file size of saved model) xgboost models via tidymodels some months ago. What this came down to (IIUC) is that tidymodels does not support native serialization of those models at the moment.

Is that something I would have worry about when working with {bonsai}/{lightgbm} as well? Or more generally, what is the recommended approach for saving lightgbm models and reading them back in for prediction / using the saved models in a "prediction package" later? Is it okay to "just" saveRDS.lgb.Booster() and readRDS.lgb.Booster() them?

Lightgbm tuned with parallel processing unable to predict

Lightgbm model (after using last_fit()) isn't able to predict when parallel processing is used.
I tried using both registerDoParallel and registerDoFuture and both gave me the same error.
I tried it on both multiclass and binary classification problems, but both cases with parallel processing still gave me the error.

Error messages are mentioned below in the reprexes:

Multiclass tuning without parallel processing

No issues. Able to predict, and extract feature importance.

library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.2.1
#> Warning: package 'broom' was built under R version 4.2.1
#> Warning: package 'scales' was built under R version 4.2.1
#> Warning: package 'infer' was built under R version 4.2.1
#> Warning: package 'modeldata' was built under R version 4.2.1
#> Warning: package 'parsnip' was built under R version 4.2.1
#> Warning: package 'rsample' was built under R version 4.2.1
#> Warning: package 'tibble' was built under R version 4.2.1
#> Warning: package 'workflows' was built under R version 4.2.1
#> Warning: package 'workflowsets' was built under R version 4.2.1
library(bonsai)
library(palmerpenguins)
#> Warning: package 'palmerpenguins' was built under R version 4.2.1
#> 
#> Attaching package: 'palmerpenguins'
#> The following object is masked from 'package:modeldata':
#> 
#>     penguins

split <- penguins |>
  initial_split(strata = species)

penguins_train <- training(split)
penguins_test <- testing(split)
folds <- vfold_cv(penguins_train, strata = species, 3)

recipe_basic <- penguins_train |>
  recipe(species ~ .)

lightgbm_spec <- boost_tree(trees = tune()) |>
  set_engine(
    "lightgbm",
    objective = "multiclass",
    metric = "multi_error",
    num_class = !!length(unique(penguins_train$species))
  ) |>
  set_mode("classification")

lightgbm_wflow <- workflow(preprocessor = recipe_basic,
                           spec = lightgbm_spec)

training_grid_results <- lightgbm_wflow |>
  tune_grid(resamples = folds,
            grid = 5)
#> Warning: package 'lightgbm' was built under R version 4.2.1

last_fit <- lightgbm_wflow |>
  finalize_workflow(select_best(training_grid_results, "roc_auc")) |>
  last_fit(split)

last_fit |>
  extract_workflow() |>
  predict(head(penguins_test))
#> # A tibble: 6 × 1
#>   .pred_class
#>   <fct>      
#> 1 Adelie     
#> 2 Adelie     
#> 3 Adelie     
#> 4 Adelie     
#> 5 Adelie     
#> 6 Adelie

last_fit |>
  extract_fit_engine() |>
  lightgbm::lgb.importance() |>
  lightgbm::lgb.plot.importance()

Created on 2022-09-07 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22 ucrt)
#>  os       Windows 10 x64 (build 19042)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  English_Singapore.utf8
#>  ctype    English_Singapore.utf8
#>  tz       Asia/Kuala_Lumpur
#>  date     2022-09-07
#>  pandoc   2.17.1.1 @ C:/Program Files/RStudio/bin/quarto/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package        * version    date (UTC) lib source
#>  assertthat       0.2.1      2019-03-21 [1] CRAN (R 4.2.0)
#>  backports        1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>  bonsai         * 0.2.0      2022-08-31 [1] CRAN (R 4.2.0)
#>  broom          * 1.0.1      2022-08-29 [1] CRAN (R 4.2.1)
#>  class            7.3-20     2022-01-16 [1] CRAN (R 4.2.0)
#>  cli              3.3.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  codetools        0.2-18     2020-11-04 [1] CRAN (R 4.2.0)
#>  colorspace       2.0-3      2022-02-21 [1] CRAN (R 4.2.0)
#>  crayon           1.5.1      2022-03-26 [1] CRAN (R 4.2.0)
#>  curl             4.3.2      2021-06-23 [1] CRAN (R 4.2.0)
#>  data.table       1.14.2     2021-09-27 [1] CRAN (R 4.2.0)
#>  DBI              1.1.3      2022-06-18 [1] CRAN (R 4.2.0)
#>  dials          * 1.0.0      2022-06-14 [1] CRAN (R 4.2.0)
#>  DiceDesign       1.9        2021-02-13 [1] CRAN (R 4.2.0)
#>  digest           0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>  dplyr          * 1.0.9      2022-04-28 [1] CRAN (R 4.2.0)
#>  ellipsis         0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>  evaluate         0.15       2022-02-18 [1] CRAN (R 4.2.0)
#>  fansi            1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>  fastmap          1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>  foreach          1.5.2      2022-02-02 [1] CRAN (R 4.2.0)
#>  fs               1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>  furrr            0.3.1      2022-08-15 [1] CRAN (R 4.2.1)
#>  future           1.27.0     2022-07-22 [1] CRAN (R 4.2.1)
#>  future.apply     1.9.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  generics         0.1.3      2022-07-05 [1] CRAN (R 4.2.1)
#>  ggplot2        * 3.3.6      2022-05-03 [1] CRAN (R 4.2.0)
#>  globals          0.16.1     2022-08-28 [1] CRAN (R 4.2.0)
#>  glue             1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  gower            1.0.0      2022-02-03 [1] CRAN (R 4.2.0)
#>  GPfit            1.0-8      2019-02-08 [1] CRAN (R 4.2.0)
#>  gtable           0.3.0      2019-03-25 [1] CRAN (R 4.2.0)
#>  hardhat          1.2.0      2022-06-30 [1] CRAN (R 4.2.1)
#>  highr            0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>  htmltools        0.5.3      2022-07-18 [1] CRAN (R 4.2.1)
#>  httr             1.4.3      2022-05-04 [1] CRAN (R 4.2.0)
#>  infer          * 1.0.2      2022-06-10 [1] CRAN (R 4.2.1)
#>  ipred            0.9-13     2022-06-02 [1] CRAN (R 4.2.1)
#>  iterators        1.0.14     2022-02-05 [1] CRAN (R 4.2.0)
#>  jsonlite         1.8.0      2022-02-22 [1] CRAN (R 4.2.0)
#>  knitr            1.39       2022-04-26 [1] CRAN (R 4.2.0)
#>  lattice          0.20-45    2021-09-22 [1] CRAN (R 4.2.0)
#>  lava             1.6.10     2021-09-02 [1] CRAN (R 4.2.0)
#>  lhs              1.1.5      2022-03-22 [1] CRAN (R 4.2.0)
#>  lifecycle        1.0.1      2021-09-24 [1] CRAN (R 4.2.0)
#>  lightgbm       * 3.3.2      2022-01-14 [1] CRAN (R 4.2.1)
#>  listenv          0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>  lubridate        1.8.0      2021-10-07 [1] CRAN (R 4.2.0)
#>  magrittr         2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>  MASS             7.3-57     2022-04-22 [1] CRAN (R 4.2.0)
#>  Matrix           1.4-1      2022-03-23 [1] CRAN (R 4.2.0)
#>  mime             0.12       2021-09-28 [1] CRAN (R 4.2.0)
#>  modeldata      * 1.0.0      2022-07-01 [1] CRAN (R 4.2.1)
#>  munsell          0.5.0      2018-06-12 [1] CRAN (R 4.2.0)
#>  nnet             7.3-17     2022-01-16 [1] CRAN (R 4.2.0)
#>  palmerpenguins * 0.1.1      2022-08-15 [1] CRAN (R 4.2.1)
#>  parallelly       1.32.1     2022-07-21 [1] CRAN (R 4.2.1)
#>  parsnip        * 1.0.1      2022-08-18 [1] CRAN (R 4.2.1)
#>  pillar           1.8.1      2022-08-19 [1] CRAN (R 4.2.1)
#>  pkgconfig        2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>  prodlim          2019.11.13 2019-11-17 [1] CRAN (R 4.2.0)
#>  purrr          * 0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>  R6             * 2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>  Rcpp             1.0.9      2022-07-08 [1] CRAN (R 4.2.1)
#>  recipes        * 1.0.1      2022-07-07 [1] CRAN (R 4.2.0)
#>  reprex           2.0.2      2022-08-17 [1] CRAN (R 4.2.1)
#>  rlang            1.0.4      2022-07-12 [1] CRAN (R 4.2.1)
#>  rmarkdown        2.16       2022-08-24 [1] CRAN (R 4.2.1)
#>  rpart            4.1.16     2022-01-24 [1] CRAN (R 4.2.0)
#>  rsample        * 1.1.0      2022-08-08 [1] CRAN (R 4.2.1)
#>  rstudioapi       0.14       2022-08-22 [1] CRAN (R 4.2.0)
#>  scales         * 1.2.1      2022-08-20 [1] CRAN (R 4.2.1)
#>  sessioninfo      1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi          1.7.6      2021-11-29 [1] CRAN (R 4.2.0)
#>  stringr          1.4.1      2022-08-20 [1] CRAN (R 4.2.1)
#>  survival         3.4-0      2022-08-09 [1] CRAN (R 4.2.1)
#>  tibble         * 3.1.8      2022-07-22 [1] CRAN (R 4.2.1)
#>  tidymodels     * 1.0.0      2022-07-13 [1] CRAN (R 4.2.1)
#>  tidyr          * 1.2.0      2022-02-01 [1] CRAN (R 4.2.0)
#>  tidyselect       1.1.2      2022-02-21 [1] CRAN (R 4.2.0)
#>  timeDate         4021.104   2022-07-19 [1] CRAN (R 4.2.1)
#>  tune           * 1.0.0      2022-07-07 [1] CRAN (R 4.2.0)
#>  utf8             1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs            0.4.1      2022-04-13 [1] CRAN (R 4.2.0)
#>  withr            2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>  workflows      * 1.0.0      2022-07-05 [1] CRAN (R 4.2.1)
#>  workflowsets   * 1.0.0      2022-07-12 [1] CRAN (R 4.2.1)
#>  xfun             0.31       2022-05-10 [1] CRAN (R 4.2.0)
#>  xml2             1.3.3      2021-11-30 [1] CRAN (R 4.2.0)
#>  yaml             2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#>  yardstick      * 1.0.0      2022-06-06 [1] CRAN (R 4.2.0)
#> 
#>  [1] C:/Users/dchoy/AppData/Local/Programs/R/R-4.2.0/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Multiclass tuning with parallel processing (registerDoParallel)

Unable to predict nor extract feature importance.

library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.2.1
#> Warning: package 'broom' was built under R version 4.2.1
#> Warning: package 'scales' was built under R version 4.2.1
#> Warning: package 'infer' was built under R version 4.2.1
#> Warning: package 'modeldata' was built under R version 4.2.1
#> Warning: package 'parsnip' was built under R version 4.2.1
#> Warning: package 'rsample' was built under R version 4.2.1
#> Warning: package 'tibble' was built under R version 4.2.1
#> Warning: package 'workflows' was built under R version 4.2.1
#> Warning: package 'workflowsets' was built under R version 4.2.1
library(bonsai)
library(palmerpenguins)
#> Warning: package 'palmerpenguins' was built under R version 4.2.1
#> 
#> Attaching package: 'palmerpenguins'
#> The following object is masked from 'package:modeldata':
#> 
#>     penguins
library(doParallel)
#> Loading required package: foreach
#> 
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#> 
#>     accumulate, when
#> Loading required package: iterators
#> Loading required package: parallel
library(parallel)

split <- penguins |>
  initial_split(strata = species)

penguins_train <- training(split)
penguins_test <- testing(split)
folds <- vfold_cv(penguins_train, strata = species, 3)

recipe_basic <- penguins_train |>
  recipe(species ~ .)

lightgbm_spec <- boost_tree(trees = tune()) |>
  set_engine(
    "lightgbm",
    objective = "multiclass",
    metric = "multi_error",
    num_class = !!length(unique(penguins_train$species))
  ) |>
  set_mode("classification")

lightgbm_wflow <- workflow(preprocessor = recipe_basic,
                           spec = lightgbm_spec)

all_cores <- detectCores(logical = FALSE)
cl <- makePSOCKcluster(all_cores - 1)
registerDoParallel(cl)

training_grid_results <- lightgbm_wflow |>
  tune_grid(resamples = folds,
            grid = 5)

last_fit <- lightgbm_wflow |>
  finalize_workflow(select_best(training_grid_results, "roc_auc")) |>
  last_fit(split)

last_fit |>
  extract_workflow() |>
  predict(head(penguins_test))
#> Error in predictor$predict(data = data, start_iteration = start_iteration, : Attempting to use a Booster which no longer exists. This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters.

last_fit |>
  extract_fit_engine() |>
  lightgbm::lgb.importance() |>
  lightgbm::lgb.plot.importance()
#> Error in booster$dump_model(num_iteration = num_iteration): Attempting to use a Booster which no longer exists. This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters.

Created on 2022-09-07 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22 ucrt)
#>  os       Windows 10 x64 (build 19042)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  English_Singapore.utf8
#>  ctype    English_Singapore.utf8
#>  tz       Asia/Kuala_Lumpur
#>  date     2022-09-07
#>  pandoc   2.17.1.1 @ C:/Program Files/RStudio/bin/quarto/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package        * version    date (UTC) lib source
#>  assertthat       0.2.1      2019-03-21 [1] CRAN (R 4.2.0)
#>  backports        1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>  bonsai         * 0.2.0      2022-08-31 [1] CRAN (R 4.2.0)
#>  broom          * 1.0.1      2022-08-29 [1] CRAN (R 4.2.1)
#>  class            7.3-20     2022-01-16 [1] CRAN (R 4.2.0)
#>  cli              3.3.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  codetools        0.2-18     2020-11-04 [1] CRAN (R 4.2.0)
#>  colorspace       2.0-3      2022-02-21 [1] CRAN (R 4.2.0)
#>  crayon           1.5.1      2022-03-26 [1] CRAN (R 4.2.0)
#>  data.table       1.14.2     2021-09-27 [1] CRAN (R 4.2.0)
#>  DBI              1.1.3      2022-06-18 [1] CRAN (R 4.2.0)
#>  dials          * 1.0.0      2022-06-14 [1] CRAN (R 4.2.0)
#>  DiceDesign       1.9        2021-02-13 [1] CRAN (R 4.2.0)
#>  digest           0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>  doParallel     * 1.0.17     2022-02-07 [1] CRAN (R 4.2.0)
#>  dplyr          * 1.0.9      2022-04-28 [1] CRAN (R 4.2.0)
#>  ellipsis         0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>  evaluate         0.15       2022-02-18 [1] CRAN (R 4.2.0)
#>  fansi            1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>  fastmap          1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>  foreach        * 1.5.2      2022-02-02 [1] CRAN (R 4.2.0)
#>  fs               1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>  furrr            0.3.1      2022-08-15 [1] CRAN (R 4.2.1)
#>  future           1.27.0     2022-07-22 [1] CRAN (R 4.2.1)
#>  future.apply     1.9.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  generics         0.1.3      2022-07-05 [1] CRAN (R 4.2.1)
#>  ggplot2        * 3.3.6      2022-05-03 [1] CRAN (R 4.2.0)
#>  globals          0.16.1     2022-08-28 [1] CRAN (R 4.2.0)
#>  glue             1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  gower            1.0.0      2022-02-03 [1] CRAN (R 4.2.0)
#>  GPfit            1.0-8      2019-02-08 [1] CRAN (R 4.2.0)
#>  gtable           0.3.0      2019-03-25 [1] CRAN (R 4.2.0)
#>  hardhat          1.2.0      2022-06-30 [1] CRAN (R 4.2.1)
#>  highr            0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>  htmltools        0.5.3      2022-07-18 [1] CRAN (R 4.2.1)
#>  infer          * 1.0.2      2022-06-10 [1] CRAN (R 4.2.1)
#>  ipred            0.9-13     2022-06-02 [1] CRAN (R 4.2.1)
#>  iterators      * 1.0.14     2022-02-05 [1] CRAN (R 4.2.0)
#>  jsonlite         1.8.0      2022-02-22 [1] CRAN (R 4.2.0)
#>  knitr            1.39       2022-04-26 [1] CRAN (R 4.2.0)
#>  lattice          0.20-45    2021-09-22 [1] CRAN (R 4.2.0)
#>  lava             1.6.10     2021-09-02 [1] CRAN (R 4.2.0)
#>  lhs              1.1.5      2022-03-22 [1] CRAN (R 4.2.0)
#>  lifecycle        1.0.1      2021-09-24 [1] CRAN (R 4.2.0)
#>  lightgbm         3.3.2      2022-01-14 [1] CRAN (R 4.2.1)
#>  listenv          0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>  lubridate        1.8.0      2021-10-07 [1] CRAN (R 4.2.0)
#>  magrittr         2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>  MASS             7.3-57     2022-04-22 [1] CRAN (R 4.2.0)
#>  Matrix           1.4-1      2022-03-23 [1] CRAN (R 4.2.0)
#>  modeldata      * 1.0.0      2022-07-01 [1] CRAN (R 4.2.1)
#>  munsell          0.5.0      2018-06-12 [1] CRAN (R 4.2.0)
#>  nnet             7.3-17     2022-01-16 [1] CRAN (R 4.2.0)
#>  palmerpenguins * 0.1.1      2022-08-15 [1] CRAN (R 4.2.1)
#>  parallelly       1.32.1     2022-07-21 [1] CRAN (R 4.2.1)
#>  parsnip        * 1.0.1      2022-08-18 [1] CRAN (R 4.2.1)
#>  pillar           1.8.1      2022-08-19 [1] CRAN (R 4.2.1)
#>  pkgconfig        2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>  prodlim          2019.11.13 2019-11-17 [1] CRAN (R 4.2.0)
#>  purrr          * 0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>  R6               2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>  Rcpp             1.0.9      2022-07-08 [1] CRAN (R 4.2.1)
#>  recipes        * 1.0.1      2022-07-07 [1] CRAN (R 4.2.0)
#>  reprex           2.0.2      2022-08-17 [1] CRAN (R 4.2.1)
#>  rlang            1.0.4      2022-07-12 [1] CRAN (R 4.2.1)
#>  rmarkdown        2.16       2022-08-24 [1] CRAN (R 4.2.1)
#>  rpart            4.1.16     2022-01-24 [1] CRAN (R 4.2.0)
#>  rsample        * 1.1.0      2022-08-08 [1] CRAN (R 4.2.1)
#>  rstudioapi       0.14       2022-08-22 [1] CRAN (R 4.2.0)
#>  scales         * 1.2.1      2022-08-20 [1] CRAN (R 4.2.1)
#>  sessioninfo      1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi          1.7.6      2021-11-29 [1] CRAN (R 4.2.0)
#>  stringr          1.4.1      2022-08-20 [1] CRAN (R 4.2.1)
#>  survival         3.4-0      2022-08-09 [1] CRAN (R 4.2.1)
#>  tibble         * 3.1.8      2022-07-22 [1] CRAN (R 4.2.1)
#>  tidymodels     * 1.0.0      2022-07-13 [1] CRAN (R 4.2.1)
#>  tidyr          * 1.2.0      2022-02-01 [1] CRAN (R 4.2.0)
#>  tidyselect       1.1.2      2022-02-21 [1] CRAN (R 4.2.0)
#>  timeDate         4021.104   2022-07-19 [1] CRAN (R 4.2.1)
#>  tune           * 1.0.0      2022-07-07 [1] CRAN (R 4.2.0)
#>  utf8             1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs            0.4.1      2022-04-13 [1] CRAN (R 4.2.0)
#>  withr            2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>  workflows      * 1.0.0      2022-07-05 [1] CRAN (R 4.2.1)
#>  workflowsets   * 1.0.0      2022-07-12 [1] CRAN (R 4.2.1)
#>  xfun             0.31       2022-05-10 [1] CRAN (R 4.2.0)
#>  yaml             2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#>  yardstick      * 1.0.0      2022-06-06 [1] CRAN (R 4.2.0)
#> 
#>  [1] C:/Users/dchoy/AppData/Local/Programs/R/R-4.2.0/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Multiclass tuning with parallel processing (registerDoFuture)

Unable to predict nor extract feature importance.

library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.2.1
#> Warning: package 'broom' was built under R version 4.2.1
#> Warning: package 'scales' was built under R version 4.2.1
#> Warning: package 'infer' was built under R version 4.2.1
#> Warning: package 'modeldata' was built under R version 4.2.1
#> Warning: package 'parsnip' was built under R version 4.2.1
#> Warning: package 'rsample' was built under R version 4.2.1
#> Warning: package 'tibble' was built under R version 4.2.1
#> Warning: package 'workflows' was built under R version 4.2.1
#> Warning: package 'workflowsets' was built under R version 4.2.1
library(bonsai)
library(palmerpenguins)
#> Warning: package 'palmerpenguins' was built under R version 4.2.1
#> 
#> Attaching package: 'palmerpenguins'
#> The following object is masked from 'package:modeldata':
#> 
#>     penguins
library(doFuture)
#> Warning: package 'doFuture' was built under R version 4.2.1
#> Loading required package: foreach
#> 
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#> 
#>     accumulate, when
#> Loading required package: future
#> Warning: package 'future' was built under R version 4.2.1
library(parallel)

split <- penguins |>
  initial_split(strata = species)

penguins_train <- training(split)
penguins_test <- testing(split)
folds <- vfold_cv(penguins_train, strata = species, 3)

recipe_basic <- penguins_train |>
  recipe(species ~ .)

lightgbm_spec <- boost_tree(trees = tune()) |>
  set_engine(
    "lightgbm",
    objective = "multiclass",
    metric = "multi_error",
    num_class = !!length(unique(penguins_train$species))
  ) |>
  set_mode("classification")

lightgbm_wflow <- workflow(preprocessor = recipe_basic,
                           spec = lightgbm_spec)

all_cores <- detectCores(logical = FALSE)
registerDoFuture()
cl <- makeCluster(all_cores)
plan(cluster, workers = cl)

training_grid_results <- lightgbm_wflow |>
  tune_grid(resamples = folds,
            grid = 5)

last_fit <- lightgbm_wflow |>
  finalize_workflow(select_best(training_grid_results, "roc_auc")) |>
  last_fit(split)

last_fit |>
  extract_workflow() |>
  predict(head(penguins_test))
#> Error in predictor$predict(data = data, start_iteration = start_iteration, : Attempting to use a Booster which no longer exists. This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters.

last_fit |>
  extract_fit_engine() |>
  lightgbm::lgb.importance() |>
  lightgbm::lgb.plot.importance()
#> Error in booster$dump_model(num_iteration = num_iteration): Attempting to use a Booster which no longer exists. This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters.

Created on 2022-09-07 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22 ucrt)
#>  os       Windows 10 x64 (build 19042)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  English_Singapore.utf8
#>  ctype    English_Singapore.utf8
#>  tz       Asia/Kuala_Lumpur
#>  date     2022-09-07
#>  pandoc   2.17.1.1 @ C:/Program Files/RStudio/bin/quarto/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package        * version    date (UTC) lib source
#>  assertthat       0.2.1      2019-03-21 [1] CRAN (R 4.2.0)
#>  backports        1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>  bonsai         * 0.2.0      2022-08-31 [1] CRAN (R 4.2.0)
#>  broom          * 1.0.1      2022-08-29 [1] CRAN (R 4.2.1)
#>  class            7.3-20     2022-01-16 [1] CRAN (R 4.2.0)
#>  cli              3.3.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  codetools        0.2-18     2020-11-04 [1] CRAN (R 4.2.0)
#>  colorspace       2.0-3      2022-02-21 [1] CRAN (R 4.2.0)
#>  crayon           1.5.1      2022-03-26 [1] CRAN (R 4.2.0)
#>  data.table       1.14.2     2021-09-27 [1] CRAN (R 4.2.0)
#>  DBI              1.1.3      2022-06-18 [1] CRAN (R 4.2.0)
#>  dials          * 1.0.0      2022-06-14 [1] CRAN (R 4.2.0)
#>  DiceDesign       1.9        2021-02-13 [1] CRAN (R 4.2.0)
#>  digest           0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>  doFuture       * 0.12.2     2022-04-26 [1] CRAN (R 4.2.1)
#>  dplyr          * 1.0.9      2022-04-28 [1] CRAN (R 4.2.0)
#>  ellipsis         0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>  evaluate         0.15       2022-02-18 [1] CRAN (R 4.2.0)
#>  fansi            1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>  fastmap          1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>  foreach        * 1.5.2      2022-02-02 [1] CRAN (R 4.2.0)
#>  fs               1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>  furrr            0.3.1      2022-08-15 [1] CRAN (R 4.2.1)
#>  future         * 1.27.0     2022-07-22 [1] CRAN (R 4.2.1)
#>  future.apply     1.9.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  generics         0.1.3      2022-07-05 [1] CRAN (R 4.2.1)
#>  ggplot2        * 3.3.6      2022-05-03 [1] CRAN (R 4.2.0)
#>  globals          0.16.1     2022-08-28 [1] CRAN (R 4.2.0)
#>  glue             1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  gower            1.0.0      2022-02-03 [1] CRAN (R 4.2.0)
#>  GPfit            1.0-8      2019-02-08 [1] CRAN (R 4.2.0)
#>  gtable           0.3.0      2019-03-25 [1] CRAN (R 4.2.0)
#>  hardhat          1.2.0      2022-06-30 [1] CRAN (R 4.2.1)
#>  highr            0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>  htmltools        0.5.3      2022-07-18 [1] CRAN (R 4.2.1)
#>  infer          * 1.0.2      2022-06-10 [1] CRAN (R 4.2.1)
#>  ipred            0.9-13     2022-06-02 [1] CRAN (R 4.2.1)
#>  iterators        1.0.14     2022-02-05 [1] CRAN (R 4.2.0)
#>  jsonlite         1.8.0      2022-02-22 [1] CRAN (R 4.2.0)
#>  knitr            1.39       2022-04-26 [1] CRAN (R 4.2.0)
#>  lattice          0.20-45    2021-09-22 [1] CRAN (R 4.2.0)
#>  lava             1.6.10     2021-09-02 [1] CRAN (R 4.2.0)
#>  lhs              1.1.5      2022-03-22 [1] CRAN (R 4.2.0)
#>  lifecycle        1.0.1      2021-09-24 [1] CRAN (R 4.2.0)
#>  lightgbm         3.3.2      2022-01-14 [1] CRAN (R 4.2.1)
#>  listenv          0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>  lubridate        1.8.0      2021-10-07 [1] CRAN (R 4.2.0)
#>  magrittr         2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>  MASS             7.3-57     2022-04-22 [1] CRAN (R 4.2.0)
#>  Matrix           1.4-1      2022-03-23 [1] CRAN (R 4.2.0)
#>  modeldata      * 1.0.0      2022-07-01 [1] CRAN (R 4.2.1)
#>  munsell          0.5.0      2018-06-12 [1] CRAN (R 4.2.0)
#>  nnet             7.3-17     2022-01-16 [1] CRAN (R 4.2.0)
#>  palmerpenguins * 0.1.1      2022-08-15 [1] CRAN (R 4.2.1)
#>  parallelly       1.32.1     2022-07-21 [1] CRAN (R 4.2.1)
#>  parsnip        * 1.0.1      2022-08-18 [1] CRAN (R 4.2.1)
#>  pillar           1.8.1      2022-08-19 [1] CRAN (R 4.2.1)
#>  pkgconfig        2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>  prodlim          2019.11.13 2019-11-17 [1] CRAN (R 4.2.0)
#>  purrr          * 0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>  R6               2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>  Rcpp             1.0.9      2022-07-08 [1] CRAN (R 4.2.1)
#>  recipes        * 1.0.1      2022-07-07 [1] CRAN (R 4.2.0)
#>  reprex           2.0.2      2022-08-17 [1] CRAN (R 4.2.1)
#>  rlang            1.0.4      2022-07-12 [1] CRAN (R 4.2.1)
#>  rmarkdown        2.16       2022-08-24 [1] CRAN (R 4.2.1)
#>  rpart            4.1.16     2022-01-24 [1] CRAN (R 4.2.0)
#>  rsample        * 1.1.0      2022-08-08 [1] CRAN (R 4.2.1)
#>  rstudioapi       0.14       2022-08-22 [1] CRAN (R 4.2.0)
#>  scales         * 1.2.1      2022-08-20 [1] CRAN (R 4.2.1)
#>  sessioninfo      1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi          1.7.6      2021-11-29 [1] CRAN (R 4.2.0)
#>  stringr          1.4.1      2022-08-20 [1] CRAN (R 4.2.1)
#>  survival         3.4-0      2022-08-09 [1] CRAN (R 4.2.1)
#>  tibble         * 3.1.8      2022-07-22 [1] CRAN (R 4.2.1)
#>  tidymodels     * 1.0.0      2022-07-13 [1] CRAN (R 4.2.1)
#>  tidyr          * 1.2.0      2022-02-01 [1] CRAN (R 4.2.0)
#>  tidyselect       1.1.2      2022-02-21 [1] CRAN (R 4.2.0)
#>  timeDate         4021.104   2022-07-19 [1] CRAN (R 4.2.1)
#>  tune           * 1.0.0      2022-07-07 [1] CRAN (R 4.2.0)
#>  utf8             1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs            0.4.1      2022-04-13 [1] CRAN (R 4.2.0)
#>  withr            2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>  workflows      * 1.0.0      2022-07-05 [1] CRAN (R 4.2.1)
#>  workflowsets   * 1.0.0      2022-07-12 [1] CRAN (R 4.2.1)
#>  xfun             0.31       2022-05-10 [1] CRAN (R 4.2.0)
#>  yaml             2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#>  yardstick      * 1.0.0      2022-06-06 [1] CRAN (R 4.2.0)
#> 
#>  [1] C:/Users/dchoy/AppData/Local/Programs/R/R-4.2.0/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Binary class tuning with parallel processing (registerDoFuture)

Unable to predict nor extract feature importance.

library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.2.1
#> Warning: package 'broom' was built under R version 4.2.1
#> Warning: package 'scales' was built under R version 4.2.1
#> Warning: package 'infer' was built under R version 4.2.1
#> Warning: package 'modeldata' was built under R version 4.2.1
#> Warning: package 'parsnip' was built under R version 4.2.1
#> Warning: package 'rsample' was built under R version 4.2.1
#> Warning: package 'tibble' was built under R version 4.2.1
#> Warning: package 'workflows' was built under R version 4.2.1
#> Warning: package 'workflowsets' was built under R version 4.2.1
library(bonsai)
library(doFuture)
#> Warning: package 'doFuture' was built under R version 4.2.1
#> Loading required package: foreach
#> 
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#> 
#>     accumulate, when
#> Loading required package: future
#> Warning: package 'future' was built under R version 4.2.1
library(parallel)

split <- modeldata::bivariate_test |>
  initial_split()

data_train <- training(split)
data_test <- testing(split)
folds <- vfold_cv(data_train, 3)

recipe_basic <- data_train |>
  recipe(Class ~ .)

lightgbm_spec <- boost_tree(trees = tune(),) |>
  set_engine("lightgbm") |>
  set_mode("classification")

lightgbm_wflow <- workflow(preprocessor = recipe_basic,
                           spec = lightgbm_spec)

all_cores <- detectCores(logical = FALSE)
registerDoFuture()
cl <- makeCluster(all_cores)
plan(cluster, workers = cl)

training_grid_results <- lightgbm_wflow |>
  tune_grid(resamples = folds,
            grid = 5)

last_fit <- lightgbm_wflow |>
  finalize_workflow(select_best(training_grid_results, "roc_auc")) |>
  last_fit(split)

last_fit |>
  extract_workflow() |>
  predict(head(data_test))
#> Error in predictor$predict(data = data, start_iteration = start_iteration, : Attempting to use a Booster which no longer exists. This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters.

last_fit |>
  extract_fit_engine() |>
  lightgbm::lgb.importance() |>
  lightgbm::lgb.plot.importance()
#> Error in booster$dump_model(num_iteration = num_iteration): Attempting to use a Booster which no longer exists. This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters.

Created on 2022-09-07 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22 ucrt)
#>  os       Windows 10 x64 (build 19042)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  English_Singapore.utf8
#>  ctype    English_Singapore.utf8
#>  tz       Asia/Kuala_Lumpur
#>  date     2022-09-07
#>  pandoc   2.17.1.1 @ C:/Program Files/RStudio/bin/quarto/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.2.0)
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>  bonsai       * 0.2.0      2022-08-31 [1] CRAN (R 4.2.0)
#>  broom        * 1.0.1      2022-08-29 [1] CRAN (R 4.2.1)
#>  class          7.3-20     2022-01-16 [1] CRAN (R 4.2.0)
#>  cli            3.3.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  codetools      0.2-18     2020-11-04 [1] CRAN (R 4.2.0)
#>  colorspace     2.0-3      2022-02-21 [1] CRAN (R 4.2.0)
#>  crayon         1.5.1      2022-03-26 [1] CRAN (R 4.2.0)
#>  data.table     1.14.2     2021-09-27 [1] CRAN (R 4.2.0)
#>  DBI            1.1.3      2022-06-18 [1] CRAN (R 4.2.0)
#>  dials        * 1.0.0      2022-06-14 [1] CRAN (R 4.2.0)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.2.0)
#>  digest         0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>  doFuture     * 0.12.2     2022-04-26 [1] CRAN (R 4.2.1)
#>  dplyr        * 1.0.9      2022-04-28 [1] CRAN (R 4.2.0)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>  evaluate       0.15       2022-02-18 [1] CRAN (R 4.2.0)
#>  fansi          1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>  fastmap        1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>  foreach      * 1.5.2      2022-02-02 [1] CRAN (R 4.2.0)
#>  fs             1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.2.1)
#>  future       * 1.27.0     2022-07-22 [1] CRAN (R 4.2.1)
#>  future.apply   1.9.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.2.1)
#>  ggplot2      * 3.3.6      2022-05-03 [1] CRAN (R 4.2.0)
#>  globals        0.16.1     2022-08-28 [1] CRAN (R 4.2.0)
#>  glue           1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  gower          1.0.0      2022-02-03 [1] CRAN (R 4.2.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.2.0)
#>  gtable         0.3.0      2019-03-25 [1] CRAN (R 4.2.0)
#>  hardhat        1.2.0      2022-06-30 [1] CRAN (R 4.2.1)
#>  highr          0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>  htmltools      0.5.3      2022-07-18 [1] CRAN (R 4.2.1)
#>  infer        * 1.0.2      2022-06-10 [1] CRAN (R 4.2.1)
#>  ipred          0.9-13     2022-06-02 [1] CRAN (R 4.2.1)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.2.0)
#>  jsonlite       1.8.0      2022-02-22 [1] CRAN (R 4.2.0)
#>  knitr          1.39       2022-04-26 [1] CRAN (R 4.2.0)
#>  lattice        0.20-45    2021-09-22 [1] CRAN (R 4.2.0)
#>  lava           1.6.10     2021-09-02 [1] CRAN (R 4.2.0)
#>  lhs            1.1.5      2022-03-22 [1] CRAN (R 4.2.0)
#>  lifecycle      1.0.1      2021-09-24 [1] CRAN (R 4.2.0)
#>  lightgbm       3.3.2      2022-01-14 [1] CRAN (R 4.2.1)
#>  listenv        0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>  lubridate      1.8.0      2021-10-07 [1] CRAN (R 4.2.0)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>  MASS           7.3-57     2022-04-22 [1] CRAN (R 4.2.0)
#>  Matrix         1.4-1      2022-03-23 [1] CRAN (R 4.2.0)
#>  modeldata    * 1.0.0      2022-07-01 [1] CRAN (R 4.2.1)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.2.0)
#>  nnet           7.3-17     2022-01-16 [1] CRAN (R 4.2.0)
#>  parallelly     1.32.1     2022-07-21 [1] CRAN (R 4.2.1)
#>  parsnip      * 1.0.1      2022-08-18 [1] CRAN (R 4.2.1)
#>  pillar         1.8.1      2022-08-19 [1] CRAN (R 4.2.1)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>  prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.2.0)
#>  purrr        * 0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>  Rcpp           1.0.9      2022-07-08 [1] CRAN (R 4.2.1)
#>  recipes      * 1.0.1      2022-07-07 [1] CRAN (R 4.2.0)
#>  reprex         2.0.2      2022-08-17 [1] CRAN (R 4.2.1)
#>  rlang          1.0.4      2022-07-12 [1] CRAN (R 4.2.1)
#>  rmarkdown      2.16       2022-08-24 [1] CRAN (R 4.2.1)
#>  rpart          4.1.16     2022-01-24 [1] CRAN (R 4.2.0)
#>  rsample      * 1.1.0      2022-08-08 [1] CRAN (R 4.2.1)
#>  rstudioapi     0.14       2022-08-22 [1] CRAN (R 4.2.0)
#>  scales       * 1.2.1      2022-08-20 [1] CRAN (R 4.2.1)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi        1.7.6      2021-11-29 [1] CRAN (R 4.2.0)
#>  stringr        1.4.1      2022-08-20 [1] CRAN (R 4.2.1)
#>  survival       3.4-0      2022-08-09 [1] CRAN (R 4.2.1)
#>  tibble       * 3.1.8      2022-07-22 [1] CRAN (R 4.2.1)
#>  tidymodels   * 1.0.0      2022-07-13 [1] CRAN (R 4.2.1)
#>  tidyr        * 1.2.0      2022-02-01 [1] CRAN (R 4.2.0)
#>  tidyselect     1.1.2      2022-02-21 [1] CRAN (R 4.2.0)
#>  timeDate       4021.104   2022-07-19 [1] CRAN (R 4.2.1)
#>  tune         * 1.0.0      2022-07-07 [1] CRAN (R 4.2.0)
#>  utf8           1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs          0.4.1      2022-04-13 [1] CRAN (R 4.2.0)
#>  withr          2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>  workflows    * 1.0.0      2022-07-05 [1] CRAN (R 4.2.1)
#>  workflowsets * 1.0.0      2022-07-12 [1] CRAN (R 4.2.1)
#>  xfun           0.31       2022-05-10 [1] CRAN (R 4.2.0)
#>  yaml           2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#>  yardstick    * 1.0.0      2022-06-06 [1] CRAN (R 4.2.0)
#> 
#>  [1] C:/Users/dchoy/AppData/Local/Programs/R/R-4.2.0/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Rows with missing values in input data seem to be entirely dropped when using lightgbm

Consider the penguin example from the bonsai documentation, slightly modified to include all of the covariates in the penguin dataset:

library(tidymodels)
library(modeldata)
library(bonsai)

mod <-
  boost_tree() %>%
  set_engine(engine = "lightgbm", num_threads = 8, verbose = 2) %>%
  set_mode(mode = "classification") %>%
  fit(formula = species ~ ., data = penguins)

When I run this (with tidymodels 1.1.0, and bonsai 0.2.1), the LightGBM debugging messages tell me that only 333 rows of the penguins data are used, even though there are 344 with some notion of usable labels or features:

[LightGBM] [Debug] Dataset::GetMultiBinFromAllFeatures: sparse rate 0.164164
[LightGBM] [Debug] init for col-wise cost 0.000095 seconds, init for row-wise cost 0.000126 seconds
[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.000226 seconds.
You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Debug] Using Dense Multi-Val Bin
[LightGBM] [Info] Total Bins 262
[LightGBM] [Info] Number of data points in the train set: 333, number of used features: 6
[LightGBM] [Info] Start training from score -0.824536
[LightGBM] [Info] Start training from score -1.588635
[LightGBM] [Info] Start training from score -1.029019
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Debug] Trained a tree with leaves = 6 and depth = 3
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Debug] Trained a tree with leaves = 7 and depth = 3

333 happens to be exactly the count of rows with no missing values:

> penguins %>% complete.cases %>% sum
[1] 333

My guess is that there's something inside bonsai or parsnip that is imposing a complete.cases on the input data, but I wasn't able to find it.

Is this intended? LightGBM is noteworthy for its explicit support for missing values, and I know its possible to take advantage of them through the standard LightGBM R package interface, but I'd love to be able to stay in the tidymodels ecosystem if possible. Is there an option I can be passing to fit() or even the model_spec to ensure this happens?

Thanks in advance for any suggestions!

aorsf with doParallel cluster has interesting error

Hi Simon -

TLDR: I think doParallel clusters are causing an issue with aorsf engine.

I was eager to test the new aorsf random forest engine after your blog post & new bonsai release to CRAN!

My goal was to benchmark against some bagged and boosted trees for one of my projects.  When I swapped out the rand_forest and switched to aorsf, I kept getting a weird error: "parsnip could not locate an implementation for rand_forest regression model specifications using the aorsf engine." But I had updated bonsai, parsnip, even switched R version as I thought I was losing my mind. To make matters even more confusing, I setup a simple regex yesterday, but it worked. (?!). So I thought something was going on with function masking or my environment, constantly refreshing it and testing my main script & the reproducible example.

The one thing I didn't have in my regex yesterday was the pretraining setting: (cluster <- makePSOCKcluster(8); registerDoParallel(cluster)). Once I initialize this cluster in the regex, it seems there is an issue with training the aorsf model (I think?). It's the only way I was able to reproduce the error in the regex.

I am wondering if I should go about this differently, not run a cluster for aorsf?, maybe it is compatible with a different library cluster library? I initialize clusters for bagged and boosted trees so there will be one in my environment unless I ran aorsf in a different script altogether (I am currently running a Quarto code chunk for each model). Open to feedback, solution, or if I'm just losing my mind. Each time it says parsnip could not locate, you can see that it is there with show_engines().

# setup libs, data, recipe ----------------------------------------------------------

library(doParallel) # The issue arised after creating a doParallel cluster ? I think so ??
#> Warning: package 'doParallel' was built under R version 4.4.1
#> Loading required package: foreach
#> Warning: package 'foreach' was built under R version 4.4.1
#> Loading required package: iterators
#> Warning: package 'iterators' was built under R version 4.4.1
#> Loading required package: parallel

library(parsnip)
#> Warning: package 'parsnip' was built under R version 4.4.1


``` r
library(rsample)
#> Warning: package 'rsample' was built under R version 4.4.1
library(tune)
#> Warning: package 'tune' was built under R version 4.4.1
library(yardstick)
#> Warning: package 'yardstick' was built under R version 4.4.1
library(workflows)
#> Warning: package 'workflows' was built under R version 4.4.1
library(recipes)
#> Warning: package 'recipes' was built under R version 4.4.1
#> Loading required package: dplyr
#> Warning: package 'dplyr' was built under R version 4.4.1
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#> 
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#> 
#>     step
library(dplyr)
library(bonsai)
#> Warning: package 'bonsai' was built under R version 4.4.1
#library(aorsf)
library(tidymodels)   # THE ISSUE IS AN UNDERLYING LIB mask? ??? I don't think so
#> Warning: package 'tidymodels' was built under R version 4.4.1
#> Warning: package 'broom' was built under R version 4.4.1
#> Warning: package 'dials' was built under R version 4.4.1
#> Warning: package 'scales' was built under R version 4.4.1
#> Warning: package 'ggplot2' was built under R version 4.4.1
#> Warning: package 'infer' was built under R version 4.4.1
#> Warning: package 'modeldata' was built under R version 4.4.1
#> Warning: package 'purrr' was built under R version 4.4.1
#> Warning: package 'tibble' was built under R version 4.4.1
#> Warning: package 'tidyr' was built under R version 4.4.1
#> Warning: package 'workflowsets' was built under R version 4.4.1
library(finetune)
#> Warning: package 'finetune' was built under R version 4.4.1
#library(dplyr)

training <- ChickWeight |> ungroup() |> tibble::as_tibble() |> mutate(Chick = as.numeric(Chick))

folds <- vfold_cv(data = training, v = 10)

rm_smth <- "Diet"

model_recipe <- 
   recipe(weight ~ ., training) |>
    step_rm(any_of(!!rm_smth)) |>
    step_dummy(all_nominal_predictors()) |>
    step_YeoJohnson(all_nominal_predictors()) 


# random forest spec and grid ----------------------------------------
forest_grid <-  expand.grid(
  trees = c(20, 50),
  mtry = c(2, 5, 7)
)

orf_spec <- rand_forest(
  trees = tune(),
  mtry = tune()) |>
  set_engine("aorsf") |> # issues arise for aorsf:
  set_mode("regression")


# pre training settings ---
cluster <- makePSOCKcluster(8)
registerDoParallel(cluster)

# model creation -------------------------------------------
orf_results <-
  finetune::tune_race_anova(
    workflow() |>
      add_recipe(model_recipe) |>
      add_model(orf_spec),
    resamples = folds,
    grid = forest_grid,
    control = control_race(),
    metrics = metric_set(yardstick::rmse)
  )
#> Warning: All models failed. Run `show_notes(.Last.tune.result)` for more
#> information.
#> Error in `test_parameters_gls()`:
#> ! There were no valid metrics for the ANOVA model.
# post training settings ---
stopCluster(cluster)
registerDoSEQ()


show_notes(.Last.tune.result)
#> unique notes:
#> ──────────────────────────────────────────────────────────────
#> Error:
#> ! parsnip could not locate an implementation for `rand_forest`
#>   regression model specifications using the `aorsf` engine.
parsnip::show_engines("rand_forest")
#> # A tibble: 10 × 2
#>    engine       mode          
#>    <chr>        <chr>         
#>  1 ranger       classification
#>  2 ranger       regression    
#>  3 randomForest classification
#>  4 randomForest regression    
#>  5 spark        classification
#>  6 spark        regression    
#>  7 partykit     regression    
#>  8 partykit     classification
#>  9 aorsf        classification
#> 10 aorsf        regression
# select_best(orf_results)

Created on 2024-06-27 with reprex v2.1.0

catboost

Will bonsai support catboost?

Upkeep for bonsai

Pre-history

  • usethis::use_readme_rmd()
  • usethis::use_roxygen_md()
  • usethis::use_github_links()
  • usethis::use_pkgdown_github_pages()
  • usethis::use_tidy_github_labels()
  • usethis::use_tidy_style()
  • usethis::use_tidy_description()
  • urlchecker::url_check()

2020

  • usethis::use_package_doc()
    Consider letting usethis manage your @importFrom directives here.
    usethis::use_import_from() is handy for this.
  • usethis::use_testthat(3) and upgrade to 3e, testthat 3e vignette
  • Align the names of R/ files and test/ files for workflow happiness.
    The docs for usethis::use_r() include a helpful script.
    usethis::rename_files() may be be useful.

2021

  • usethis::use_tidy_dependencies()
  • usethis::use_tidy_github_actions() and update artisanal actions to use setup-r-dependencies
  • Remove check environments section from cran-comments.md
  • Bump required R version in DESCRIPTION to 3.5
  • Use lifecycle instead of artisanal deprecation messages, as described in Communicate lifecycle changes in your functions
  • Make sure RStudio appears in Authors@R of DESCRIPTION like so, if appropriate:
    person("RStudio", role = c("cph", "fnd"))

2022

2023

Necessary:

  • Update copyright holder in DESCRIPTION: person(given = "Posit, PBC", role = c("cph", "fnd"))
  • Double check license file uses '[package] authors' as copyright holder. Run use_mit_license()
  • Update email addresses *@rstudio.com -> *@posit.co
  • Update logo (https://github.com/rstudio/hex-stickers); run use_tidy_logo()
  • usethis::use_tidy_coc()

Optional:

  • Review 2022 checklist to see if you completed the pkgdown updates
  • Prefer pak::pak("org/pkg") over devtools::install_github("org/pkg") in README
  • Consider running use_tidy_dependencies() and/or replace compat files with use_standalone()
  • use_standalone("r-lib/rlang", "types-check") instead of home grown argument checkers
  • Add alt-text to pictures, plots, etc; see https://posit.co/blog/knitr-fig-alt/ for examples
  • usethis::use_tidy_github_actions()

Supporting EBMs (Explainable Boosting Machines)

Hi, I was wondering if there are plans to support EBMs since they appear quite promising in terms of performance (matching XGB etc) whilst remaining very interpretable.

There is even an R package interpret, but I haven't found a way to make it play nicely with parsnip.

handling `mtry` with lightgbm

The package currently uses this setting:

bonsai/R/lightgbm_data.R

Lines 161 to 168 in eb13b9e

parsnip::set_model_arg(
model = "boost_tree",
eng = "lightgbm",
parsnip = "mtry",
original = "feature_fraction",
func = list(pkg = "dials", fun = "mtry"),
has_submodel = FALSE
)

feature_fraction is actually mtry / number_of_predictors, though. From what I understand, rules::mtry_prop can be of use here. Still need to look into the most principled way to then handle mtry passed to boost_tree.

resolving aliased argument names with lightgbm

Following up on @jameslamb's comment here—thank you for being willing to discuss. :)


Some background, for the GitHub archeologists:

lightgbm allows passing many of its arguments with aliases. On the parsnip side, these include both main and engine arguments to boost_tree(), including the now-tunable engine argument num_leaves. On the lightgbm side, these include both "core" and "control" arguments.

As of now, any aliases supplied to set_engine are passed in the dots of bonsai::train_lightgbm() to the dots of lightgbm::lgb.train(). lightgbm's machinery takes care of resolving aliases, with some rules that generally prevent silent failures while tuning:

https://github.com/microsoft/LightGBM/blob/e45fc48405e9877138ffb5f7e1fd4c449752d323/R-package/R/utils.R#L176-L181

  • If a main argument is marked for tuning and its main translation (i.e. lightgbm's non-alias argument name) is supplied as an engine arg, parsnip machinery will throw warnings that the engine argument will be ignored. e.g. for min_n -> min_data_in_leaf:
! Bootstrap1: preprocessor 1/1, model 1/3: The following arguments cannot be manually modified and were removed: min_data_in_leaf
...
  • If a main argument is marked for tuning and a lightgbm alias is supplied as an engine arg, we ignore the alias silently. (Note that bonsai::train_lightgbm() sets lgb.train()'s verbose argument to 1L if one isn't supplied.)

  • The scariest issue I'd anticipate is the user not touching the main argument (that will be translated to the main, non-alias lgb.train argument), but setting the alias in set_engine(). In that case, the bonsai::train_lightgbm() default kicks in, and the user-supplied engine argument is silently ignored in favor of the default supplied as the non-alias lightgbm argument.🫣

Reprex here. (Click to expand)
library(tidymodels)
library(bonsai)
library(testthat)

data("penguins", package = "modeldata")

penguins <- penguins[complete.cases(penguins),]
penguins_split <- initial_split(penguins)
set.seed(1)
boots <- bootstraps(training(penguins_split), 3)
base_wf <- workflow() %>% add_formula(bill_length_mm ~ .)

Marking a main argument for tuning, as usual:

bt_spec <-
  boost_tree(min_n = tune()) %>%
  set_engine("lightgbm") %>%
  set_mode("regression")

bt_wf <-
  base_wf %>%
  add_model(bt_spec)

set.seed(1)
bt_res_correct <- tune_grid(bt_wf, boots, grid = 3, control = control_grid(save_pred = TRUE))

bt_res_correct
#> # Tuning results
#> # Bootstrap sampling 
#> # A tibble: 3 × 5
#>   splits           id         .metrics         .notes           .predictions
#>   <list>           <chr>      <list>           <list>           <list>      
#> 1 <split [249/93]> Bootstrap1 <tibble [6 × 5]> <tibble [0 × 3]> <tibble>    
#> 2 <split [249/93]> Bootstrap2 <tibble [6 × 5]> <tibble [0 × 3]> <tibble>    
#> 3 <split [249/97]> Bootstrap3 <tibble [6 × 5]> <tibble [0 × 3]> <tibble>

Marking a main argument for tuning, and supplying its non-alias translation as engine arg:

bt_spec <-
  boost_tree(min_n = tune()) %>%
  set_engine("lightgbm", min_data_in_leaf = 1) %>%
  set_mode("regression")

bt_wf <-
  base_wf %>%
  add_model(bt_spec)

set.seed(1)
bt_res_both <- tune_grid(bt_wf, boots, grid = 3)
#> ! Bootstrap1: preprocessor 1/1, model 1/3: The following arguments cannot be manually modified and were removed: mi...
#> ! Bootstrap1: preprocessor 1/1, model 2/3: The following arguments cannot be manually modified and were removed: mi...
#> ! Bootstrap1: preprocessor 1/1, model 3/3: The following arguments cannot be manually modified and were removed: mi...
#> ! Bootstrap2: preprocessor 1/1, model 1/3: The following arguments cannot be manually modified and were removed: mi...
#> ! Bootstrap2: preprocessor 1/1, model 2/3: The following arguments cannot be manually modified and were removed: mi...
#> ! Bootstrap2: preprocessor 1/1, model 3/3: The following arguments cannot be manually modified and were removed: mi...
#> ! Bootstrap3: preprocessor 1/1, model 1/3: The following arguments cannot be manually modified and were removed: mi...
#> ! Bootstrap3: preprocessor 1/1, model 2/3: The following arguments cannot be manually modified and were removed: mi...
#> ! Bootstrap3: preprocessor 1/1, model 3/3: The following arguments cannot be manually modified and were removed: mi...

bt_res_both
#> # Tuning results
#> # Bootstrap sampling 
#> # A tibble: 3 × 4
#>   splits           id         .metrics         .notes          
#>   <list>           <chr>      <list>           <list>          
#> 1 <split [249/93]> Bootstrap1 <tibble [6 × 5]> <tibble [3 × 3]>
#> 2 <split [249/93]> Bootstrap2 <tibble [6 × 5]> <tibble [3 × 3]>
#> 3 <split [249/97]> Bootstrap3 <tibble [6 × 5]> <tibble [3 × 3]>
#> 
#> There were issues with some computations:
#> 
#>   - Warning(s) x9: The following arguments cannot be manually modified and were remo...
#> 
#> Run `show_notes(.Last.tune.result)` for more information.

Marking a main argument for tuning, and supplying an alias to tune as engine arg:

set.seed(1)
bt_spec <-
  boost_tree(min_n = tune()) %>%
  set_engine("lightgbm", min_data_per_leaf = 1) %>%
  set_mode("regression")

bt_wf <-
  base_wf %>%
  add_model(bt_spec)

bt_res_alias <- 
  tune_grid(
    bt_wf, boots, grid = 3, 
    control = control_grid(extract = extract_fit_engine, save_pred = TRUE)
  )

Note that both params end up in the resulting object, though only one is reference when making predictions.

bt_res_alias %>%
  pull(.extracts) %>%
  `[[`(1) 
#> # A tibble: 3 × 3
#>   min_n .extracts  .config             
#>   <int> <list>     <chr>               
#> 1    13 <lgb.Bstr> Preprocessor1_Model1
#> 2    33 <lgb.Bstr> Preprocessor1_Model2
#> 3    25 <lgb.Bstr> Preprocessor1_Model3

lgb_fit <- bt_res_alias %>%
  pull(.extracts) %>%
  `[[`(1) %>%
  pull(.extracts) %>%
  `[[`(1)

lgb_fit$params$min_data_in_leaf
#> [1] 13
lgb_fit$params$min_data_per_leaf
#> [1] 1

# all good
expect_equal(
  collect_predictions(bt_res_correct),
  collect_predictions(bt_res_alias)
)

bt_mets_correct <- 
  bt_res_correct %>%
  select_best("rmse") %>%
  finalize_workflow(bt_wf, parameters = .) %>%
  last_fit(penguins_split)

bt_mets_alias <- 
  bt_res_alias %>%
  select_best("rmse") %>%
  finalize_workflow(bt_wf, parameters = .) %>%
  last_fit(penguins_split)

# all good
expect_equal(
  bt_mets_correct$.metrics,
  bt_mets_alias$.metrics
)

Created on 2022-11-04 with reprex v2.0.2


I think the best approach here would be to raise a warning or error whenever an alias that maps to a main boost_tree() argument is supplied, and note that it can be resolved by passing as a main argument to boost_tree(). Otherwise, passing aliases as engine arguments (i.e. that don't map to main arguments) seems unproblematic to me. Another option is to set verbose to a setting that allows lightgbm to propogate its own prompts with duplicated aliases when any alias is supplied, though this feels like it might obscure train_lightgbm()s role in passing a non-aliased argument. Either way, this requires being able to detect when an alias is supplied.

A question for you, James, if you're up for it—is there any sort of dictionary that we could reference that would contain these mappings? A list like that currently outputted by lightgbm:::.PARAMETER_ALIASES() would be perfect, though that also contains the parameters listed under "Learning Control Parameters".

We could also put that together ourselves—we'd just need the mappings for 8 of them:

library(tidymodels)
library(bonsai)

get_from_env("boost_tree_args") %>%
  filter(engine == "lightgbm")
#> # A tibble: 8 × 5
#>   engine   parsnip        original                func             has_submodel
#>   <chr>    <chr>          <chr>                   <list>           <lgl>       
#> 1 lightgbm tree_depth     max_depth               <named list [2]> FALSE       
#> 2 lightgbm trees          num_iterations          <named list [2]> TRUE        
#> 3 lightgbm learn_rate     learning_rate           <named list [2]> FALSE       
#> 4 lightgbm mtry           feature_fraction_bynode <named list [2]> FALSE       
#> 5 lightgbm min_n          min_data_in_leaf        <named list [2]> FALSE       
#> 6 lightgbm loss_reduction min_gain_to_split       <named list [2]> FALSE       
#> 7 lightgbm sample_size    bagging_fraction        <named list [2]> FALSE       
#> 8 lightgbm stop_iter      early_stopping_rounds   <named list [2]> FALSE

Created on 2022-11-04 with reprex v2.0.2

predict `type` warning with lightgbm

── Warning ('test-lightgbm.R:207'): boost_tree with lightgbm ───────────────────
predict.lgb.Booster: Found the following passed through '...': type. These will be used, but in future 
releases of lightgbm, this warning will become an error. Add these to 'params' instead. See 
?predict.lgb.Booster for documentation on how to call this function.

'reshape' argument to predict() will be removed in LightGBM v4.0.0

As of microsoft/LightGBM#4971, lightgbm:::predict.lgb.Booster() no longer accepts the argument reshape. That pull request will be a part of LightGBM v4.0.0 (#34, microsoft/LightGBM#5153).

{bonsai} passes reshape = TRUE to predict.lgb.Booster() in several places.

p <- stats::predict(object$fit, prepare_df_lgbm(new_data), reshape = TRUE, ...)

stats::predict(object$fit, prepare_df_lgbm(new_data), reshape = TRUE, rawscore = TRUE, ...)

bonsai/R/lightgbm.R

Lines 304 to 311 in ad9bb39

p <-
stats::predict(
object$fit,
prepare_df_lgbm(new_data),
reshape = TRUE,
params = list(predict_disable_shape_check = TRUE),
...
)

As a result, {bonsai}'s unit tests fail against {lightgbm}'s development version (confirmed with the approach mentioned in #34 (comment)).

{bonsai}'s uses of predict.lgb.Booster(..., reshape = TRUE ...) should be removed.

` feature_fraction` description maps to wrong parameter

From the documentation feature_fraction is described as:

Rather than as a number, lightgbm::lgb.train()’s feature_fraction argument encodes mtry as the proportion of predictors that will be randomly sampled at each split. doc

This doesn't line up with the documentation from lightgbm
image

It seems like the parameter should be feature_fraction_bynode to align with the parsnip documentation and other models.

Method for passing linear_tree or other params that should be routed to lgb.Dataset rather than lgb.train

The current engine implementation for lightgbm sends additional engine arguments to the function lightgbm::lgb.train. However, one of the dataset parameters is really a model option, linear_tree, which fits linear regression models in the leaf nodes rather than constant models in each leaf.

Currently, the tidymodels engine argument mechanics don't allow for specifying where additional engine arguments should be directed. One could simply add a single fixed argument for linear_tree as I did in this crude example for experimentation, but that's not really a long term solution, if there are other lightgbm arguments that people would like to access in a similar manner.

Would it be possible to more fully expose engine arguments for lightgbm in a way that distinguishes between those intended for lgb.train() vs. lgb.Dataset(), or is an argument by argument method the best we can do?

partykit prediction differences

re:

# TODO unsure why this is close but not the same
# expect_error_free(cf_pred_2 <- predict(cf_fit_2, mtcars)$.pred)
# pk_pred_2 <- unname(predict(pk_fit_2, mtcars))
# expect_equal(pk_pred_2, cf_pred_2)

EDIT: passing mincriterion and testtype does not fix the issue—made a mistake with RNG. Still troubleshooting here.

Reprex:

library(partykit)
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm
library(bonsai)
library(parsnip)
library(testthat)
  
set.seed(1)
cf_fit_1 <-
  rand_forest(trees = 5) %>%
  set_engine("partykit") %>%
  set_mode("regression") %>%
  fit(mpg ~ ., data = mtcars)

set.seed(1)

pk_fit_1 <- cforest(mpg ~ ., data = mtcars, ntree = 5)

expect_equal(pk_fit_1$fitted, cf_fit_1$fit$fitted)

cf_pred_1 <- predict(cf_fit_1, mtcars)$.pred

pk_pred_1 <- unname(predict(pk_fit_1, mtcars))

expect_equal(pk_pred_1, cf_pred_1)
#> Error: `pk_pred_1` not equal to `cf_pred_1`.
#> 32/32 mismatches (average diff: 0.73)
#> [1] 21.8 - 21.8 == 0.0188
#> [2] 21.8 - 21.8 == 0.0188
#> [3] 26.2 - 26.1 == 0.0329
#> [4] 19.8 - 17.9 == 1.9173
#> [5] 16.4 - 15.8 == 0.6046
#> [6] 21.7 - 17.9 == 3.7786
#> [7] 16.4 - 15.8 == 0.6046
#> [8] 26.2 - 24.0 == 2.1366
#> [9] 26.2 - 26.1 == 0.0329
#> ...

Created on 2022-05-03 by the reprex package (v2.0.1)

Missing details_boost_tree_lightgbm help page

On windows platform the command ?details_boost_tree_lightgbm returns
No documentation for ‘details_boost_tree_lightgbm’ in specified packages and libraries: you could try ‘??details_boost_tree_lightgbm’

??details_boost_tree_lightgbm shows No results found

The partykit help pages are fine.

`aorsf` - engine: model fit fails if `mtry` is specified

Hi,

the model fit fails if mtry is specified for the aorsf-engine. If it is not specified, it works with the default engine values.

library(bonsai)
#> Loading required package: parsnip

# This works with default mtry value
rf_mod <- 
  rand_forest() %>%
  set_engine(engine = "aorsf") %>%
  set_mode(mode = "regression") %>% 
  set_args(min_n = 1, trees = 2, importance = "permute") %>% 
  fit(
    formula = mpg  ~ . , 
    data = mtcars 
  )


rf_mod
#> parsnip model object
#> 
#> ---------- Oblique random regression forest
#> 
#>      Linear combinations: Accelerated Linear regression
#>           N observations: 32
#>                  N trees: 2
#>       N predictors total: 10
#>    N predictors per node: 4
#>  Average leaves per tree: 7.5
#> Min observations in leaf: 1
#>           OOB stat value: 0.27
#>            OOB stat type: RSQ
#>      Variable importance: permute
#> 
#> -----------------------------------------



# Error occurs...
rf_mod_w_mtry <- 
  rand_forest() %>%
  set_engine(engine = "aorsf") %>%
  set_mode(mode = "regression") %>% 
  set_args(mtry = 3, min_n = 1, trees = 2, importance = "permute") %>% 
  fit(
    formula = mpg  ~ . , 
    data = mtcars 
  )
#> Error in ncol(source): object 'x' not found

Created on 2024-08-08 with reprex v2.0.2


Thank you in advance and best regards

Feature idea - Linking hyperparameters during CV

Problem

Within LightGBM, num_leaves is capped at 2 ^ max_depth. For example, if num_leaves is set to 1000 and max_depth is set to 5, then LightGBM will likely end up creating a full-depth tree with 32 (2 ^ 5) leaves per iteration.

{bonsai} / {parsnip} have no knowledge of the relationship between these parameters. As a result, during cross-validation, Bayesian optimization and other CV search methods will spend a significant amount of time exploring meaningless hyperparameter space where num_leaves > 2 ^ max_depth. This results in longer CV times, especially for large models with many parameters.

Idea

One potential solution is to explicitly link num_leaves and max_depth specifically for the LightGBM model spec. I implemented this link in my treesnip fork by essentially adding two engine arguments:

  1. link_max_depth - Boolean. When FALSE, max_depth is equal to whatever is passed via engine/model arg. When TRUE, max_depth is equal to {floor(log2(num_leaves)) + link_max_depth_add.
  2. link_max_depth_add - Integer. Value added to max_depth. For example, if link_max_depth is TRUE, num_leaves is 1000, and link_max_depth_add is 2, then max_depth = floor(log2(1000)) + 2, or 11.

This would improve cross-validation times by restricting the hyperparameter space that needs to be explored while leaving the default options untouched. Ideally, it could even be generalized (within {parsnip}) to other model types that have intrinsically linked hyperparameters. However, not sure if this fits with the Tidymodels way of doing things. If it's totally out-of-scope, then feel free to close this issue.

Add gbm?

This was, for a while, the premier boosting package in R. There was a Community post the other day where someone was setting up parsnip to use it.

allow passing main arguments to `lightgbm::lgb.train`

At the moment, train_lightgbm only passes elements from set_engine to elements in the params argument to lightgbm::lgb.train. This seems to be where most parameter adjustments happen, but some of the main arguments would be nice to be able to set as well.

library(bonsai)
data(penguins, package = "modeldata")

m <- boost_tree() %>%
  set_engine("lightgbm", nrounds = 100) %>%
  set_mode("regression") %>%
  fit(bill_length_mm ~ ., data = penguins)
#> [LightGBM] [Warning] Unknown parameter: nrounds

Created on 2022-05-02 by the reprex package (v2.0.1)

properly resolve `obj` alias `objective`

Error message Error in init(env): For early stopping, valids must have at least one element is produced when trying to train a multi-class model with

lgb_model <-
  boost_tree(
      stop_iter = 1000
  ) %>%
  set_engine(   
     "lightgbm",
      validation = 0.2,
      objective = 'multiclass',
      metric="none",
      eval = 'auc_mu',
      num_class = n_classes
  )

multi_predict() doesn't support `type = "raw"` predictions for `{lightgbm}` classification models

There is some code in {bonsai} that looks like it was intended to support multi_predict(..., type = "raw") for {lightgbm} classification models.

bonsai/R/lightgbm_data.R

Lines 146 to 158 in 6c090e1

parsnip::set_pred(
model = "boost_tree",
eng = "lightgbm",
mode = "classification",
type = "raw",
value = parsnip::pred_value_template(
pre = NULL,
post = NULL,
func = c(pkg = "bonsai", fun = "predict_lightgbm_classification_raw"),
object = quote(object),
new_data = quote(new_data)
)
)

However, I don't believe {bonsai} actually respects type = "raw" for multi_predict().

Reproducible Example

See the following coded for evidence of this claim. I saw this behavior with both {lightgbm} v3.3.2 installed from CRAN and with the latest development version (microsoft/LightGBM@c7102e5).

sessionInfo() (click me)
R version 4.1.0 (2021-05-18)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS 12.2.1

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] modeldata_1.0.0   lightgbm_3.3.2    R6_2.5.1          dplyr_1.0.9       bonsai_0.1.0.9000
[6] parsnip_1.0.0    

loaded via a namespace (and not attached):
 [1] rstudioapi_0.13   magrittr_2.0.3    tidyselect_1.1.2  munsell_0.5.0     lattice_0.20-45  
 [6] colorspace_2.0-3  rlang_1.0.4       fansi_1.0.3       tools_4.1.0       hardhat_1.2.0    
[11] grid_4.1.0        data.table_1.14.2 gtable_0.3.0      utf8_1.2.2        cli_3.3.0        
[16] withr_2.5.0       ellipsis_0.3.2    tibble_3.1.7      lifecycle_1.0.1   crayon_1.5.1     
[21] Matrix_1.4-0      purrr_0.3.4       ggplot2_3.3.6     tidyr_1.2.0       vctrs_0.4.1      
[26] glue_1.6.2        compiler_4.1.0    pillar_1.8.0      dials_1.0.0       generics_0.1.3   
[31] scales_1.2.0      jsonlite_1.8.0    DiceDesign_1.9    pkgconfig_2.0.3
library(bonsai)
library(dplyr)
library(lightgbm)
library(modeldata)
library(parsnip)

data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]

penguins_subset <- penguins[1:10,]
penguins_subset_numeric <-
    penguins_subset %>%
    mutate(across(where(is.character), ~as.factor(.x))) %>%
    mutate(across(where(is.factor), ~as.integer(.x) - 1))

clf_multiclass_fit <-
    boost_tree(trees = 5) %>%
    set_engine("lightgbm") %>%
    set_mode("classification") %>%
    fit(species ~ ., data = penguins)

new_data <-
    penguins_subset_numeric %>%
    select(-species) %>%
    as.matrix()

preds_bonsai_raw <-
    multi_predict(
        clf_multiclass_fit
        , new_data = new_data[1, , drop = FALSE]
        , trees = seq_len(4)
        , type = "raw"
    )

preds_lgb_raw <-
    t(sapply(
        X = seq_len(4)
        , FUN = function(booster, new_data, num_iteration) {
            booster$predict(new_data, num_iteration = num_iteration, rawscore = TRUE)
        }
        , booster = clf_multiclass_fit$fit
        , new_data = new_data[1, , drop = FALSE]
    ))

preds_bonsai_prob <-
    multi_predict(
        clf_multiclass_fit
        , new_data = new_data[1, , drop = FALSE]
        , trees = seq_len(4)
        , type = "prob"
    )

The predictions from multi_predict(..., type = "raw") look like probabilities (between 0 and 1, sum to 1) and don't match {lightgbm}'s output for raw predictions.

preds_bonsai_raw[[".pred"]][[1]]
# A tibble: 4 × 4
#  trees .pred_Adelie .pred_Chinstrap .pred_Gentoo
#  <int>        <dbl>           <dbl>        <dbl>
#      1        0.500           0.184        0.316
#      2        0.556           0.165        0.279
#      3        0.607           0.147        0.246
#      4        0.652           0.131        0.217

preds_lgb_raw
#            [,1]      [,2]      [,3]
# [1,] -0.6724811 -1.672408 -1.132757
# [2,] -0.5392134 -1.754103 -1.230182
# [3,] -0.4193116 -1.834036 -1.322633
# [4,] -0.3093926 -1.912255 -1.411070

type = "prob" predictions look correct, and like probabilities.

preds_bonsai_prob[[".pred"]][[1]]
# A tibble: 4 × 4
#   trees .pred_Adelie .pred_Chinstrap .pred_Gentoo
#  <int>        <dbl>           <dbl>        <dbl>
# 1     1        0.500           0.184        0.316
# 2     2        0.556           0.165        0.279
# 3     3        0.607           0.147        0.246
# 4     4        0.652           0.131        0.217

I observed the same thing for binary classification models. This doesn't matter for regression models, because "raw" predictions are the default for {lightgbm} regression models using built-in objectives.

Notes for Maintainers

I believe the issue is that this block does not contain an if (type == "raw") condition:

bonsai/R/lightgbm.R

Lines 366 to 375 in 6c090e1

} else {
if (type == "class") {
pred <- predict_lightgbm_classification_class(object, new_data, num_iteration = tree)
pred <- tibble::tibble(.pred_class = factor(pred, levels = object$lvl))
} else {
pred <- predict_lightgbm_classification_prob(object, new_data, num_iteration = tree)
names(pred) <- paste0(".pred_", names(pred))
}

Is it expected that {bonsai} supports multi_predict(..., type = "raw") for {lightgbm} classification models? If so, would you be open to me putting up a pull request to add this support?

Thanks for your time and consideration.

Feature idea - provide custom validation sets for early stopping

Thanks for creating this excellent package. I created a similar fork of treesnip but am planning to replace it with {bonsai} in all our production models.

One feature that I think would be incredibly useful in {bonsai} is the ability to provide custom validation sets during early stopping (instead of using a random split of the training data). This would have a few potential benefits:

  1. More training data. In many cases, you're already going to have a validation set set aside from a classic train, validate, test split. Currently, {bonsai} will further split the train data into train subset and validation specifically for early stopping sets. Instead, it would be ideal to be able to pass the validate set directly. This would mean all of train would be used for training.
  2. Ability to do more complex cross-validation. Certain cross-validation techniques (rolling origin, spatial, etc.) don't rely on a random sample of the training data and instead use some sort of partitioning (time or geographic). Allowing custom validation data would let users use the "correct" validation set for early stopping when using these more complex methods.
  3. Better integration with tidymodels. Tidymodels supports k-fold and other types of cross-validation. Using the validation set created for each fold rather than splitting a separate validation set specifically for early stopping would be much simpler.

Let me know if this is out-of-scope for this project. If not, I'm happy to contribute if needed.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.