GithubHelp home page GithubHelp logo

cpi's People

Contributors

dswatson avatar mnwright avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

cpi's Issues

`resampling = "oob"` fails when an observation is never out-of-bag

Hello! Thanks very much for developing this methodology and package for calculating variable importance.

I'm using it to try to make inference from a random forest model fit to large-ish data and with collinearity in the predictors. I've found the CPI approach to be more satisfying than arbitrarily dropping correlated predictors below some threshold value of correlation.

I was exploring the use of the "oob" method for computing loss, and kept getting the following error:

Error in cpi_fun(j) : 
  task 1 failed - "missing value where TRUE/FALSE needed"

I tracked it down to occurring when an observation is never out-of-bag in every tree (i.e., it is always in bag in every tree). This can happen when observations are weighted or if there aren't very many trees. You can recreate the error using this example from the help file of the package and setting the num.trees argument to something low:

mytask <- as_task_regr(iris, target = "Petal.Length")
cpi::cpi(task = mytask, learner = lrn("regr.ranger", keep.inbag = TRUE, num.trees = 10, seed = 2), 
    resampling = "oob", 
    knockoff_fun = seqknockoff::knockoffs_seq)

In the weeds

The problem starts with the creation of the oob_idx object here:

oob_idx <- ifelse(simplify2array(mod$model$inbag.counts) == 0, TRUE, NA)

If an observation is never out of bag for all the trees, then the rowMeans() calculation for that line is NaN even with na.rm = TRUE in this line:

y_hat <- rowMeans(oob_idx * preds, na.rm = TRUE)
.

Then there are NaN in the predictions, which puts NA in the loss returned by compute_loss(), which all carry through to be NA when calculating dif here:

cpi/R/cpi.R

Line 272 in 42a7e0b

dif <- err_reduced - err_full
. Then the cpi, calculated as mean(dif) returns an NA and the if(cpi == 0) line errors out here:

cpi/R/cpi.R

Line 292 in 42a7e0b

if (cpi == 0) {
.

I think you can get around this by including na.rm = TRUE in the cpi <- mean(dif) call, but maybe there's a deeper philosophical issue where it wouldn't be prudent to drop the "only ever in bag" observations from the CPI calculation? Particularly if we get to that point because of weighting that observation to be favored as in-bag? I suspect this might be the case, in which case perhaps the best change would be a more informative error message?

cpi <- mean(dif, na.rm = TRUE)
se <- sd(dif, na.rm = TRUE) / sqrt(length(which(!is.na(dif))))

consequences of logloss as loss measure for binary classification when optimal classification threshold != 0.5

Hello CPI team,

This is such a fantastic package and I'm so happy that I found it (and the suite of papers your team has written about state-of-the-art random forests for inference).

This may be a question that points to a broader theoretical question, but the {cpi} package is how I came to it, so I'm starting here...

tl;dr
What are the implications of using log loss as the loss function when the tuned classification threshold is not at 0.5? I'm sure someone has thought/written about this, but I'm having trouble finding those papers!

Example
A model prediction of 0.6 for an observation in the positive class seems like it should carry different information about loss if the optimal classification threshold for the model is 0.5 versus 0.2 (for instance). In both cases, a model prediction of 0.6 would correctly classify the observation as belonging to the positive class. But if the tuned classification threshold is 0.5, then a model prediction of 0.6 is barely over that threshold compared to the case where the tuned classification threshold is 0.2 (the model prediction is quite a bit over the 0.6 threshold). But the log loss would be equal in each case. Naively, I would have expected a loss function would show more loss for a prediction of 0.6 when the classification threshold is 0.5, and less loss when then the classification threshold is 0.2.

Possible solution?
Do we need to rescale the model's probability predictions to take into account a tuned classification threshold that isn't at 0.5 prior to calculating log loss? That is, all predictions below the classification threshold get rescaled to [0, 0.5] and all predictions above the classification threshold get rescaled to [0.5, 1]?

Eventual goal
I have a highly unbalanced binary classification problem with multicollinearity in the features, mostly continuous features (just 1 categorical feature with only 5 levels), and a desire to better understand which features are important and how (i.e,. the shape of their relationship to the target).

What I've tried
I've played around with modifying the {cpi} package to calculate loss at an aggregated scale (i.e., per test data set) rather than a per-observation scale using measures more robust to class imbalance (Matthew's Correlation Coefficient). The actual implementation of that modificaiton to {cpi} is here. In that case, I relied on the repeated spatial cross validation for "significance" of CPI for each feature, since the implemented statistical tests rely on having CPI on a per-observation scale (before taking the mean to report a per-feature CPI value). But this strikes me as perhaps being overly conservative, so I'm revisiting using the default CPI loss functions.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.