GithubHelp home page GithubHelp logo

A Spark version in plan? about shap HOT 46 OPEN

slundberg avatar slundberg commented on May 3, 2024
A Spark version in plan?

from shap.

Comments (46)

nickthegeek2877 avatar nickthegeek2877 commented on May 3, 2024 3

@QuentinAmbard First of all, I want to thank you for this excellent job. I'm also interested in pyspark. The problem is what @evah mentioned: you can build an explainer based on pyspark, but you cannot feed your explainer with a pyspark DataFrame. Theoretically, one could define a UserDefinedFunction, collect the features from spark to get a numpy or a pandas object, run your explainer and let spark do the dispatching. Unfortunately, to do so, objects needed for the computation must be serializable (through pickle) in order to be dispatched to every node, and an explainer is not.

A possible solution is to expose a serialize and a unserialize method for an explainer. The serialize should return an object made of dictionary, lists, whatever "basic" python types. The unserialize of course takes the serialized object and build from it a proper explainer.

In this way things might work: the serialized object could be dispatched, every node would build its explainer and the calculation could be carried out over the entire DataFrame.

Is it feasible to have those methods?

Thank you very much

PS: mostly interested in TreeExplainer if that helps in some way.

from shap.

techwithshadab avatar techwithshadab commented on May 3, 2024 3

Even I'm interested in PySpark implementation of it. I tried calculating SHAP values using the UDF code below:

X_df = X_df.repartition(10)
X_columns = X_df.columns

def add_shap(rows):
  rows_pd = pd.DataFrame(rows, columns=X_columns)
  shap_values = explainer.shap_values(rows_pd.drop(["Respondent"], axis=1))
  return [Row(*([int(rows_pd["Respondent"][i])] + [float(f) for f in shap_values[i]])) for i in range(len(shap_values))]

shap_df = X_df.rdd.mapPartitions(add_shap).toDF(X_columns)

Code Reference- DataBricks
But this code is throwing below error:


Traceback (most recent call last):
  File "/usr/lib/spark/python/pyspark/serializers.py", line 597, in dumps
    return cloudpickle.dumps(obj, 2)
  File "/usr/lib/spark/python/pyspark/cloudpickle.py", line 863, in dumps
    cp.dump(obj)
  File "/usr/lib/spark/python/pyspark/cloudpickle.py", line 260, in dump
    return Pickler.dump(self, obj)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 409, in dump
    self.save(obj)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
    f(self, obj) # Call unbound method with explicit self
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 736, in save_tuple
    save(element)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
    f(self, obj) # Call unbound method with explicit self
  File "/usr/lib/spark/python/pyspark/cloudpickle.py", line 400, in save_function
    self.save_function_tuple(obj)
  File "/usr/lib/spark/python/pyspark/cloudpickle.py", line 549, in save_function_tuple
    save(state)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
    f(self, obj) # Call unbound method with explicit self
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 821, in save_dict
    self._batch_setitems(obj.items())
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 847, in _batch_setitems
    save(v)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
    f(self, obj) # Call unbound method with explicit self
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 821, in save_dict
    self._batch_setitems(obj.items())
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 852, in _batch_setitems
    save(v)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 521, in save
    self.save_reduce(obj=obj, *rv)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 634, in save_reduce
    save(state)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
    f(self, obj) # Call unbound method with explicit self
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 821, in save_dict
    self._batch_setitems(obj.items())
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 847, in _batch_setitems
    save(v)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 521, in save
    self.save_reduce(obj=obj, *rv)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 634, in save_reduce
    save(state)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
    f(self, obj) # Call unbound method with explicit self
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 821, in save_dict
    self._batch_setitems(obj.items())
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 847, in _batch_setitems
    save(v)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 521, in save
    self.save_reduce(obj=obj, *rv)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 634, in save_reduce
    save(state)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 476, in save
    f(self, obj) # Call unbound method with explicit self
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 821, in save_dict
    self._batch_setitems(obj.items())
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 847, in _batch_setitems
    save(v)
  File "/opt/conda/anaconda/lib/python3.6/pickle.py", line 496, in save
    rv = reduce(self.proto)
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/usr/lib/spark/python/pyspark/sql/utils.py", line 63, in deco
    return f(*a, **kw)
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 332, in get_return_value
    format(target_id, ".", name, value))
