Comments (46)
@QuentinAmbard First of all, I want to thank you for this excellent job. I'm also interested in pyspark. The problem is what @evah mentioned: you can build an explainer based on pyspark, but you cannot feed your explainer with a pyspark DataFrame. Theoretically, one could define a UserDefinedFunction, collect the features from spark to get a numpy or a pandas object, run your explainer and let spark do the dispatching. Unfortunately, to do so, objects needed for the computation must be serializable (through pickle) in order to be dispatched to every node, and an explainer is not.
A possible solution is to expose a serialize and a unserialize method for an explainer. The serialize should return an object made of dictionary, lists, whatever "basic" python types. The unserialize of course takes the serialized object and build from it a proper explainer.
In this way things might work: the serialized object could be dispatched, every node would build its explainer and the calculation could be carried out over the entire DataFrame.
Is it feasible to have those methods?
Thank you very much
PS: mostly interested in TreeExplainer if that helps in some way.
from shap.
Even I'm interested in PySpark implementation of it. I tried calculating SHAP values using the UDF code below:
X_df = X_df.repartition(10)
X_columns = X_df.columns
def add_shap(rows):
rows_pd = pd.DataFrame(rows, columns=X_columns)
shap_values = explainer.shap_values(rows_pd.drop(["Respondent"], axis=1))
return [Row(*([int(rows_pd["Respondent"][i])] + [float(f) for f in shap_values[i]])) for i in range(len(shap_values))]
shap_df = X_df.rdd.mapPartitions(add_shap).toDF(X_columns)
Code Reference- DataBricks
But this code is throwing below error:
Traceback (most recent call last):
File "/usr/lib/spark/python/pyspark/serializers.py", line 597, in dumps
return cloudpickle.dumps(obj, 2)
File "/usr/lib/spark/python/pyspark/cloudpickle.py", line 863, in dumps
cp.dump(obj)
File "/usr/lib/spark/python/pyspark/cloudpickle.py", line 260, in dump
return Pickler.dump(self, obj)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 409, in dump
self.save(obj)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
f(self, obj) # Call unbound method with explicit self
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 736, in save_tuple
save(element)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
f(self, obj) # Call unbound method with explicit self
File "/usr/lib/spark/python/pyspark/cloudpickle.py", line 400, in save_function
self.save_function_tuple(obj)
File "/usr/lib/spark/python/pyspark/cloudpickle.py", line 549, in save_function_tuple
save(state)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
f(self, obj) # Call unbound method with explicit self
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 821, in save_dict
self._batch_setitems(obj.items())
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 847, in _batch_setitems
save(v)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
f(self, obj) # Call unbound method with explicit self
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 821, in save_dict
self._batch_setitems(obj.items())
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 852, in _batch_setitems
save(v)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 521, in save
self.save_reduce(obj=obj, *rv)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 634, in save_reduce
save(state)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
f(self, obj) # Call unbound method with explicit self
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 821, in save_dict
self._batch_setitems(obj.items())
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 847, in _batch_setitems
save(v)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 521, in save
self.save_reduce(obj=obj, *rv)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 634, in save_reduce
save(state)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
f(self, obj) # Call unbound method with explicit self
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 821, in save_dict
self._batch_setitems(obj.items())
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 847, in _batch_setitems
save(v)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 521, in save
self.save_reduce(obj=obj, *rv)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 634, in save_reduce
save(state)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
f(self, obj) # Call unbound method with explicit self
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 821, in save_dict
self._batch_setitems(obj.items())
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 847, in _batch_setitems
save(v)
File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 496, in save
rv = reduce(self.proto)
File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
answer, self.gateway_client, self.target_id, self.name)
File "/usr/lib/spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 332, in get_return_value
format(target_id, ".", name, value))
py4j.protocol.Py4JError: An error occurred while calling o3332.__getstate__. Trace:
py4j.Py4JException: Method __getstate__([]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
at py4j.Gateway.invoke(Gateway.java:274)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
---------------------------------------------------------------------------
Py4JError Traceback (most recent call last)
/usr/lib/spark/python/pyspark/serializers.py in dumps(self, obj)
596 try:
--> 597 return cloudpickle.dumps(obj, 2)
598 except pickle.PickleError:
/usr/lib/spark/python/pyspark/cloudpickle.py in dumps(obj, protocol)
862 cp = CloudPickler(file,protocol)
--> 863 cp.dump(obj)
864 return file.getvalue()
/usr/lib/spark/python/pyspark/cloudpickle.py in dump(self, obj)
259 try:
--> 260 return Pickler.dump(self, obj)
261 except RuntimeError as e:
/opt/conda/anaconda/lib/python3.6/pickle.py in dump(self, obj)
408 self.framer.start_framing()
--> 409 self.save(obj)
410 self.write(STOP)
/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
475 if f is not None:
--> 476 f(self, obj) # Call unbound method with explicit self
477 return
/opt/conda/anaconda/lib/python3.6/pickle.py in save_tuple(self, obj)
735 for element in obj:
--> 736 save(element)
737 # Subtle. Same as in the big comment below.
/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
475 if f is not None:
--> 476 f(self, obj) # Call unbound method with explicit self
477 return
/usr/lib/spark/python/pyspark/cloudpickle.py in save_function(self, obj, name)
399 or themodule is None):
--> 400 self.save_function_tuple(obj)
401 return
/usr/lib/spark/python/pyspark/cloudpickle.py in save_function_tuple(self, func)
548 state['qualname'] = func.__qualname__
--> 549 save(state)
550 write(pickle.TUPLE)
/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
475 if f is not None:
--> 476 f(self, obj) # Call unbound method with explicit self
477 return
/opt/conda/anaconda/lib/python3.6/pickle.py in save_dict(self, obj)
820 self.memoize(obj)
--> 821 self._batch_setitems(obj.items())
822
/opt/conda/anaconda/lib/python3.6/pickle.py in _batch_setitems(self, items)
846 save(k)
--> 847 save(v)
848 write(SETITEMS)
/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
475 if f is not None:
--> 476 f(self, obj) # Call unbound method with explicit self
477 return
/opt/conda/anaconda/lib/python3.6/pickle.py in save_dict(self, obj)
820 self.memoize(obj)
--> 821 self._batch_setitems(obj.items())
822
/opt/conda/anaconda/lib/python3.6/pickle.py in _batch_setitems(self, items)
851 save(k)
--> 852 save(v)
853 write(SETITEM)
/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
520 # Save the reduce() output and finally memoize the object
--> 521 self.save_reduce(obj=obj, *rv)
522
/opt/conda/anaconda/lib/python3.6/pickle.py in save_reduce(self, func, args, state, listitems, dictitems, obj)
633 if state is not None:
--> 634 save(state)
635 write(BUILD)
/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
475 if f is not None:
--> 476 f(self, obj) # Call unbound method with explicit self
477 return
/opt/conda/anaconda/lib/python3.6/pickle.py in save_dict(self, obj)
820 self.memoize(obj)
--> 821 self._batch_setitems(obj.items())
822
/opt/conda/anaconda/lib/python3.6/pickle.py in _batch_setitems(self, items)
846 save(k)
--> 847 save(v)
848 write(SETITEMS)
/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
520 # Save the reduce() output and finally memoize the object
--> 521 self.save_reduce(obj=obj, *rv)
522
/opt/conda/anaconda/lib/python3.6/pickle.py in save_reduce(self, func, args, state, listitems, dictitems, obj)
633 if state is not None:
--> 634 save(state)
635 write(BUILD)
/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
475 if f is not None:
--> 476 f(self, obj) # Call unbound method with explicit self
477 return
/opt/conda/anaconda/lib/python3.6/pickle.py in save_dict(self, obj)
820 self.memoize(obj)
--> 821 self._batch_setitems(obj.items())
822
/opt/conda/anaconda/lib/python3.6/pickle.py in _batch_setitems(self, items)
846 save(k)
--> 847 save(v)
848 write(SETITEMS)
/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
520 # Save the reduce() output and finally memoize the object
--> 521 self.save_reduce(obj=obj, *rv)
522
/opt/conda/anaconda/lib/python3.6/pickle.py in save_reduce(self, func, args, state, listitems, dictitems, obj)
633 if state is not None:
--> 634 save(state)
635 write(BUILD)
/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
475 if f is not None:
--> 476 f(self, obj) # Call unbound method with explicit self
477 return
/opt/conda/anaconda/lib/python3.6/pickle.py in save_dict(self, obj)
820 self.memoize(obj)
--> 821 self._batch_setitems(obj.items())
822
/opt/conda/anaconda/lib/python3.6/pickle.py in _batch_setitems(self, items)
846 save(k)
--> 847 save(v)
848 write(SETITEMS)
/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
495 if reduce is not None:
--> 496 rv = reduce(self.proto)
497 else:
/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
1256 return_value = get_return_value(
-> 1257 answer, self.gateway_client, self.target_id, self.name)
1258
/usr/lib/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
62 try:
---> 63 return f(*a, **kw)
64 except py4j.protocol.Py4JJavaError as e:
/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
331 "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n".
--> 332 format(target_id, ".", name, value))
333 else:
Py4JError: An error occurred while calling o3332.__getstate__. Trace:
py4j.Py4JException: Method __getstate__([]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
at py4j.Gateway.invoke(Gateway.java:274)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
During handling of the above exception, another exception occurred:
PicklingError Traceback (most recent call last)
<ipython-input-370-4f054869b703> in <module>
----> 1 shap_df1 = shap_df.withColumn("shap_val", shapvalue_udf1(f.col("shap")))
/usr/lib/spark/python/pyspark/sql/udf.py in wrapper(*args)
187 @functools.wraps(self.func, assigned=assignments)
188 def wrapper(*args):
--> 189 return self(*args)
190
191 wrapper.__name__ = self._name
/usr/lib/spark/python/pyspark/sql/udf.py in __call__(self, *cols)
165
166 def __call__(self, *cols):
--> 167 judf = self._judf
168 sc = SparkContext._active_spark_context
169 return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
/usr/lib/spark/python/pyspark/sql/udf.py in _judf(self)
149 # and should have a minimal performance impact.
150 if self._judf_placeholder is None:
--> 151 self._judf_placeholder = self._create_judf()
152 return self._judf_placeholder
153
/usr/lib/spark/python/pyspark/sql/udf.py in _create_judf(self)
158 sc = spark.sparkContext
159
--> 160 wrapped_func = _wrap_function(sc, self.func, self.returnType)
161 jdt = spark._jsparkSession.parseDataType(self.returnType.json())
162 judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
/usr/lib/spark/python/pyspark/sql/udf.py in _wrap_function(sc, func, returnType)
33 def _wrap_function(sc, func, returnType):
34 command = (func, returnType)
---> 35 pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
36 return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
37 sc.pythonVer, broadcast_vars, sc._javaAccumulator)
/usr/lib/spark/python/pyspark/rdd.py in _prepare_for_python_RDD(sc, command)
2418 # the serialized command will be compressed by broadcast
2419 ser = CloudPickleSerializer()
-> 2420 pickled_command = ser.dumps(command)
2421 if len(pickled_command) > (1 << 20): # 1M
2422 # The broadcast will have same life cycle as created PythonRDD
/usr/lib/spark/python/pyspark/serializers.py in dumps(self, obj)
605 msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
606 cloudpickle.print_exec(sys.stderr)
--> 607 raise pickle.PicklingError(msg)
608
609
PicklingError: Could not serialize object: Py4JError: An error occurred while calling o3332.__getstate__. Trace:
py4j.Py4JException: Method __getstate__([]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
at py4j.Gateway.invoke(Gateway.java:274)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
Anyone has faced a similar issue and able to resolve it? Or I'm happy to work together if anyone is working to resolving this out
from shap.
Pardon for self-promotion, but you can most certainly use SHAP with Pyspark. It doesn't require any changes in SHAP. Here's my proof of existence: https://databricks.com/blog/2019/06/17/detecting-bias-with-shap.html
from shap.
In case anyone's interested, there's finally an example notebook for our Monte Carlo implementation of Shapley values. cc @shadab-entrepreneur
from shap.
Thanks @slundberg, as soon as I get something working properly I'll get back to you.
from shap.
Often with KernelExplainer you use a single background reference value (like the mean feature values), so that won't be a big issue. And you can then explain each sample individually if you like.
from shap.
@QuentinAmbard Thanks for offering your notebook, yet our answer did not solve the problem ...
- the python version, you code is using pandas data frame running single-node xgboost
- even if the scala version is parallel xgboost, it is scala and this issue's author is asking for
pyspark
- people still need SHAP for spark models (random forest & gbt etc.) not for xgboost model
from shap.
@QuentinAmbard Actually I just tried it 🤦♀️ the head-aching part is, it takes an pyspark model, but the X (feature set) still needs to be in pandas. But I guess, to make it able to take pyspark dataframe, we need to re-implement the whole algorithm in pyspark.
from shap.
@shadab-entrepreneur that looks very similar to Niloy Gupta, my, and others' work here
I think you may need to
yield
rather thanreturn
, to attempt directly answer your question?
Thanks @ijoseph
I tried using Shparkley as well and getting the same error. Is there any demo notebook for the same?
from shap.
I don't have any plans for that since I am not a Spark user. What Spark workflow do you have in mind that you would like SHAP to integrate with? I might be able to suggest something in a specific case (though I don't know much about spark).
from shap.
from shap.
from shap.
Hi Scott,
I'm would also like to understand if there is integration of SHAP method with SparkML lib ?
Any updates in this regard ?
Thanks,
Pavitra
from shap.
@pavitrasrinivasan The TreeExplainer now works for sklearn, so that part is done. As for spark integration I don't know what that would look like since I don't use spark. Do you need something special to run a python package in spark?
from shap.
Thanks for your response !
Yes, I will not be able to use the python based package directly in spark. I would need a spark compatible version.
Thanks,
Pavitra
from shap.
Okay, if anyone wants to tackle adding spark support that would be great. I am happy to help, but having not used spark myself I can't promise I'll get to it anytime soon by myself.
from shap.
Okay, if anyone wants to tackle adding spark support that would be great. I am happy to help, but having not used spark myself I can't promise I'll get to it anytime soon by myself.
Hi Scott, I've been using Spark for quite some time and would love to help adding support for Spark tree based models.
from shap.
@xboard great! If you take a look at a current example of parsing a tree model it should help you get started. Basically we build a TreeEnsemble object from many Tree objects, and each Tree object is built by parsing the trees from the original model and saving them in a form similar to how sklearn stores tree data: https://github.com/slundberg/shap/blob/5346fd4001ebb5e536fae68afb4097658def300b/shap/explainers/tree.py#L427-L439
from shap.
@xboard You got something?
for randomforestclassifier we can use following parameter, but this is internally represented as java object. what can we do?
- model.trees
- model.minInfoGain
elif str(type(model)).endswith("pyspark.ml.classification.RandomForestClassifier'>"):
scaling = 1.0 / len(model.estimators_) # output is average of trees
self.trees = [Tree(e.tree, normalize=True, scaling=scaling) for e in **model.trees**]
self.objective = objective_name_map.get(**model.minInfoGain**, None)
self.tree_output = "probability"
from shap.
@slundberg thanks for developing this great package! I'm also interested in a Spark implementation but as a first step would like to try using the model-agnostic version on a model trained with Spark.
As far as I can tell, it shouldn't matter where the model is trained, as long as I can pass a function that makes predictions based on the given features KernelExplainer. The problem would be that the dataset would be too large to fit in memory - which is why I'm using Spark. One suggestion I've had is to pass a down-sampled (but hopefully representative) version of the training data. I'm curious to know whether you think that would work?
from shap.
Hi ,
Is there any stable release of Shap in Pyspark till date? Sorry if my question doesn't make sense, but I am new to Pyspark & have a requirement of SHAP with a clustering analysis I am trying to solve.
from shap.
Not that I know of yet. Would be a great to have.
from shap.
What about the random forest in Spark ML? It is also better to have a scala version to calculate shap values.
We can use sparkContext to get the tree objects as in https://github.com/fivejjs/spark-tree-plotting
from shap.
I've got a draft version of the ExplainerTree with a simple spark decision tree using pyspark, I'll try improve it to support random forest and commit it. @slundberg can you confirm that categorical splits aren't supported?
from shap.
I've created #721, feedback is welcome.
from shap.
Does SHAP support PySpark models yet?
I have tried using SHAP with RandomForestRegressionModel (PySpark ML). But got this error: "NotImplementedError: CategoricalSplit are not yet implemented".
Is there the list of supported model/ libs for SHAP?
Thank you.
from shap.
Yes it's supported, but my understanding is that shap doesn't support categorical split so you can't have categorical features in your model (for example a stringIndexer will create a categorical split)
from shap.
Yes it's supported, but my understanding is that shap doesn't support categorical split so you can't have categorical features in your model (for example a stringIndexer will create a categorical split)
Thanks. @QuentinAmbard. I tried with another simple model and it can run properly.
I wonder if there are cases where I need to use categorical features in my model.
However, there is another error occurs, "
AttributeError: 'DataFrame' object has no attribute 'shape'". My guess is that I'm using Spark Dataframe as an input for shap_values?
from shap.
Categorical features are often required in trees (using a one hot encoder isn't a good alternative)...
There is an example in this thread: #884
As per this thread you also need to disable the additivity check for now, #905 will solve this issue (I still need to make spark work on Windows in the appveyor build)
Yes I assume you're using a Dataframe.
If you need to distribute it you can implement a pandas UDF and call spah_values inside each batch, and merge the output in a spark dataframe.
Here is a scala example with a distributed shap computation: https://databricks.com/blog/2019/06/17/detecting-bias-with-shap.html
from shap.
Yes it's supported, but my understanding is that shap doesn't support categorical split so you can't have categorical features in your model (for example a stringIndexer will create a categorical split)
Any solution for NotImplementedError: CategoricalSplit are not yet implemented" in pyspark?
from shap.
@slundberg is there a way to support categorical split with the current shap implementation?
I'm not sure how this is implemented in XGBoost, is it because there is not categorical split and just numerical one (so it's actually not implemented)?
from shap.
I am using RandomForest and GBT in pyspark and I have to transform some categorical features using StringIndexer and some using oneHotEncoder. Is there a way/workaround to use SHAP on the transformed categorical features?
from shap.
A categorical multi-way split is something that is not yet in the C++ version of TreeExplainer. So that would need to be added to support this. I would like to say I’ll just take a few days and write it, but I am not sure when I will have those free days. I’ll try to get it in sooner rather than later...
from shap.
@evah I think it does solve the problem ;)
The notebook just show you how you can compute your explanations using spark on multiple node (explanation of any model, in this case a single node xgboost).
If you want to get the explanation of your spark Tree-based model, you can find an example here:
https://github.com/QuentinAmbard/shap/blob/%23866_spark_regressor_support/tests/explainers/test_tree.py#L190
(As sacmax mentioned, categorical features are not implemented)
from shap.
@shadab-entrepreneur that looks very similar to Niloy Gupta, my, and others' work here
I think you may need to yield
rather than return
, to attempt directly answer your question?
from shap.
Hm, interesting. The readme has example usage but there's no example notebook as of yet.
Looking again at your stack trace, are you calling df.rdd.mapParittions
inside of shapvalue_udf1
? Mostly guessing based on seeing what error you got means: py4j.Py4JException: Method __getstate__([]) does not exist
and this syntax:
shap_df1 = shap_df.withColumn("shap_val", shapvalue_udf1(f.col("shap")))
from shap.
@shadab-entrepreneur have you solved the issue? I'm facing the same problem.
from shap.
@davezhouwa Nope, still I'm facing the same issue. Were you able to solve it?
from shap.
It's working fine in my notebook, make sure you can serialize the explainer:
pickle.dumps(explainer)
If you can't then something is wrong (you didn't removed the line)
You need to edit it manually as it's not yet released in the last shap version from pypi
model = pipeline.stages[-1].bestModel
print(model)
==> Out[22]: GBTClassificationModel: uid = GBTClassifier_367df9840c81, numTrees=5, numClasses=2, numFeatures=8
explainer = shap.TreeExplainer(bestModel)
#you can check if we can serialize the explainer:
#pickle.dumps(explainer)
And now it's working in a pandas_udf, or using the new mapInPandas in spark 3:
def compute_shap_values(iterator):
for X in iterator:
yield pd.DataFrame(explainer.shap_values(X, check_additivity=False))
display(dataset.mapInPandas(compute_shap_values, schema=dataset.schema).toDF(*[x+"_shap_value" for x in test.columns]))
from shap.
@shadab-entrepreneur, I followed @QuentinAmbard suggestions and it worked
from shap.
Hoping someone might benefit, my approach to the problems described in this issue was:
- randomly sample the target Spark DataFrame (to make sure the data fits the master node)
- convert the DF to a numpy array
- calculate SHAP
def df2numpy(df: DataFrame, cols: List[str], nsamples: int) -> np.ndarray:
"""
Converts a DataFrame into a Numpy array (matrix). "cols" gives the order of
the matrix columns; "nsamples" is the approximate number of records.
"""
# This is not efficient; "count" may be expensive to compute.
fraction = nsamples / df.count()
samp_df = df.sample(fraction=fraction, seed=6969) if fraction < 1 else df
return np.array(samp_df.select(*cols).collect())
I also had problems with categorical features (NotImplementedError: CategoricalSplit are not yet implemented
). I had forgotten that Spark implicitly uses column metadata to identify indexed features and make "categorical splits". To solve this issue, I just created another DF without any metadata:
# "gbt" is an instance of GBTClassifier.
# "df" contains numeric and categorical/indexed features.
pipe = Pipeline(stages=[assembler, gbt]).fit(
# Remove metadata of the input DF. We don't want to handle categorical
# values when training a GBT model (Spark models implicitly use
# metadata to handle categorical features).
df.rdd.toDF()
)
from shap.
Hoping someone might benefit, my approach to the problems described in this issue was:
- randomly sample the target Spark DataFrame (to make sure the data fits the master node)
- convert the DF to a numpy array
- calculate SHAP
def df2numpy(df: DataFrame, cols: List[str], nsamples: int) -> np.ndarray: """ Converts a DataFrame into a Numpy array (matrix). "cols" gives the order of the matrix columns; "nsamples" is the approximate number of records. """ # This is not efficient; "count" may be expensive to compute. fraction = nsamples / df.count() samp_df = df.sample(fraction=fraction, seed=6969) if fraction < 1 else df return np.array(samp_df.select(*cols).collect())I also had problems with categorical features (
NotImplementedError: CategoricalSplit are not yet implemented
). I had forgotten that Spark implicitly uses column metadata to identify indexed features and make "categorical splits". To solve this issue, I just created another DF without any metadata:# "gbt" is an instance of GBTClassifier. # "df" contains numeric and categorical/indexed features. pipe = Pipeline(stages=[assembler, gbt]).fit( # Remove metadata of the input DF. We don't want to handle categorical # values when training a GBT model (Spark models implicitly use # metadata to handle categorical features). df.rdd.toDF() )
I didn't not find that the categorical features fixed the CategoricalSplit are not yet implemented.
Has anyone found a fix for this?
from shap.
Hey everyone,
Is there any update on CategoricalSplit implementation?
Thanks in advance!
Hoping someone might benefit, my approach to the problems described in this issue was:
- randomly sample the target Spark DataFrame (to make sure the data fits the master node)
- convert the DF to a numpy array
- calculate SHAP
def df2numpy(df: DataFrame, cols: List[str], nsamples: int) -> np.ndarray: """ Converts a DataFrame into a Numpy array (matrix). "cols" gives the order of the matrix columns; "nsamples" is the approximate number of records. """ # This is not efficient; "count" may be expensive to compute. fraction = nsamples / df.count() samp_df = df.sample(fraction=fraction, seed=6969) if fraction < 1 else df return np.array(samp_df.select(*cols).collect())I also had problems with categorical features (
NotImplementedError: CategoricalSplit are not yet implemented
). I had forgotten that Spark implicitly uses column metadata to identify indexed features and make "categorical splits". To solve this issue, I just created another DF without any metadata:# "gbt" is an instance of GBTClassifier. # "df" contains numeric and categorical/indexed features. pipe = Pipeline(stages=[assembler, gbt]).fit( # Remove metadata of the input DF. We don't want to handle categorical # values when training a GBT model (Spark models implicitly use # metadata to handle categorical features). df.rdd.toDF() )I didn't not find that the categorical features fixed the CategoricalSplit are not yet implemented.
Has anyone found a fix for this?
from shap.
Hello!
Will CategoricalSplit be implemented in the next release? Any information?
from shap.
Hi, any news on this? I have seen the contributionPredictionCol but I am not sure whether SHAP values can now be calculated on the Spark version of XGBoost or there are still cornercases or workarounds.
from shap.
Hi, is CategoricalSplit yet implemented?
from shap.
Related Issues (20)
- BUG: ERROR USING LLAMA-2 HOT 15
- BUG: 0.45.0 update breaks pytorch example on docs HOT 1
- x
- BUG: Error using Falcon for text-generation HOT 4
- BUG: Error when using DeepExplainer on LSTM Model HOT 1
- ENH: Partition Explainer for Video Models
- Does Feature/Column Order of dataset matter while calculating SHAP values? HOT 3
- When will the paddlepaddle framework be supported HOT 1
- BUG: LookupError: gradient registry has no entry for: shap_TensorListStack HOT 1
- BUG: shap summary plot for 3 group classification HOT 1
- ENH: Include directionality of feature association in beeswarm plot
- ENH: Support SeLU and activation function in Pytorch Deep Explainer
- BUG: tensorflow DeepExplainer SHAP explanations do not sum up to the model's output HOT 1
- Question: Using SHAP with GPT-4 via API HOT 1
- BUG: Warning: unrecognized nn.Module: Chomp1d HOT 2
- CI broken: mistralai Mistral-7B-v0.1 Tokenizer no longer accessible
- BUG: summary_plot ignores plot_type for TreeExplainer
- BUG: TypeError: waterfall() got an unexpected keyword argument 'features'
- BUG: Unexpected Interaction Plot Instead of Summary Plot in Multiclass SHAP Summary with XGBoost HOT 1
- BUG: Workflow failure on macOS when building 'lightgbm'
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from shap.