Comments (5)
You're right that there must be some bug causing this result. Negative SAGE values should only occur when a feature is hurting the model's performance, so it cannot happen for all your features given that the model performs well.
I had to refresh my memory on a couple details from the implementation, but here are some of the important details:
-
I'm imagining that your labels are represented by an array with shape
(n,)
, and that your model's predictions are either(n,)
for the probability of class 1 or(n, 2)
for the probabilities of class -1 and 1 respectively. Let me know if either of those details aren't correct here. -
In our implementation of the cross entropy loss, the first step is to make sure that the predictions have shape
(n, num_classes)
. You can see that this is done here. Then, we extract the probability for the true class here. You can imagine that this could cause a problem when using -1 as a class label, because both -1 and 1 would lead to us extracting the last dimension. -
Like you said, someone brought up this issue a couple years ago, so I tried to put a fix in place. The fix is performed within the
verify_model_data
function. Inside that function, one of the first things we do is make sure the labels have shape(n,)
(see here). Next, if the problem is binary classification (which we check for here), one of the things we do is turn the -1 values to 0 (here). So I'm a bit surprised to encounter this error, I expected this fix to catch such situations.
Can you think of any reason why this fix wouldn't be working properly? Or would you be able to try manually changing your label encoding? Let me know what you think. It's also possible the error is unrelated to the label encoding, but this seems like a good place to start.
from sage.
Dear @iancovert,
thanks for your response.
-
My predictions have the shape:
(n,2) if predict_proba is used and (n,) otherwise. -
My Experiments
- I tried changing the labels from
[-1,1]
to[0, 1]
. For tests, I either us predict_proba or predict via callable. The results of the estimated values are approximately the same for all variants. I didn't run into any of the "invalid probabilities checks" either. Thus, I conclude, that its a problem on the classifier-side. - So far the relevant metric for my classifiers was accuracy. I estimated the cross-entropy loss for the classifier and noticed that it is higher than say a random guess (assuming p=0.5, -ln(0.5)), as it explodes for some samples. I guess, this can also cause the loss to explode during imputation / permutation, ultimately resulting in all negative feature contributions.
- I tried changing the labels from
-
In consequence, I experimented with zero-one loss again. Zero-one loss seems also appropriate as some of my classifiers deliver only hard probabilities. Here, the problem is less severe, but still persists. I have two models with negative-only contributions, despite accuracy > 0.65. I have thought about debugging this further to see what loss is before and after. I might also compare against some naive approach like random feature permutation to better understand the issue. Do you have any other ideas, how I could dig deeper?
I try to post my solution if there is any. Otherwise I think we can close this issue for now.
Thanks in advance.
from sage.
This is helpful information, I forgot that you're using some classifiers that output only hard probabilities. In that case, we won't be able to calculate the cross-entropy loss even given the predict_proba
output: the reason is that any incorrect prediction will result in infinite cross-entropy loss, because the probability mass on the correct class is zero. The cross-entropy loss requires probabilistic predictions (this is why we try to use predict_proba
whenever possible when preparing the model), but we'll need to avoid making that assumption here.
So one thing is for sure, we'll need to use your zero-one loss. The next question is why even those results look strange so far. Looking at your implementation here, here's what I'm observing:
-
Assuming our model preparation works correctly, the package should be calling
predict_proba
on your model, and therefore getting an array with shape(n,2)
containing 0s and 1s. Neither of the if statements under# Add a dimension to prediction probabilities if necessary
should be triggered, sopred
should retain that shape. -
We should have
target.ndim == 1
, so I believe the loss will be calculated according toloss = (np.argmax(pred, axis=1) != target).astype(float)
. And this actually looks correct, I'm not sure how this could go wrong.
With that in mind, I'm thinking the issue could be what you mentioned about the model's accuracy vs a random guess: if the accuracy we observe (0.65) is less than what we would get by guessing the most likely class (which could be > 0.65 depending on the class imbalance), then it would make sense to observe negative SAGE values. The way to think about this: if your model had no input features, it would roughly guess the true class proportions (say 0.2 and 0.8), and it would get accuracy of 0.8. The SAGE values should sum to 0.65 - 0.8 = -0.15, so it would make sense for many or all of them to be negative.
Let me know what you think.
from sage.
Dear @iancovert,
thanks for your response. Your response made go through the zero-one loss again and I think we finally solved it 🎉.
What I didn't try previously is changing the labels to [0,1]
instead of [-1,1]
when paired with ZeroOneLoss
. I only tested for the cross-entropy loss. I'll have to verify / investigate why targets are not converted automatically here and see where it fails exactly.
The SAGE Values look rather reasonable if I clip the values to [0, 1]
using y_test.clip(0)
for different models:
To the other points you raised:
- My dataset is approximately balanced, so zero-one loss of random prediction is around 0.5 and so is accuracy. That's made me wonder, why accuracy >= 0.65 gave me all negative SAGE values.
- All my models are wrapped in sklearn-wrapper derived from
BaseEstimator
andClassifierMixin
andpredict_proba
is called. Another one is aCatBoostClassifier
for whichpredict_proba
is called.
Thanks for your help. I'll consider the point when preparing the PR for zero-one loss and investigate it a bit further.
from sage.
Awesome, I'm relieved we figured this out. Let me know if you figure out where the label conversion went wrong, hopefully it would require only a small fix. And looking forward to the PR!
from sage.
Related Issues (16)
- License HOT 1
- Parallelized computation HOT 15
- Possibility to use presegmented images HOT 3
- Exception encountered when calling layer "gru" (type GRU). HOT 1
- TreeSAGE ? HOT 6
- Zero-One Loss in Classification HOT 4
- Unstable SAGE values HOT 4
- Mismatch between feature importancies from `GroupedMarginalImputer` and `MarginalImputer` HOT 5
- pip install HOT 2
- Explanation about new changes in the SAGE package and addition of Model sensitivity module. HOT 6
- PermutationEstimator runs infinitely when gap = 0 HOT 3
- SAGE values on cross-validation HOT 2
- Shape mismatch on XGB.Classifier HOT 3
- adaptive estimator for online data HOT 2
- SAGE with NLP/Huggingface HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from sage.