py4j.protocol.Py4JError: An error occurred while calling o3332.__getstate__. Trace:
py4j.Py4JException: Method __getstate__([]) does not exist
	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
	at py4j.Gateway.invoke(Gateway.java:274)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)



---------------------------------------------------------------------------
Py4JError                                 Traceback (most recent call last)
/usr/lib/spark/python/pyspark/serializers.py in dumps(self, obj)
    596         try:
--> 597             return cloudpickle.dumps(obj, 2)
    598         except pickle.PickleError:

/usr/lib/spark/python/pyspark/cloudpickle.py in dumps(obj, protocol)
    862         cp = CloudPickler(file,protocol)
--> 863         cp.dump(obj)
    864         return file.getvalue()

/usr/lib/spark/python/pyspark/cloudpickle.py in dump(self, obj)
    259         try:
--> 260             return Pickler.dump(self, obj)
    261         except RuntimeError as e:

/opt/conda/anaconda/lib/python3.6/pickle.py in dump(self, obj)
    408             self.framer.start_framing()
--> 409         self.save(obj)
    410         self.write(STOP)

/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
    475         if f is not None:
--> 476             f(self, obj) # Call unbound method with explicit self
    477             return

/opt/conda/anaconda/lib/python3.6/pickle.py in save_tuple(self, obj)
    735             for element in obj:
--> 736                 save(element)
    737             # Subtle.  Same as in the big comment below.

/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
    475         if f is not None:
--> 476             f(self, obj) # Call unbound method with explicit self
    477             return

/usr/lib/spark/python/pyspark/cloudpickle.py in save_function(self, obj, name)
    399                 or themodule is None):
--> 400             self.save_function_tuple(obj)
    401             return

/usr/lib/spark/python/pyspark/cloudpickle.py in save_function_tuple(self, func)
    548             state['qualname'] = func.__qualname__
--> 549         save(state)
    550         write(pickle.TUPLE)

/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
    475         if f is not None:
--> 476             f(self, obj) # Call unbound method with explicit self
    477             return

/opt/conda/anaconda/lib/python3.6/pickle.py in save_dict(self, obj)
    820         self.memoize(obj)
--> 821         self._batch_setitems(obj.items())
    822 

/opt/conda/anaconda/lib/python3.6/pickle.py in _batch_setitems(self, items)
    846                     save(k)
--> 847                     save(v)
    848                 write(SETITEMS)

/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
    475         if f is not None:
--> 476             f(self, obj) # Call unbound method with explicit self
    477             return

/opt/conda/anaconda/lib/python3.6/pickle.py in save_dict(self, obj)
    820         self.memoize(obj)
--> 821         self._batch_setitems(obj.items())
    822 

/opt/conda/anaconda/lib/python3.6/pickle.py in _batch_setitems(self, items)
    851                 save(k)
--> 852                 save(v)
    853                 write(SETITEM)

/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
    520         # Save the reduce() output and finally memoize the object
--> 521         self.save_reduce(obj=obj, *rv)
    522 

/opt/conda/anaconda/lib/python3.6/pickle.py in save_reduce(self, func, args, state, listitems, dictitems, obj)
    633         if state is not None:
--> 634             save(state)
    635             write(BUILD)

/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
    475         if f is not None:
--> 476             f(self, obj) # Call unbound method with explicit self
    477             return

/opt/conda/anaconda/lib/python3.6/pickle.py in save_dict(self, obj)
    820         self.memoize(obj)
--> 821         self._batch_setitems(obj.items())
    822 

/opt/conda/anaconda/lib/python3.6/pickle.py in _batch_setitems(self, items)
    846                     save(k)
--> 847                     save(v)
    848                 write(SETITEMS)

/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
    520         # Save the reduce() output and finally memoize the object
--> 521         self.save_reduce(obj=obj, *rv)
    522 

/opt/conda/anaconda/lib/python3.6/pickle.py in save_reduce(self, func, args, state, listitems, dictitems, obj)
    633         if state is not None:
--> 634             save(state)
    635             write(BUILD)

/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
    475         if f is not None:
