GithubHelp home page GithubHelp logo

mateoespinosa / tabcbm Goto Github PK

View Code? Open in Web Editor NEW
4.0 3.0 2.0 12.12 MB

Official Implementation of TMLR's paper: "TabCBM: Concept-based Interpretable Neural Networks for Tabular Data"

Home Page: https://openreview.net/pdf?id=TIsrnWpjQ0

License: MIT License

Python 100.00%
artificial-intelligence concept-based-explanations concept-based-models concepts explainability explainable-ai explainable-artificial-intelligence interpretable-deep-learning neural-networks tabular tabular-data xai

tabcbm's Introduction

TabCBM: Concept-based Interpretable Neural Networks for Tabular Data

License: MIT Python 3.7+ Paper Poster

TabCBM Architecture

This repository contains the official implementation of our TMLR paper "TabCBM: Concept-based Interpretable Neural Networks for Tabular Data" and its corresponding version in ICML's Workshop on Interpretable Machine Learning for Healthcare (IMLH 2023).

This work was done by Mateo Espinosa Zarlenga, Zohreh Shams, Michael Edward Nelson, Been Kim, and Mateja Jamnik

TL;DR

There has been significant efforts in recent years on designing neural architectures that can explain their predictions using high-level units of information referred to as "concepts". Nevertheless, these methods have thus far never been deployed or designed to be applicable for tabular tasks, leaving crucial domains such as those in healthcare and genomics out of the scope of concept-based interpretable models. In this work, we first provide the a novel definition of what a concept entails in a general tabular domain and then propose Tabular Concept Bottleneck Models (TabCBMs), a family of interpretable self-explaining neural architectures capable of discovering high-level concept explanations for tabular tasks without sacrificing state-of-the-art performance.

Abstract

Concept-based interpretability addresses the opacity of deep neural networks by constructing an explanation for a model's prediction using high-level units of information referred to as concepts. Research in this area, however, has been mainly focused on image and graph-structured data, leaving high-stakes tasks whose data is tabular out of reach of existing methods. In this paper, we address this gap by introducing the first definition of what a high-level concept may entail in tabular data. We use this definition to propose Tabular Concept Bottleneck Models (TabCBMs), a family of interpretable self-explaining neural architectures capable of learning high-level concept explanations for tabular tasks. As our method produces concept-based explanations both when partial concept supervision or no concept supervision is available at training time, it is adaptable to settings where concept annotations are missing. We evaluate our method in both synthetic and real-world tabular tasks and show that TabCBM outperforms or performs competitively compared to state-of-the-art methods, while providing a high level of interpretability as measured by its ability to discover known high-level concepts. Finally, we show that TabCBM can discover important high-level concepts in synthetic datasets inspired by critical tabular tasks (e.g., single-cell RNAseq) and allows for human-in-the-loop concept interventions in which an expert can identify and correct mispredicted concepts to boost the model's performance.

Installation

You can locally install this package by first cloning this repository:

$ git clone https://github.com/mateoespinosa/tabcbm

We provide an automatic mechanism for this installation using Python's setup process with our standalone setup.py. To install our package, therefore, you only need to move into the cloned directory (cd tabcbm) and run:

$ python setup.py install

After running this, you should by able to import our package locally using

import tabcbm

Usage

High-level Usage

In this repository, we include a standalone TensorFlow implementation of Tabular Concept Bottleneck Models (TabCBMs) which can be easily trained from scratch given a set of samples that may or may not have binary concept annotations.

In order to use our model's implementation, you first need to install all our code's requirements (listed in requirements.txt) or by following the installation instructions above.

After you have installed all dependencies, you should be able to import TabCBM as a standalone keras Model as follows:

from tabcbm.models.tabcbm import TabCBM

#####
# Define your pytorch dataset objects
#####

x_train = ...  # Numpy np.ndarray with shape (batch, features) containing samples
y_train = ...   # Numpy np.ndarray with shape (batch) containing integer labels

