GithubHelp home page GithubHelp logo

castorini / daam Goto Github PK

View Code? Open in Web Editor NEW
601.0 12.0 60.0 2.2 MB

Diffusion attentive attribution maps for interpreting Stable Diffusion.

License: MIT License

Python 3.22% CSS 0.03% Jupyter Notebook 96.74%
explainable-ai huggingface pytorch stable-diffusion generative-ai diffusion

daam's Introduction

What the DAAM: Interpreting Stable Diffusion Using Cross Attention

HF Spaces Citation PyPi version Downloads

example image

Updated to support Stable Diffusion XL (SDXL) and Diffusers 0.21.1!

I regularly update this codebase. Please submit an issue if you have any questions.

In our paper, we propose diffusion attentive attribution maps (DAAM), a cross attention-based approach for interpreting Stable Diffusion. Check out our demo: https://huggingface.co/spaces/tetrisd/Diffusion-Attentive-Attribution-Maps. See our documentation, hosted by GitHub pages, and our Colab notebook, updated for v0.1.0.

Getting Started

First, install PyTorch for your platform. Then, install DAAM with pip install daam, unless you want an editable version of the library, in which case do git clone https://github.com/castorini/daam && pip install -e daam. Finally, login using huggingface-cli login to get many stable diffusion models -- you'll need to get a token at HuggingFace.co.

Running the Website Demo

Simply run daam-demo in a shell and navigate to http://localhost:8080. The same demo as the one on HuggingFace Spaces will show up.

Using DAAM as a CLI Utility

DAAM comes with a simple generation script for people who want to quickly try it out. Try running

$ mkdir -p daam-test && cd daam-test
$ daam "A dog running across the field."
$ ls
a.heat_map.png    field.heat_map.png    generation.pt   output.png  seed.txt
dog.heat_map.png  running.heat_map.png  prompt.txt

Your current working directory will now contain the generated image as output.png and a DAAM map for every word, as well as some auxiliary data. You can see more options for daam by running daam -h. To use Stable Diffusion XL as the backend, run daam --model xl-base-1.0 "Dog jumping".

Using DAAM as a Library

Import and use DAAM as follows:

from daam import trace, set_seed
from diffusers import DiffusionPipeline
from matplotlib import pyplot as plt
import torch


model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
device = 'cuda'

pipe = DiffusionPipeline.from_pretrained(model_id, use_auth_token=True, torch_dtype=torch.float16, use_safetensors=True, variant='fp16')
pipe = pipe.to(device)

prompt = 'A dog runs across the field'
gen = set_seed(0)  # for reproducibility

with torch.no_grad():
    with trace(pipe) as tc:
        out = pipe(prompt, num_inference_steps=50, generator=gen)
        heat_map = tc.compute_global_heat_map()
        heat_map = heat_map.compute_word_heat_map('dog')
        heat_map.plot_overlay(out.images[0])
        plt.show()

You can also serialize and deserialize the DAAM maps pretty easily:

from daam import GenerationExperiment, trace

with trace(pipe) as tc:
    pipe('A dog and a cat')
    exp = tc.to_experiment('experiment-dir')
    exp.save()  # experiment-dir now contains all the data and heat maps

exp = GenerationExperiment.load('experiment-dir')  # load the experiment

We'll continue adding docs. In the meantime, check out the GenerationExperiment, GlobalHeatMap, and DiffusionHeatMapHooker classes, as well as the daam/run/*.py example scripts. You can download the COCO-Gen dataset from the paper at http://ralphtang.com/coco-gen.tar.gz. If clicking the link doesn't work on your browser, copy and paste it in a new tab, or use a CLI utility such as wget.

See Also

Citation

@inproceedings{tang2023daam,
    title = "What the {DAAM}: Interpreting Stable Diffusion Using Cross Attention",
    author = "Tang, Raphael  and
      Liu, Linqing  and
      Pandey, Akshat  and
      Jiang, Zhiying  and
      Yang, Gefei  and
      Kumar, Karun  and
      Stenetorp, Pontus  and
      Lin, Jimmy  and
      Ture, Ferhan",
    booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
    year = "2023",
    url = "https://aclanthology.org/2023.acl-long.310",
}

daam's People

Contributors

daemon avatar eltociear avatar furkangozukara avatar ji-xin avatar likicode avatar nityanandmathur avatar rockerboo avatar saltaccount avatar wangdong2023 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

daam's Issues

Feasibility of using code/logic on non-diffusers repository (applying DAAM to image editing models)

Hi,

Thanks for this great paper+codebase! A deserved best paper award.