--> 476             f(self, obj) # Call unbound method with explicit self
    477             return

/opt/conda/anaconda/lib/python3.6/pickle.py in save_dict(self, obj)
    820         self.memoize(obj)
--> 821         self._batch_setitems(obj.items())
    822 

/opt/conda/anaconda/lib/python3.6/pickle.py in _batch_setitems(self, items)
    846                     save(k)
--> 847                     save(v)
    848                 write(SETITEMS)

/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
    520         # Save the reduce() output and finally memoize the object
--> 521         self.save_reduce(obj=obj, *rv)
    522 

/opt/conda/anaconda/lib/python3.6/pickle.py in save_reduce(self, func, args, state, listitems, dictitems, obj)
    633         if state is not None:
--> 634             save(state)
    635             write(BUILD)

/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
    475         if f is not None:
--> 476             f(self, obj) # Call unbound method with explicit self
    477             return

/opt/conda/anaconda/lib/python3.6/pickle.py in save_dict(self, obj)
    820         self.memoize(obj)
--> 821         self._batch_setitems(obj.items())
    822 

/opt/conda/anaconda/lib/python3.6/pickle.py in _batch_setitems(self, items)
    846                     save(k)
--> 847                     save(v)
    848                 write(SETITEMS)

/opt/conda/anaconda/lib/python3.6/pickle.py in save(self, obj, save_persistent_id)
    495             if reduce is not None:
--> 496                 rv = reduce(self.proto)
    497             else:

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 

/usr/lib/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    331                     "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n".
--> 332                     format(target_id, ".", name, value))
    333         else:

Py4JError: An error occurred while calling o3332.__getstate__. Trace:
py4j.Py4JException: Method __getstate__([]) does not exist
	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
	at py4j.Gateway.invoke(Gateway.java:274)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)



During handling of the above exception, another exception occurred:

PicklingError                             Traceback (most recent call last)
<ipython-input-370-4f054869b703> in <module>
----> 1 shap_df1 = shap_df.withColumn("shap_val", shapvalue_udf1(f.col("shap")))

/usr/lib/spark/python/pyspark/sql/udf.py in wrapper(*args)
    187         @functools.wraps(self.func, assigned=assignments)
    188         def wrapper(*args):
--> 189             return self(*args)
    190 
    191         wrapper.__name__ = self._name

/usr/lib/spark/python/pyspark/sql/udf.py in __call__(self, *cols)
    165 
    166     def __call__(self, *cols):
--> 167         judf = self._judf
    168         sc = SparkContext._active_spark_context
    169         return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))

/usr/lib/spark/python/pyspark/sql/udf.py in _judf(self)
    149         # and should have a minimal performance impact.
    150         if self._judf_placeholder is None:
--> 151             self._judf_placeholder = self._create_judf()
    152         return self._judf_placeholder
    153 

/usr/lib/spark/python/pyspark/sql/udf.py in _create_judf(self)
    158         sc = spark.sparkContext
    159 
--> 160         wrapped_func = _wrap_function(sc, self.func, self.returnType)
    161         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
    162         judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(

/usr/lib/spark/python/pyspark/sql/udf.py in _wrap_function(sc, func, returnType)
     33 def _wrap_function(sc, func, returnType):
     34     command = (func, returnType)
---> 35     pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
     36     return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
     37                                   sc.pythonVer, broadcast_vars, sc._javaAccumulator)

/usr/lib/spark/python/pyspark/rdd.py in _prepare_for_python_RDD(sc, command)
   2418     # the serialized command will be compressed by broadcast
   2419     ser = CloudPickleSerializer()
-> 2420     pickled_command = ser.dumps(command)
   2421     if len(pickled_command) > (1 << 20):  # 1M
   2422         # The broadcast will have same life cycle as created PythonRDD

/usr/lib/spark/python/pyspark/serializers.py in dumps(self, obj)
    605                 msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
    606             cloudpickle.print_exec(sys.stderr)
--> 607             raise pickle.PicklingError(msg)
    608 
    609 

PicklingError: Could not serialize object: Py4JError: An error occurred while calling o3332.__getstate__. Trace:
py4j.Py4JException: Method __getstate__([]) does not exist
	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
	at py4j.Gateway.invoke(Gateway.java:274)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)