#####
# Construct the model's hyperparameters
#####
n_concepts = ...  # Number of concepts we wish to discover in the task

tab_cbm_params = dict(
    features_to_concepts_model=..., # Provide Keras model to be used for the feature to latent code model (e.g., $\phi$)
    concepts_to_labels_model=..., # Provide Keras model to be used for the concept-scores-to-label model (e.g., f)
    loss_fn=..., # An appropiate loss function following tensorflow's loss function APIs
    latent_dims=32, # Size of latent space for concept embeddings
    n_concepts=n_concepts,  # Number of concepts we wish to discover in the task
    n_supervised_concepts=0, # Change to another number if concept supervision is expected (i.e., we have a c_train matrix)

    # Loss hypers
    coherence_reg_weight=..., # Scalar loss weight for the coherence regularizer
    diversity_reg_weight=..., # Scalar loss weight for the diversity regularizer
    feature_selection_reg_weight=..., # Scalar loss weight for the specificity regularizer
    concept_prediction_weight=..., # Scalar loss weight hyper for the concept predictive loss (only relevant if n_supervised_concepts != 0)
)

#####
# Perform its self-supervised pre-training first
#####

ss_tabcbm = TabCBM(
    self_supervised_mode=True,
    **tab_cbm_params,
)
ss_tabcbm.compile(optimizer=tf.keras.optimizers.Adam(1e-3))
ss_tabcbm._compute_self_supervised_loss(x_test[:2, :])
ss_tabcbm.set_weights(ss_tabcbm.get_weights())

ss_tabcbm.fit(
    x=x_train,
    y=y_train,
    validation_split=0.2,
    epochs=100,
    batch_size=256
)

#####
# And now do its supervised training stage
#####

tabcbm = TabCBM(
    self_supervised_mode=False,
    concept_generators=ss_tabcbm.concept_generators,
    prior_masks=ss_tabcbm.feature_probabilities,
    **tab_cbm_params,
)
tabcbm.compile(optimizer=optimizer_gen())
tabcbm._compute_supervised_loss(
    x_test[:2, :],
    y_test[:2],
    c_true=(
        c_train_real[:2, :]
        if c_train_real is not None else None
    ),
)
tabcbm.fit(
    x=x_train,
    y=y_train,
    validation_split=0.2,
    epochs=100,
    batch_size=256
)

For a step-by-step example showing how to generate a dataset and configure a TabCBM for training on your own custom dataset, please see our Synth-Nonlin example notebook.

Further documentation on this model's parameters, as well as test scripts showing examples of usage, will be incoorporated with the aforementioned refactor that is currently going on. Until then, feel free to raise an issue or reach out if you want specific details on a parameter of TabCBM.

Class Arguments