I am thinking of applying the DAAM idea to image editing models (i.e. InstructPix2Pix: https://github.com/timothybrooks/instruct-pix2pix) and was wondering how hard you think it would be to get the same logic running in a non-diffusers code base such as the linked Pix2Pix which is older CompVis code for SD1.5? For example what are the key parts of your code that would need to be reimplemented, and the evaluation part (IOU on COCO-Gen with 0.4 threshold) from the paper?

Let me know what you think!
Best,
Benno

Showing <|startoftext|> and <|endoftext|>

Hi,

Thanks for a great work in visualising the cross attention of diffusion models. Is it possible to visualise the cross-attention map for the tokens of <|startoftext|> and <|endoftext|>? Since they play an important role in diffusion models too.

Trace hook never unhooks the attention processor

daam/daam/trace.py

Lines 281 to 282 in a323129

def _hook_impl(self):
self.module.set_processor(self)

This hook doesn't unhook which can cause problems for the network in the future.

So I added a capture of the original processor and then added an unhook impl to set it back to the original processor.

    def _hook_impl(self):
        self.original_processor = self.module.processor
        self.module.set_processor(self)

    def _unhook_impl(self):
        self.module.set_processor(self.original_processor)

I can make a PR in the future but in a messy situation at the moment.

How to use without a pipeline?

I am not using the HF pipeline in some of my scripts. Is it possible to hook to unet only to collect probs for specific tokens?

DAAM with mu

Hey! Great job on this repo! Very clean documentation and a useful idea.

Support for Prompt Embeddings Input Argument during Inference

This is more of a request, but would you be able to support using custom embeddings and negative embeddings as pipeline arguments? The reason I want to do this is so I can use prompt engineering techniques such as prompt weighting/emphasis, which aren't directly supported in diffusers. HuggingFace suggests using the compel library to generate your own text embeddings and then input those into the pipeline - https://huggingface.co/docs/diffusers/using-diffusers/weighted_prompts

However, when trying to use this with DAAM, I get the following error:

from diffusers import StableDiffusionPipeline
from compel import Compel

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")

compel = Compel(tokenizer=pipeline.tokenizer, text_encoder=pipeline.text_encoder, truncate_long_prompts=False)

prompt = "a person"
negative_prompt = "bad art" 

conditioning = compel.build_conditioning_tensor(prompt)
negative_conditioning = compel.build_conditioning_tensor(negative_prompt)

[embeddings, negative_embeddings] = compel.pad_conditioning_tensors_to_same_length([conditioning, negative_conditioning])

with torch.autocast("cuda", dtype=torch.float16), torch.no_grad():
    with trace(pipeline) as tc:
        image = pipeline(
            prompt_embeds=embeddings, 
            negative_prompt_embeds=negative_embeddings, 
            height=512,
            width=512,
            num_images_per_prompt=1,
            num_inference_steps=35,
            guidance_scale=7.5,  
        ).images[0]

ERROR MESSAGE:

Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:12                                                                                   │
│                                                                                                  │
│    9                                                                                             │
│   10 with torch.autocast("cuda", dtype=torch.float16), torch.no_grad():                          │
│   11 │   with trace(pipeline) as tc:                                                             │
│ ❱ 12 │   │   image = pipeline(                                                                   │
│   13 │   │   │   prompt_embeds=embeddings,                                                       │
│   14 │   │   │   height=height,                                                                  │
│   15 │   │   │   width=width,                                                                    │
│                                                                                                  │
│ /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/utils/_contextlib.py:115 in           │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /opt/conda/envs/pytorch/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeli │
│ ne_stable_diffusion.py:645 in __call__                                                           │
│                                                                                                  │
│   642 │   │   do_classifier_free_guidance = guidance_scale > 1.0                                 │
│   643 │   │                                                                                      │
│   644 │   │   # 3. Encode input prompt                                                           │
│ ❱ 645 │   │   prompt_embeds = self._encode_prompt(                                               │
│   646 │   │   │   prompt,                                                                        │
│   647 │   │   │   device,                                                                        │
│   648 │   │   │   num_images_per_prompt,                                                         │
│                                                                                                  │
│ /opt/conda/envs/pytorch/lib/python3.10/site-packages/daam/trace.py:146 in _hooked_encode_prompt  │
│                                                                                                  │
│   143 │   │   return image, has_nsfw                                                             │
│   144 │                                                                                          │
│   145 │   def _hooked_encode_prompt(hk_self, _: StableDiffusionPipeline, prompt: Union[str, Li   │
│ ❱ 146 │   │   if not isinstance(prompt, str) and len(prompt) > 1:                                │
│   147 │   │   │   raise ValueError('Only single prompt generation is supported for heat map co   │
│   148 │   │   elif not isinstance(prompt, str):                                                  │
│   149 │   │   │   last_prompt = prompt[0]                                                        │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: object of type 'NoneType' has no len()

Compel documentation - https://github.com/damian0815/compel

Report a mistake(maybe)

Thanks for your code. I am following your work, and I found a piece of confusing code that might be a mistake. It is in the file hook.py lines 103 and 104:
up_names = ['up'] * len(model.up_blocks)
down_names = ['down'] * len(model.up_blocks)
Shouldn't it probably be:
down_names = ['down'] * len(model.down_blocks)

stable-cascade

can use the method to show the cross attention map in the stable-cascade?

Add support for non-square output

Hi,

Thanks for the great work! I wonder if you can add the support for non-square image outputs? For example, the output height is 704 and width is 512. That will be helpful. Thank you!

Beta release 0.1.0

Gearing up for a beta release in Jan or Feb.

  • Docs
  • More general introspection
  • CI and unit testing
  • Website UI

Clarification needed

Thanks for your great work! I wanna know why we need this operation below? We see that we only need half of the attn maps, for example if we have 8 heads then below for map_.size(0) we will have 16. But why do we have 16 in the first place considering we only have 8 heads each transformer block? Can you show me where does diffusers do this? Really confused, thank you!

map_ = map_[map_.size(0) // 2:] # Filter out unconditional

Installation problem

Thx for your great work!
I wonder if it is possible to upgrade diffusers to 0.11.0. Many current works need updated diffusers. Really appreciate.

Collab Demo

Thank you for the amazing work. I am trying to run the Collab demo. I am getting the error below while loading the model:

model = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-base')
AttributeError: module transformers has no attribute CLIPImageProcessor

about image size and prompt length

Hi, firstly thank you for sharing your great work.

While reading your paper and code, I had question about image size and prompt length.
It seems like your code only accepts image size of 512x512 or 768x768, and with maximum prompt length of 75+2, which is limit of text encoder.
However, if my understanding is correct, there is no problem with getting attribution maps while making any size of image with longer prompt.

Is it correct? Or is there any other reason that daam is coded with such a limitation?

[Question] about perturbation- and gradient-based methods

Hi, @daemon !

Thanks for sharing exciting work!
I read your paper. And, I have a question about how to do experiments using perturbation- and gradient-based methods in Section 2.2 Diffusion Attentive Attribution Maps.

You mentioned that gradient methods require to need a back-prop for all T timesteps.
I agree with the problem. However, following the sentence, you wrote that even minor pertubations cause to generate different images.

How did you confirm the problem? I'd like to try both perturbation- and gradient-based methods.
Would you tell me how I can verify them?

Best regards.

Newer version of diffusers fails due to tensor being passed to safety checker

With diffusers 0.21.2 the safety checker hook fails due to tensors being passed instead of numpy arrays.

Traceback (most recent call last):
  File "/mnt/900/builds/sd-scripts/daam/test.py", line 161, in <module>
    main(args)
  File "/mnt/900/builds/sd-scripts/daam/test.py", line 115, in main
    out = pipe(
  File "/mnt/900/builds/sd-scripts/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/900/builds/sd-scripts/.venv/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 708, in __call__
    image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
  File "/mnt/900/builds/sd-scripts/daam/daam/trace.py", line 140, in _hooked_run_safety_checker
    pil_image = self.numpy_to_pil(image)
  File "/mnt/900/builds/sd-scripts/.venv/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py", line 1699, in numpy_to_pil
    return numpy_to_pil(images)
  File "/mnt/900/builds/sd-scripts/.venv/lib/python3.10/site-packages/diffusers/utils/pil_utils.py", line 43, in numpy_to_pil
    images = (images * 255).round().astype("uint8")
AttributeError: 'Tensor' object has no attribute 'astype'. Did you mean: 'dtype'?

Am aware the library isn't supported at this version. Making a PR to fix with backwards compatibility.

heatmap for num_inference_steps=1 using callable

Hi,
I want to compare heatmaps for
operating diffusion steps for one time and 20 times.
How should I add code in your colab?
I think I should use callable and callback_steps options in 'model().images[0]',
but I don't know what to do after when plotting heatmaps.
I would greatly appreciate it if you could provide me with guidance thanks!

Time_idx

Hi,

Thanks for the code! Quick question, when computing the heatmaps for the first timestep
compute_global_heat_map(prompt, time_idx=time_idx)

time_idx=0 would be the first denoising step in other words step 999. as per scheduler with 1000 steps? or rather that would be the last denoising step same as the time step from scheduler?

Thanks!

cross attention map of an existing image

Hi. I want to ask, how to compute the cross attention map of an existing image? For example, I already have an image of a human, my text is "a man is jumping". I want to get the heat map of word "jump".

In your code, it seems that the heat map is collected from the generation process.

Thanks.

about head-dependent pairs

I want to see and use head-dependent pairs in section 4. Would you open 8,000 head–dependent pairs ?

How to evaluate DAAM in multi-object images?

Hi, thanks for sharing this wonderful work! I have some questions regarding multi-object images. As I see in the codes:

daam/daam/run/evaluate.py

Lines 157 to 171 in 299de09

m = expand_m(m)
mask = torch.ones_like(m)
mask[m < args.max_threshold] = 0
gt_map = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)[:, :, 3].astype(np.float32) / 255
gt_map[gt_map < 1.0] = 0
gt_map = torch.from_numpy(gt_map)
if args.segment == '80':
word = merge_dict.get(word, word)
if word not in CLASSES:
continue
iou = compute_iou(mask, gt_map)

You compute the binary mask for each object. Am I correct? Then I think it is a little different from the standard segmentation evaluation protocol, as we compute each pixel's label for one image. Thus if we want to evaluate IOU on a multi-object image, what should DAAM do? What if one pixel activates for several nouns? Can we compare the attention value for different nouns directly? Looking forward to your reply, thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.