Anyone has faced a similar issue and able to resolve it? Or I'm happy to work together if anyone is working to resolving this out

from shap.

srowen avatar srowen commented on May 3, 2024 2

Pardon for self-promotion, but you can most certainly use SHAP with Pyspark. It doesn't require any changes in SHAP. Here's my proof of existence: https://databricks.com/blog/2019/06/17/detecting-bias-with-shap.html

from shap.

ijoseph avatar ijoseph commented on May 3, 2024 2

In case anyone's interested, there's finally an example notebook for our Monte Carlo implementation of Shapley values. cc @shadab-entrepreneur

from shap.

xboard avatar xboard commented on May 3, 2024 1

Thanks @slundberg, as soon as I get something working properly I'll get back to you.

from shap.

slundberg avatar slundberg commented on May 3, 2024 1

Often with KernelExplainer you use a single background reference value (like the mean feature values), so that won't be a big issue. And you can then explain each sample individually if you like.

from shap.

evah avatar evah commented on May 3, 2024 1

@QuentinAmbard Thanks for offering your notebook, yet our answer did not solve the problem ...

  • the python version, you code is using pandas data frame running single-node xgboost
  • even if the scala version is parallel xgboost, it is scala and this issue's author is asking for pyspark
  • people still need SHAP for spark models (random forest & gbt etc.) not for xgboost model

from shap.

evah avatar evah commented on May 3, 2024 1

@QuentinAmbard Actually I just tried it 🤦‍♀️ the head-aching part is, it takes an pyspark model, but the X (feature set) still needs to be in pandas. But I guess, to make it able to take pyspark dataframe, we need to re-implement the whole algorithm in pyspark.

from shap.

techwithshadab avatar techwithshadab commented on May 3, 2024 1

@shadab-entrepreneur that looks very similar to Niloy Gupta, my, and others' work here

I think you may need to yield rather than return, to attempt directly answer your question?

Thanks @ijoseph
I tried using Shparkley as well and getting the same error. Is there any demo notebook for the same?

from shap.

slundberg avatar slundberg commented on May 3, 2024

I don't have any plans for that since I am not a Spark user. What Spark workflow do you have in mind that you would like SHAP to integrate with? I might be able to suggest something in a specific case (though I don't know much about spark).

from shap.

richardxy avatar richardxy commented on May 3, 2024

from shap.

slundberg avatar slundberg commented on May 3, 2024

from shap.

pavitrasrinivasan avatar pavitrasrinivasan commented on May 3, 2024

Hi Scott,

I'm would also like to understand if there is integration of SHAP method with SparkML lib ?
Any updates in this regard ?

Thanks,
Pavitra

from shap.

slundberg avatar slundberg commented on May 3, 2024

@pavitrasrinivasan The TreeExplainer now works for sklearn, so that part is done. As for spark integration I don't know what that would look like since I don't use spark. Do you need something special to run a python package in spark?

from shap.

pavitrasrinivasan avatar pavitrasrinivasan commented on May 3, 2024

Thanks for your response !

Yes, I will not be able to use the python based package directly in spark. I would need a spark compatible version.

Thanks,
Pavitra

from shap.

slundberg avatar slundberg commented on May 3, 2024

Okay, if anyone wants to tackle adding spark support that would be great. I am happy to help, but having not used spark myself I can't promise I'll get to it anytime soon by myself.

from shap.

xboard avatar xboard commented on May 3, 2024

Okay, if anyone wants to tackle adding spark support that would be great. I am happy to help, but having not used spark myself I can't promise I'll get to it anytime soon by myself.

Hi Scott, I've been using Spark for quite some time and would love to help adding support for Spark tree based models.

from shap.

slundberg avatar slundberg commented on May 3, 2024

@xboard great! If you take a look at a current example of parsing a tree model it should help you get started. Basically we build a TreeEnsemble object from many Tree objects, and each Tree object is built by parsing the trees from the original model and saving them in a form similar to how sklearn stores tree data: https://github.com/slundberg/shap/blob/5346fd4001ebb5e536fae68afb4097658def300b/shap/explainers/tree.py#L427-L439

from shap.

haribaskar avatar haribaskar commented on May 3, 2024