Our TabCBM module takes the following initialization key arguments:

  • features_to_concepts_model (tf.keras.Model): A tensorflow model mapping input features with shape $(B, ...)$ to a set of latent codes with shape $(B, m)$. This is the latent code encoder $\phi$ used in Figure 1 of our paper.
  • concepts_to_labels_model (tf.keras.Model): A tensorflow model mapping a set of concept scores in $[0, 1]$ with shape $(B, k')$ to a set of output probabilities for each of the task classes (i.e., a tensor with shape $(B, L)$). This is the label predictor $f$ used in Figure 1 of our paper.
  • latent_dims (int): The dimensionality $m$ of the latent code and the concept embeddings.
  • n_concepts (int): Number of total concepts to use for this model. This number must include both supervised concepts (if any) and unsupervised/discovered concepts.
  • masking_values (np.ndarray or None): The values to use when masking each of the input features. This should be an array with $n$ elments in it, one for each input feauture. If None, as it is defaulted, we use an array with all zeros (i.e., we will mask all samples using zero masks).
  • features_to_embeddings_model (tf.keras.Model or None): An optional tensorflow model used to preprocess the input features before passing them to the feature_to_concepts_model. This argument can be used to incoorporate learnable embedding-based models when working with categorical features where we would like to learn embeddings for each of the categories in each discrete feature. If not provided (i.e., set to None) then we assume no input preprocessing is needed. If provided, then it is expected that the effective input shape of features_to_concepts_model is the effective output shape of features_to_embeddings_model.
  • cov_mat (np.ndarray or None): Empirical $(n \times n)$ covariance matrix for the input training features. This covariance is used for learning correlated gate maskings that take into account cross-feature correlations to avoid leakage when a feature is masked (as in SEFS). If not provided (i.e., set to None) then we will assume that all features are independent of each other (i.e., the covariance matrix is the identity matrix).

The different components of TabCBM's loss can be configured through the following arguments:

  • loss_fn (Callable[tf.Tensor, tf.Tensor]): A loss function to use for the downstream task. This is a differientable TF function that takes the true labels y_true and the predicted labels y_pred for a batch of B inputs, and returns a vector of size (B) describing the loss for each sample. Defaults to TF's unreduced categorical cross entropy.
  • coherence_reg_weight (float): Weight for the coherence regulariser (called $\lambda_\text{co}$ in the paper). Defaults to 0.1.
  • diversity_reg_weight (float): Weight for the diversity regulariser (called $\lambda_\text{div}$ in the paper). Defaults to 5.
  • feature_selection_reg_weight (float): Weight for the specificity regulariser (called $\lambda_\text{spect}$ in the paper). Defaults to 5.
  • top_k (int): Number of k-nearest neighbors to use when computing the coherency loss. This argument is important to fine tune and must be less than the batch size. Defaults to 32.

If some ground-truth concepts labels are provided during training, then this can be indicated through the following arguments:

  • n_supervised_concepts (int): Number of concepts that will be provided supervision for. If non-zero, then we expect, for each sample, to be provided with a vector of n_supervised_concepts with binary concept annotations. The value of n_supervised_concepts should be less than n_concepts. Defaults to 0 (i.e., no ground-truth concepts provided).
  • concept_prediction_weight (int): When provided with ground-truth concepts during training, this value specifies the weight of the concept prediction loss used during training for the supervised concepts. Defaults to 0 (i.e., no ground-truth concepts provided).

A quick note that if ground truth concepts are provided during training, then the first n_supervised_concepts concept scores will correspond to the provided concepts in the same order they are given in the training concept label vector. Moreover, our TabCBM implementation supports partial concept annotations (i.e., some samples may have some concepts annotatated and some may not). This can be done by setting unknown concept labels as NaNs in the corresponding training samples.

Our TabCBM's class also provides arguments to aid with the end-to-end incorporation of the Self-supervised pipeline as part of its training (as shown in the example above). These arguments are:

  • self_supervised_mode (bool): Whether or not this model's mask generator modules have been pretrained. If True, then it will use the SEFS pre-text self-supervised task to pretrain these modules when one calls the .fit(...) function. Otherwise, if False, it assumes mask generators have already been pre-trained and .fit(...) will proceed directly to end-to-end training of the entire TabCBM model. See TabCBM notebook example to see how to use this parameter for pretraining. Defaults to False.
  • g_model (tf.keras.Model or None): Model to be used for reconstructing the input features from the learnt latent code during self-supervised pretraining. If not provided (i.e., set to None) then it defaults to a simple 3-layer ReLU MLP model with a hidden layer with 500 activation in it. Notice that this model is only relevant during self-supervised pretraining but it is irrelevant during end-to-end TabCBM training and any subsequent inferences.
  • gate_estimator_weight: (float): Weight to be used self-supervised mask generator pretraining for the regularizer penalizing the model for not correctly predicting the mask applied to the sample. Defaults to 1.
  • include_bn (bool): Whether or not we include a learnable batch normalization layer that preprocesses the input features before any concept embeddings/scores. Defaults to False.
  • rec_model_units (List[int]): The size of the layers for the MLP used for the the reconstruction model during the self-supervised pre-training stagte. Defaults to [64].

Finally, we provide some control over the architecture used for the concept generators via the following arguments:

  • concept_generator_units (List[int]): The size of the layers for the MLP used for the concept generator models (i.e., $\rho^{(i)}$ models.). Defaults to [64].
  • concept_generators (list[tf.keras.Model]): A list of n_concepts TF model to be used as concept generators $\rho$. If not provided (i.e., set to None), then will instantiate each concept generator using a ReLU MLP with layer sizes concept_generator_units.
  • prior_masks (np.ndarray or None): Initial values for TabCBM's masks logit probabilities. If not provided (i.e., set to None), then we will randomly initialize every mask's logit probability to a value uniformly at random from $[-1, 1]$.

Experiment Reproducibility

Running Experiments

To reproduce the experiments discussed in our paper, please use our run_experiments.py script in experiments/ after installing the package as indicated above. You should then be able to run all experiments by running this script with the appropiate config from the experiments/configs/ directory. For example, to run our experiments on the Synth-Linear dataset (see our paper), you can execute the following command:

$ python experiments/run_experiments.py dot -c experiments/configs/linear_tab_synth_config.yaml -o results/synth_linear

This should generate a summary of all the results after execution has terminated and dump all results/trained models/logs into the given output directory (results/synth_linear/ in this case).

tabcbm's People

Contributors

mateoespinosa avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

tabcbm's Issues

What are the models that needed in the TabCBM

Hello!
I am trying to use TabCBM for my own tabular datasets which are provided in CSV file. In the model provided, it required some required models such as:
feature_to_concept_model
concept_to_feature_model

So, What does these parts means? It is not referred to it in the paper with details. For example, assume that we have tabular data for train with 1000 rows and 7 columns which are features and 1000 elements vector with binary elements. How can we use TabCBM to train a model for test data?

How can i use my custom dataframe in it

Hello!

While I looked on your launching example, i still dont get how to use my "custom" dataframe(okay, dataframe.value, i mean np array) in it?
I dont need to use part with "generate_tabular_synth_data" at all in this case?
But where should i get c_train/c_test/ground_truth_concept_masks then?
Or i just can assign them as None?
And look, i tried to split my initial df into x_train/x_test and so for y, then i ".value"-ed all of these df-s, then, i've gotten this:

Error itself:

InvalidArgumentError Traceback (most recent call last)
Cell In[352], line 9
7 pretrain_epochs = 50
8 batch_size = 1024
----> 9 pretrain_hist = end_to_end_model.fit(
10 x=x_train,
11 y=y_train,
12 epochs=pretrain_epochs,
13 batch_size=batch_size,
14 validation_split=validation_size,
15 verbose=1,
16 )

File D:\anaconda3\Lib\site-packages\keras\src\utils\traceback_utils.py:122, in filter_traceback..error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.traceback)
120 # To get the full stack trace, call:
121 # keras.config.disable_traceback_filtering()
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb

File D:\anaconda3\Lib\site-packages\tensorflow\python\eager\execute.py:53, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
51 try:
52 ctx.ensure_initialized()
---> 53 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
54 inputs, attrs, num_outputs)
55 except core._NotOkStatusException as e:
56 if name is not None:

InvalidArgumentError: Graph execution error:

Detected at node compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits defined at (most recent call last):
File "", line 198, in _run_module_as_main

File "", line 88, in _run_code

File "D:\anaconda3\Lib\site-packages\ipykernel_launcher.py", line 17, in

File "D:\anaconda3\Lib\site-packages\traitlets\config\application.py", line 992, in launch_instance

File "D:\anaconda3\Lib\site-packages\ipykernel\kernelapp.py", line 701, in start

File "D:\anaconda3\Lib\site-packages\tornado\platform\asyncio.py", line 195, in start

File "D:\anaconda3\Lib\asyncio\windows_events.py", line 321, in run_forever

File "D:\anaconda3\Lib\asyncio\base_events.py", line 608, in run_forever

File "D:\anaconda3\Lib\asyncio\base_events.py", line 1936, in _run_once

File "D:\anaconda3\Lib\asyncio\events.py", line 84, in _run

File "D:\anaconda3\Lib\site-packages\ipykernel\kernelbase.py", line 534, in dispatch_queue

