GithubHelp home page GithubHelp logo

michelecafagna26 / vl-shap Goto Github PK

View Code? Open in Web Editor NEW
2.0 1.0 0.0 10.88 MB

[Frontiers in AI Journal] Implementation of the paper "Interpreting Vision and Language Generative Models with Semantic Visual Priors"

License: Apache License 2.0

Python 7.70% Jupyter Notebook 92.30%
explainable-ai explainable-machine-learning explanations generative-ai interpretable-ai multimodal-deep-learning semantic shap stego vision-and-language

vl-shap's Introduction

VL-SHAP

Official Repo for the paper: "Interpreting Vision and Language Generative Models with Semantic Visual Priors".


Explain VL generative models using KernelSHAP sentence-based visual explanations, exploiting the model's visual semantic priors.

Overview

UPDATES ๐Ÿ”ฅ

  • 25/11/2023 You can try VL-SHAP with the Gradio Demo ๐ŸŽฎ here:michelecafagna26/vl-shap-demo
  • 21/11/2023 If you have trouble installing STEGO, you can check out this branch vl-shap/adding_clipseg. Here we replace the STEGO model with Clipseg semantic segmentation model. Clipseg is compatible with the latest python and pytorch library and is easier to install. Moreover, you can control the semantic masks through textual prompts!๐Ÿš€

Requirements

3.6.9<= python <= 3.10.12
pytorch <= 1.13
torchvision <=0.14

Installation

pip install git+https://github.com/lucasb-eyer/pydensecrf.git
pip install git+https://github.com/michelecafagna26/vl-shap.git#egg=semshap

Example: Explain OFA Visual Question Answering Model

Install OFA from the official repo Then run the following code to extract semantic masks

import requests
from io import BytesIO
from pathlib import Path
from PIL import Image

from transformers import OFATokenizer, OFAModel

from semshap.masking import generate_dff_masks, generate_superpixel_masks
from semshap.plot import  heatmap, barh, plot_masks
from semshap.explainers import BaseExplainer


import torch
import torch.nn as nn
from torchvision import transforms

import matplotlib.pyplot as plt

ckpt_dir = "/path/to/the/model/ofa-models/OFA-large" # change this
device = "cuda" if torch.cuda.is_available() else "cpu"
img_url="https://farm4.staticflickr.com/3663/3392599156_e94f7d1098_z.jpg"

# load the model
model = OFAModel.from_pretrained(ckpt_dir, use_cache=False).to(device)
tokenizer = OFATokenizer.from_pretrained(ckpt_dir)

# load the image
response = requests.get(img_url)
img = Image.open(BytesIO(response.content))

# Generate semantic masks
patch_resize_transform = transforms.Compose([
    lambda image: image.convert("RGB"),
    transforms.Resize(img.size, interpolation=Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# extract CNN features from the model
with torch.no_grad():
    visual_embeds = model.encoder.embed_images(patch_resize_transform(img).unsqueeze(0).to(device))

visual_embeds = visual_embeds.detach().cpu().squeeze(0).permute(1, 2, 0)

# generate DFF semantic masks
out = generate_dff_masks(visual_embeds, k=10, img_size=img.size, mask_th=25, return_heatmaps=True)

# to visualize the masks run
# plot_masks(out['masks'])

The explainer expects a model that generates a caption given an image: model(img) --> caption. Therefore we write a simple wrapper for our model taking care of the preprocessing and the decoding required by the model.

class ModelWrapper(nn.Module):
    def __init__(self, model, tokenizer, question, resolution, device="cpu"):
        super().__init__()
        
        self.resolution=resolution
        self.num_beams = 5
        self.no_repeat_ngram_size = 3
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.patch_resize_transform = transforms.Compose([
                                                            lambda image: image.convert("RGB"),
                                                            transforms.Resize(self.resolution, interpolation=Image.BICUBIC),
                                                            transforms.ToTensor(),
                                                            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
                                                        ])
        
        self.inputs = tokenizer([question], return_tensors="pt").input_ids.to(self.device)
        
    

    def forward(self, img):
        # put here all to code to generate a caption from an image
        
        patch_img = self.patch_resize_transform(img).unsqueeze(0).to(self.device)
        out_ids = model.generate(self.inputs, patch_images=patch_img, num_beams=self.num_beams, 
                                 no_repeat_ngram_size=self.no_repeat_ngram_size)
        
        return tokenizer.batch_decode(out_ids, skip_special_tokens=True)[0]

Now we can generate a caption in this way

question = "What is the subject doing?"
model_wrapper = ModelWrapper(model, tokenizer, question, resolution=img.size, device=device)
model_wrapper(img)

We have everything we need to create and run the KernelSHAP explainer.

explainer = BaseExplainer(model_wrapper, device=device)
shap, base = explainer.explain(img, out['masks'], k=-1)

We visualize the Shapley values corresponding to the visual features masks as a barchart, by running

labels = [ f"f_{i}" for i in range(shap.shape[0]) ]
barh(labels, shap)

and the visual explanation

heatmap(img, out['heatmaps'], shap, alpha=0.65)

In this way you can generate sentence-based visual semantic explanations like these.

More Examples:

For more use cases check the notebooks:

Troubleshooting

In case of problems wiht the pydensecrf module try:

pip install cython
pip install git+https://github.com/lucasb-eyer/pydensecrf.git

Acknowledgement

  • The STEGO model has been adapted from the official repo
  • The KernelSHAP implementation is based on the official SHAP repo.

Citation Information

@ARTICLE{10.3389/frai.2023.1220476,
 AUTHOR={Cafagna, Michele and Rojas-Barahona, Lina M. and van Deemter, Kees and Gatt, Albert},   
 TITLE={Interpreting vision and language generative models with semantic visual priors},      
 JOURNAL={Frontiers in Artificial Intelligence},      
 VOLUME={6},           
 YEAR={2023},      
 URL={https://www.frontiersin.org/articles/10.3389/frai.2023.1220476},       
 DOI={10.3389/frai.2023.1220476},      
 ISSN={2624-8212},   
}

vl-shap's People

Contributors

michelecafagna26 avatar

Stargazers

 avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.