@xboard You got something?
for randomforestclassifier we can use following parameter, but this is internally represented as java object. what can we do?

  • model.trees
  • model.minInfoGain
elif str(type(model)).endswith("pyspark.ml.classification.RandomForestClassifier'>"):
            scaling = 1.0 / len(model.estimators_) # output is average of trees
            self.trees = [Tree(e.tree, normalize=True, scaling=scaling) for e in **model.trees**]
            self.objective = objective_name_map.get(**model.minInfoGain**, None)
            self.tree_output = "probability"

from shap.

jbertscher avatar jbertscher commented on May 3, 2024

@slundberg thanks for developing this great package! I'm also interested in a Spark implementation but as a first step would like to try using the model-agnostic version on a model trained with Spark.

As far as I can tell, it shouldn't matter where the model is trained, as long as I can pass a function that makes predictions based on the given features KernelExplainer. The problem would be that the dataset would be too large to fit in memory - which is why I'm using Spark. One suggestion I've had is to pass a down-sampled (but hopefully representative) version of the training data. I'm curious to know whether you think that would work?

from shap.

Chakri-V-V avatar Chakri-V-V commented on May 3, 2024

Hi ,

Is there any stable release of Shap in Pyspark till date? Sorry if my question doesn't make sense, but I am new to Pyspark & have a requirement of SHAP with a clustering analysis I am trying to solve.

from shap.

slundberg avatar slundberg commented on May 3, 2024

Not that I know of yet. Would be a great to have.

from shap.

fivejjs avatar fivejjs commented on May 3, 2024

What about the random forest in Spark ML? It is also better to have a scala version to calculate shap values.

We can use sparkContext to get the tree objects as in https://github.com/fivejjs/spark-tree-plotting

from shap.

QuentinAmbard avatar QuentinAmbard commented on May 3, 2024

I've got a draft version of the ExplainerTree with a simple spark decision tree using pyspark, I'll try improve it to support random forest and commit it. @slundberg can you confirm that categorical splits aren't supported?

from shap.

QuentinAmbard avatar QuentinAmbard commented on May 3, 2024

I've created #721, feedback is welcome.

from shap.

ppakawatk avatar ppakawatk commented on May 3, 2024

Does SHAP support PySpark models yet?
I have tried using SHAP with RandomForestRegressionModel (PySpark ML). But got this error: "NotImplementedError: CategoricalSplit are not yet implemented".

Is there the list of supported model/ libs for SHAP?

Thank you.

from shap.

QuentinAmbard avatar QuentinAmbard commented on May 3, 2024

Yes it's supported, but my understanding is that shap doesn't support categorical split so you can't have categorical features in your model (for example a stringIndexer will create a categorical split)

from shap.

ppakawatk avatar ppakawatk commented on May 3, 2024

Yes it's supported, but my understanding is that shap doesn't support categorical split so you can't have categorical features in your model (for example a stringIndexer will create a categorical split)

Thanks. @QuentinAmbard. I tried with another simple model and it can run properly.
I wonder if there are cases where I need to use categorical features in my model.

However, there is another error occurs, "
AttributeError: 'DataFrame' object has no attribute 'shape'". My guess is that I'm using Spark Dataframe as an input for shap_values?

from shap.

QuentinAmbard avatar QuentinAmbard commented on May 3, 2024

Categorical features are often required in trees (using a one hot encoder isn't a good alternative)...

There is an example in this thread: #884
As per this thread you also need to disable the additivity check for now, #905 will solve this issue (I still need to make spark work on Windows in the appveyor build)

Yes I assume you're using a Dataframe.
If you need to distribute it you can implement a pandas UDF and call spah_values inside each batch, and merge the output in a spark dataframe.
Here is a scala example with a distributed shap computation: https://databricks.com/blog/2019/06/17/detecting-bias-with-shap.html

from shap.

sacmax avatar sacmax commented on May 3, 2024

Yes it's supported, but my understanding is that shap doesn't support categorical split so you can't have categorical features in your model (for example a stringIndexer will create a categorical split)

Any solution for NotImplementedError: CategoricalSplit are not yet implemented" in pyspark?

from shap.

QuentinAmbard avatar QuentinAmbard commented on May 3, 2024

@slundberg is there a way to support categorical split with the current shap implementation?
I'm not sure how this is implemented in XGBoost, is it because there is not categorical split and just numerical one (so it's actually not implemented)?