File "D:\anaconda3\Lib\site-packages\ipykernel\kernelbase.py", line 523, in process_one

File "D:\anaconda3\Lib\site-packages\ipykernel\kernelbase.py", line 429, in dispatch_shell

File "D:\anaconda3\Lib\site-packages\ipykernel\kernelbase.py", line 767, in execute_request

File "D:\anaconda3\Lib\site-packages\ipykernel\ipkernel.py", line 429, in do_execute

File "D:\anaconda3\Lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell

File "D:\anaconda3\Lib\site-packages\IPython\core\interactiveshell.py", line 3051, in run_cell

File "D:\anaconda3\Lib\site-packages\IPython\core\interactiveshell.py", line 3106, in _run_cell

File "D:\anaconda3\Lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner

File "D:\anaconda3\Lib\site-packages\IPython\core\interactiveshell.py", line 3311, in run_cell_async

File "D:\anaconda3\Lib\site-packages\IPython\core\interactiveshell.py", line 3493, in run_ast_nodes

File "D:\anaconda3\Lib\site-packages\IPython\core\interactiveshell.py", line 3553, in run_code

File "C:\Users\Timur.c\AppData\Local\Temp\ipykernel_10576\2716014408.py", line 9, in

File "D:\anaconda3\Lib\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler

File "D:\anaconda3\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 314, in fit

File "D:\anaconda3\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 117, in one_step_on_iterator

File "D:\anaconda3\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 104, in one_step_on_data

File "D:\anaconda3\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 54, in train_step

File "D:\anaconda3\Lib\site-packages\keras\src\trainers\trainer.py", line 316, in compute_loss

File "D:\anaconda3\Lib\site-packages\keras\src\trainers\compile_utils.py", line 609, in call

File "D:\anaconda3\Lib\site-packages\keras\src\trainers\compile_utils.py", line 645, in call

File "D:\anaconda3\Lib\site-packages\keras\src\losses\loss.py", line 43, in call

File "D:\anaconda3\Lib\site-packages\keras\src\losses\losses.py", line 22, in call

File "D:\anaconda3\Lib\site-packages\keras\src\losses\losses.py", line 1722, in sparse_categorical_crossentropy

File "D:\anaconda3\Lib\site-packages\keras\src\ops\nn.py", line 1567, in sparse_categorical_crossentropy

File "D:\anaconda3\Lib\site-packages\keras\src\backend\tensorflow\nn.py", line 638, in sparse_categorical_crossentropy

Received a label value of 200000000 which is outside the valid range of [0, 540). Label values: 18000000 16000000 37000000 11300000 <...> 27000000 48000000 [[{{node compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}}]] [Op:__inference_one_step_on_iterator_68526]

Ah, and my splitted arrays (at least what i tried to get as "splitted") have such props:

x_train has shape (13024, 12) and type float64
y_train has shape (13024,) and type int64
x_test has shape (3257, 12) and type float64
y_test has shape (3257,) and type int64

You know, before this, i tried to give into "generate_tabular_synth_data" function my df, as variable named "latent" there. It worked in that block, but also, i was getting very strange, about -10E6-order losses (namely minus) and zero accuracy.
And even after this "half-done" block, next block was giving to me such error:
"ValueError: Only one class present in y_true. ROC AUC score is not defined in that case."
And i could understand this error in the way of absence in "generate_tabular_synth_data" function any specifying of y variable itself.

So, i think it's just i didnt get something very important, so I'd be happy if you help me to understand that all)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.