from shap.

sacmax avatar sacmax commented on May 3, 2024

I am using RandomForest and GBT in pyspark and I have to transform some categorical features using StringIndexer and some using oneHotEncoder. Is there a way/workaround to use SHAP on the transformed categorical features?

from shap.

slundberg avatar slundberg commented on May 3, 2024

A categorical multi-way split is something that is not yet in the C++ version of TreeExplainer. So that would need to be added to support this. I would like to say I’ll just take a few days and write it, but I am not sure when I will have those free days. I’ll try to get it in sooner rather than later...

from shap.

QuentinAmbard avatar QuentinAmbard commented on May 3, 2024

@evah I think it does solve the problem ;)
The notebook just show you how you can compute your explanations using spark on multiple node (explanation of any model, in this case a single node xgboost).
If you want to get the explanation of your spark Tree-based model, you can find an example here:
https://github.com/QuentinAmbard/shap/blob/%23866_spark_regressor_support/tests/explainers/test_tree.py#L190
(As sacmax mentioned, categorical features are not implemented)

from shap.

ijoseph avatar ijoseph commented on May 3, 2024

@shadab-entrepreneur that looks very similar to Niloy Gupta, my, and others' work here

I think you may need to yield rather than return, to attempt directly answer your question?

from shap.

ijoseph avatar ijoseph commented on May 3, 2024

Hm, interesting. The readme has example usage but there's no example notebook as of yet.

Looking again at your stack trace, are you calling df.rdd.mapParittions inside of shapvalue_udf1? Mostly guessing based on seeing what error you got means: py4j.Py4JException: Method __getstate__([]) does not exist and this syntax:

shap_df1 = shap_df.withColumn("shap_val", shapvalue_udf1(f.col("shap")))

from shap.

davezhouwa avatar davezhouwa commented on May 3, 2024

@shadab-entrepreneur have you solved the issue? I'm facing the same problem.

from shap.

techwithshadab avatar techwithshadab commented on May 3, 2024

@davezhouwa Nope, still I'm facing the same issue. Were you able to solve it?

from shap.

QuentinAmbard avatar QuentinAmbard commented on May 3, 2024

It's working fine in my notebook, make sure you can serialize the explainer:
pickle.dumps(explainer)
If you can't then something is wrong (you didn't removed the line)
You need to edit it manually as it's not yet released in the last shap version from pypi

model = pipeline.stages[-1].bestModel
print(model)
 ==> Out[22]: GBTClassificationModel: uid = GBTClassifier_367df9840c81, numTrees=5, numClasses=2, numFeatures=8
explainer = shap.TreeExplainer(bestModel)
#you can check if we can serialize the explainer: 
#pickle.dumps(explainer)

And now it's working in a pandas_udf, or using the new mapInPandas in spark 3:

def compute_shap_values(iterator):
  for X in iterator:
      yield pd.DataFrame(explainer.shap_values(X, check_additivity=False))

display(dataset.mapInPandas(compute_shap_values, schema=dataset.schema).toDF(*[x+"_shap_value" for x in test.columns]))

from shap.

davezhouwa avatar davezhouwa commented on May 3, 2024

@shadab-entrepreneur, I followed @QuentinAmbard suggestions and it worked

from shap.

boechat107 avatar boechat107 commented on May 3, 2024

Hoping someone might benefit, my approach to the problems described in this issue was:

  1. randomly sample the target Spark DataFrame (to make sure the data fits the master node)
  2. convert the DF to a numpy array
  3. calculate SHAP
def df2numpy(df: DataFrame, cols: List[str], nsamples: int) -> np.ndarray:
    """
    Converts a DataFrame into a Numpy array (matrix). "cols" gives the order of
    the matrix columns; "nsamples" is the approximate number of records.
    """
    # This is not efficient; "count" may be expensive to compute.
    fraction = nsamples / df.count()
    samp_df = df.sample(fraction=fraction, seed=6969) if fraction < 1 else df
    return np.array(samp_df.select(*cols).collect())

I also had problems with categorical features (NotImplementedError: CategoricalSplit are not yet implemented). I had forgotten that Spark implicitly uses column metadata to identify indexed features and make "categorical splits". To solve this issue, I just created another DF without any metadata:

    # "gbt" is an instance of GBTClassifier.
    # "df" contains numeric and categorical/indexed features.
    pipe = Pipeline(stages=[assembler, gbt]).fit(
        # Remove metadata of the input DF. We don't want to handle categorical
        # values when training a GBT model (Spark models implicitly use
        # metadata to handle categorical features).
        df.rdd.toDF()
    )

from shap.

zwag20 avatar zwag20 commented on May 3, 2024

Hoping someone might benefit, my approach to the problems described in this issue was:

  1. randomly sample the target Spark DataFrame (to make sure the data fits the master node)
  2. convert the DF to a numpy array
  3. calculate SHAP
def df2numpy(df: DataFrame, cols: List[str], nsamples: int) -> np.ndarray:
    """
    Converts a DataFrame into a Numpy array (matrix). "cols" gives the order of
    the matrix columns; "nsamples" is the approximate number of records.
    """
    # This is not efficient; "count" may be expensive to compute.
    fraction = nsamples / df.count()
    samp_df = df.sample(fraction=fraction, seed=6969) if fraction < 1 else df
    return np.array(samp_df.select(*cols).collect())

I also had problems with categorical features (NotImplementedError: CategoricalSplit are not yet implemented). I had forgotten that Spark implicitly uses column metadata to identify indexed features and make "categorical splits". To solve this issue, I just created another DF without any metadata:

    # "gbt" is an instance of GBTClassifier.
    # "df" contains numeric and categorical/indexed features.
    pipe = Pipeline(stages=[assembler, gbt]).fit(
        # Remove metadata of the input DF. We don't want to handle categorical
        # values when training a GBT model (Spark models implicitly use
        # metadata to handle categorical features).
        df.rdd.toDF()
    )

I didn't not find that the categorical features fixed the CategoricalSplit are not yet implemented.
Has anyone found a fix for this?

from shap.

demirbilek95 avatar demirbilek95 commented on May 3, 2024

Hey everyone,
Is there any update on CategoricalSplit implementation?
Thanks in advance!

Hoping someone might benefit, my approach to the problems described in this issue was:

  1. randomly sample the target Spark DataFrame (to make sure the data fits the master node)
  2. convert the DF to a numpy array
  3. calculate SHAP
def df2numpy(df: DataFrame, cols: List[str], nsamples: int) -> np.ndarray:
    """
    Converts a DataFrame into a Numpy array (matrix). "cols" gives the order of
    the matrix columns; "nsamples" is the approximate number of records.
    """
    # This is not efficient; "count" may be expensive to compute.
    fraction = nsamples / df.count()
    samp_df = df.sample(fraction=fraction, seed=6969) if fraction < 1 else df
    return np.array(samp_df.select(*cols).collect())

I also had problems with categorical features (NotImplementedError: CategoricalSplit are not yet implemented). I had forgotten that Spark implicitly uses column metadata to identify indexed features and make "categorical splits". To solve this issue, I just created another DF without any metadata:

    # "gbt" is an instance of GBTClassifier.
    # "df" contains numeric and categorical/indexed features.
    pipe = Pipeline(stages=[assembler, gbt]).fit(
        # Remove metadata of the input DF. We don't want to handle categorical
        # values when training a GBT model (Spark models implicitly use
        # metadata to handle categorical features).
        df.rdd.toDF()
    )

I didn't not find that the categorical features fixed the CategoricalSplit are not yet implemented.
Has anyone found a fix for this?

from shap.

varunbanda avatar varunbanda commented on May 3, 2024

Hello!

Will CategoricalSplit be implemented in the next release? Any information?

from shap.

olbapjose avatar olbapjose commented on May 3, 2024

Hi, any news on this? I have seen the contributionPredictionCol but I am not sure whether SHAP values can now be calculated on the Spark version of XGBoost or there are still cornercases or workarounds.

from shap.

venser12 avatar venser12 commented on May 3, 2024

Hi, is CategoricalSplit yet implemented?

from shap.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.