GithubHelp home page GithubHelp logo

callsys / genpromp Goto Github PK

View Code? Open in Web Editor NEW
53.0 4.0 2.0 10.47 MB

[ICCV 2023] Generative Prompt Model for Weakly Supervised Object Localization

License: Apache License 2.0

Python 100.00%
diffusion iccv2023 wsol

genpromp's Introduction

Generative Prompt Model for Weakly Supervised Object Localization

This is the official implementaion of paper Generative Prompt Model for Weakly Supervised Object Localization, which is accepted in ICCV 2023. This repository contains Pytorch training code, evaluation code, pre-trained models, and visualization method.

arXiv preprint Python 3.8 PyTorch 1.11 LICENSE

PWC PWC

1. Contents

2. Introduction

Weakly supervised object localization (WSOL) remains challenging when learning object localization models from image category labels. Conventional methods that discriminatively train activation models ignore representative yet less discriminative object parts. In this study, we propose a generative prompt model (GenPromp), defining the first generative pipeline to localize less discriminative object parts by formulating WSOL as a conditional image denoising procedure. During training, GenPromp converts image category labels to learnable prompt embeddings which are fed to a generative model to conditionally recover the input image with noise and learn representative embeddings. During inference, GenPromp combines the representative embeddings with discriminative embeddings (queried from an off-the-shelf vision-language model) for both representative and discriminative capacity. The combined embeddings are finally used to generate multi-scale high-quality attention maps, which facilitate localizing full object extent. Experiments on CUB-200-2011 and ILSVRC show that GenPromp respectively outperforms the best discriminative models, setting a solid baseline for WSOL with the generative model.

3. Results

We re-train GenPromp with a better learning schedule on 4 x A100. The performance of GenPromp on CUB-200-2011 is further improved.

Method Dataset Cls Back. Top-1 Loc Top-5 Loc GT-known Loc
GenPromp CUB-200-2011 EfficientNet-B7 87.0 96.1 98.0
GenPromp (Re-train) CUB-200-2011 EfficientNet-B7 87.2 (+0.2) 96.3 (+0.2) 98.3 (+0.3)
GenPromp ImageNet EfficientNet-B7 65.2 73.4 75.0

4. Get Start

4.1 Installation

To setup the environment of GenPromp, we use conda to manage our dependencies. Our developers use CUDA 11.3 to do experiments. Run the following commands to install GenPromp:

conda create -n gpm python=3.8 -y && conda activate gpm
pip install --upgrade pip
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install --upgrade diffusers[torch]==0.13.1
pip install transformers==4.29.2 accelerate==0.19.0
pip install matplotlib opencv-python OmegaConf tqdm

4.2 Dataset and Files Preparation

To train GenPromp with pre-training weights and infer GenPromp with the given weights, download the files in the table and arrange the files according to the file tree below. (Uploading)

Dataset & Files Download Usage
data/ImageNet_ILSVRC2012 (146GB) Official Link Benchmark dataset
data/CUB_200_2011 (1.2GB) Official Link Benchmark dataset
ckpts/pretrains (5.2GB) Official Link, Google Drive, Baidu Drive(o9ei) Stable Diffusion pretrain weights
ckpts/classifications (2.3GB) Google Drive, Baidu Drive(o9ei) Classfication results on benchmark datasets
ckpts/imagenet750 (3.3.GB) Google Drive, Baidu Drive(o9ei) Weights that achieves 75.0% GT-Known Loc on ImageNet
ckpts/cub983 (3.3GB) Google Drive, Baidu Drive(o9ei) Weights that achieves 98.3% GT-Known Loc on CUB
    |--GenPromp/
      |--data/
        |--ImageNet_ILSVRC2012/
           |--ILSVRC2012_list/
           |--train/
           |--val/
        |--CUB_200_2011
           |--attributes/
           |--images/
           ...
      |--ckpts/
        |--pretrains/
          |--stable-diffusion-v1-4/
        |--classifications/
          |--cub_efficientnetb7.json
          |--imagenet_efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k.json
        |--imagenet750/
          |--tokens/
             |--49408.bin
             |--49409.bin
             ...
          |--unet/
        |--cub983/
          |--tokens/
             |--49408.bin
             |--49409.bin
             ...
          |--unet/
      |--configs/
      |--datasets
      |--models
      |--main.py

4.3 Training

Here is a training example of GenPromp on ImageNet.

accelerate config
accelerate launch python main.py --function train_token --config configs/imagenet.yml --opt "{'train': {'save_path': 'ckpts/imagenet/'}}"
accelerate launch python main.py --function train_unet --config configs/imagenet_stage2.yml --opt "{'train': {'load_token_path': 'ckpts/imagenet/tokens/', 'save_path': 'ckpts/imagenet/'}}"

accelerate is used for multi-GPU training. In the first training stage, the weights of concept tokens of the representative embeddings are learned and saved to ckpts/imagenet/. In the second training stage, the weights of the learned concept tokens are loaded from ckpts/imagenet/tokens/, then the weights of the UNet are finetuned and saved to ckpts/imagenet/. Other configurations can be seen in the config files (i.e. configs/imagenet.yml and configs/imagenet_stage2.yml) and can be modified by --opt with a parameter dict (See Extra Options for details).

Here is a training example of GenPromp on CUB_200_2011.

accelerate config
accelerate launch python main.py --function train_token --config configs/cub.yml --opt "{'train': {'save_path': 'ckpts/cub/'}}"
accelerate launch python main.py --function train_unet --config configs/cub_stage2.yml --opt "{'train': {'load_token_path': 'ckpts/cub/tokens/', 'save_path': 'ckpts/cub/'}}"

4.4 Inference

Here is a inference example of GenPromp on ImageNet.

python main.py --function test --config configs/imagenet_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/imagenet750/tokens/', 'load_unet_path': 'ckpts/imagenet750/unet/', 'save_log_path': 'ckpts/imagnet750/log.txt'}}"

In the inference stage, the weights of the learned concept tokens are load from ckpts/imagenet750/tokens/ , the weights of the finetuned UNet are load from ckpts/imagenet750/unet/ and the log file is saved to ckpts/imagnet750/log.txt. Due the random noise added to the tested image, the results might fluctuate within a small range ($\pm$ 0.1).

Here is a inference example of GenPromp on CUB_200_2011.

python main.py --function test --config configs/cub_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/cub983/tokens/', 'load_unet_path': 'ckpts/cub983/unet/', 'save_log_path': 'ckpts/cub983/log.txt'}}"

4.5 Extra Options

There are many extra options during training and inference. The default option is configured in the yml file. We can use --opt to add or override the default option with a parameter dict. Here are some usage of the most commonly used options.

Option Scope Usage
{'data': {'keep_class': [0, 9]}} data keep the data with category id in [0, 1, 2, 3, ..., 9]
{'train': {'batch_size': 2}} train train with batch size 2.
{'train': {'num_train_epochs': 1}} train train the model for 1 epoch.
{'train': {'save_steps': 200}} train_unet save trained UNet every 200 steps.
{'train': {'max_train_steps': 600}} train_unet terminate training within 600 steps.
{'train': {'gradient_accumulation_steps': 2}} train batch size x2 when the memory of GPU is limited.
{'train': {'learning_rate': 5.0e-08}} train the learning rate is 5.0e-8.
{'train': {'scale_lr': True}} train the learning rate is multiplied with batch size if True.
{'train': {'load_pretrain_path': 'stable-diffusion/'}} train the pretrained model is load from stable-diffusion/.
{'train': {'load_token_path': 'ckpt/tokens/'}} train the trained concept tokens are load from ckpt/tokens/.
{'train': {'save_path': 'ckpt/'}} train save the trained weights to ckpt/.
{'test': {'batch_size': 2}} test test with batch size 2.
{'test': {'cam_thr': 0.25}} test test with cam threshold 0.25.
{'test': {'combine_ratio': 0.6}} test combine ratio between $f_r$ and $f_d$ is 0.6.
{'test': {'load_class_path': 'imagenet_efficientnet.json'}} test load classification results from imagenet_efficientnet.json.
{'test': {'load_pretrain_path': 'stable-diffusion/'}} test the pretrained model is load from stable-diffusion/.
{'test': {'load_token_path': 'ckpt/tokens/'}} test the trained concept tokens are load from ckpt/tokens/.
{'test': {'load_unet_path': 'ckpt/unet/'}} test the trained UNet is load from ckpt/unet/.
{'test': {'save_vis_path': 'ckpt/vis/'}} test the visualized predictions are saved to ckpt/vis/.
{'test': {'save_log_path': 'ckpt/log.txt'}} test the log file is saved to ckpt/log.txt.
{'test': {'eval_mode': 'top1'}} test top1 denotes evaluating the predicted top1 cls category of the test image, top5 denotes evaluating the predicted top5 cls category of the test image, gtk denotes evaluating the gt category of the test image, which can be tested without the classification result. We use top1 as the default eval mode.

These options can be combined by simplely merging the dicts. For example, if you want to evaluate GenPromp with config file configs/imagenet_stage2.yml, with categories [0, 1, 2, ..., 9], concept tokens load from ckpts/imagenet750/tokens/, UNet load from ckpts/imagenet750/unet/, log file of the evaluated metrics saved to ckpts/imagnet750/log0-9.txt, combine ratio equals to 0, visualization results saved to ckpts/imagenet750/vis, using the following command:

python main.py --function test --config configs/imagenet_stage2.yml --opt "{'data': {'keep_class': [0, 9]}, 'test': {'load_token_path': 'ckpts/imagenet750/tokens/', 'load_unet_path': 'ckpts/imagenet750/unet/', 'save_log_path':'ckpts/imagnet750/log.txt', 'combine_ratio': 0, 'save_vis_path': 'ckpts/imagenet750/vis'}}"

5. Contacts

If you have any question about our work or this repository, please don't hesitate to contact us by emails or open an issue under this project.

6. Acknowledgment

7. Citation

@article{zhao2023generative,
  title={Generative Prompt Model for Weakly Supervised Object Localization},
  author={Zhao, Yuzhong and Ye, Qixiang and Wu, Weijia and Shen, Chunhua and Wan, Fang},
  journal={arXiv preprint arXiv:2307.09756},
  year={2023}
}
@InProceedings{Zhao_2023_ICCV,
    author    = {Zhao, Yuzhong and Ye, Qixiang and Wu, Weijia and Shen, Chunhua and Wan, Fang},
    title     = {Generative Prompt Model for Weakly Supervised Object Localization},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
    pages     = {6351-6361}
}

genpromp's People

Contributors

callsys avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Forkers

cv-det sab148

genpromp's Issues

Test phase error (tuple index out of index)

Hi,

I tried running your code on my custom dataset. Training works well, but I keep encountering a "tuple index out of range" error during the test phase at the following line of code:
representative_embeddings = [text_encoder(ids.to(device))[0] for ids in data["caption_ids_concept_token"][top_idx]]

To ensure that I didn't make any mistakes in my data loading process, I ran the code with the CUB dataset you provided, but I encountered the same error during the test phase.

I noticed in the GitHub README that you suggest saving the models in ckpts/cub during training, but for testing, you load from the ckpts/cub983 path. Could you explain why this discrepancy exists?

Also, do you have any insights into why this error might be occurring?

Thank you

Google Drive link does not work

Thank you for your great works.
I want to use your pretrained weights but your given google drive link does not work.
And I also tried to download the Baidu file, but the program is all chinese and I failed to download it.
Can you check the google drive link?

Thank you!

Clarification Needed on Model Selection Strategy Across Epochs

I am currently looking into the implementation details of the model training process, particularly focusing on the model saving mechanism as delineated in the code. In if block, on line 489, it is observed that the model is persistently saved at the conclusion of each training epoch. However, the methodology employed for the selection of the optimal model based on the test/validation set performance remains unclear.

Could you kindly provide an elucidation on the criteria or algorithm used for identifying the most effective epoch based on the validation/test set? This clarification will significantly aid in understanding the overall model selection strategy within the training loop.

Thank you for your assistance.

the embeddings in training process

Thanks for the great work!

I have some questions regarding the two types of embeddings, or tokens, mentioned in the paper.

Prior to the training process, the concept tokens are initialized using the meta tokens.
However, I would like to clarify what happens once the training commences.

Do the meta tokens remain static and not participate in the entire training process? Is it solely the concept tokens that are involved throughout the entire training process?

Google Drive download path is invalid.

I think your Google Drive Download Link path is invalid.
Please check your README.md
When click the Google Drive Download Link attatched, only reload the page of the repo.
Thank you for your appreciate :)

test visulization

Thanks for your greate work and sharing.

How do you do the visulization about attention activation?

Thans.

RuntimeError: CUDA out of memory.

Hello,

When i run

python main.py --function test --config configs/cub_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/cub983/tokens/', 'load_unet_path': 'ckpts/cub983/unet/', 'save_log_path': 'ckpts/cub983/log.txt'}}”

I am encountering this error
Traceback (most recent call last): File "/p/project/atmlaml/benassou1/ega/GenPromp/main.py", line 646, in <module> eval(args.function)(config) File "/p/project/atmlaml/benassou1/ega/GenPromp/main.py", line 300, in test noise_pred = unet(noisy_latents, timesteps, combine_embeddings).sample File "/p/software/juwelsbooster/stages/2023/software/PyTorch/1.12.0-foss-2022a-CUDA-11.7/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/p/project/atmlaml/benassou1/ega/GenPromp/sc_venv_template/venv/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py", line 615, in forward sample = upsample_block( File "/p/software/juwelsbooster/stages/2023/software/PyTorch/1.12.0-foss-2022a-CUDA-11.7/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/p/project/atmlaml/benassou1/ega/GenPromp/sc_venv_template/venv/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py", line 1813, in forward hidden_states = attn( File "/p/software/juwelsbooster/stages/2023/software/PyTorch/1.12.0-foss-2022a-CUDA-11.7/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/p/project/atmlaml/benassou1/ega/GenPromp/sc_venv_template/venv/lib/python3.10/site-packages/diffusers/models/transformer_2d.py", line 265, in forward hidden_states = block( File "/p/software/juwelsbooster/stages/2023/software/PyTorch/1.12.0-foss-2022a-CUDA-11.7/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/p/project/atmlaml/benassou1/ega/GenPromp/sc_venv_template/venv/lib/python3.10/site-packages/diffusers/models/attention.py", line 321, in forward ff_output = self.ff(norm_hidden_states) File "/p/software/juwelsbooster/stages/2023/software/PyTorch/1.12.0-foss-2022a-CUDA-11.7/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/p/project/atmlaml/benassou1/ega/GenPromp/sc_venv_template/venv/lib/python3.10/site-packages/diffusers/models/attention.py", line 379, in forward hidden_states = module(hidden_states) File "/p/software/juwelsbooster/stages/2023/software/PyTorch/1.12.0-foss-2022a-CUDA-11.7/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/p/software/juwelsbooster/stages/2023/software/PyTorch/1.12.0-foss-2022a-CUDA-11.7/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 39.56 GiB total capacity; 7.06 GiB already allocated; 1.94 MiB free; 17.07 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I changed the batch size to 1, reduced the size of the image, max_split_size_mb, and still does not work. Could you please help me to fix this problem ?

I'm confused about how to make only the concept token embedding learnable

Thank you for the excellent paper and the implemented code. I have a point of confusion. In Figure 3 of the paper, only v_r is colored in orange.

Does this mean that among the embeddings for each word in "a photo of a ", only the word embedding corresponding to is trainable?

However, I find the following part of your code confusing:

# in datasets/base.py
def init_embeddings(self, text_encoder):
      token_embeds = text_encoder.get_input_embeddings().weight.data.clone()
      for token in self.cat2tokens:
          meta_token_id = self.tokenizer.encode(token['meta_token'], add_special_tokens=False)[0]
          concept_token_id = self.tokenizer.encode(token['concept_token'], add_special_tokens=False)[0]
          token_embeds[concept_token_id] = token_embeds[meta_token_id]
      text_encoder.get_input_embeddings().weight = torch.nn.Parameter(token_embeds)
      return text_encoder

This code sets token_embeds as trainable by making it a torch.nn.Parameter.
However, contrary to what is shown in Figure 3, this seems to make the entire token_embeds trainable, not just the concept token vector corresponding to .

Could you please clarify my confusion? Thank you very much.

P.S. If my understanding is correct, I believe the following lines of code would be necessary to make only the concept token vector trainable:

text_encoder.get_input_embeddings().weight.requires_grad = False
text_encoder.get_input_embeddings().weight[concept_token_id].requires_grad = True

When do `train_unet`, Why don't you use pretrained weight of tokens?

First and foremost, I'd like to express my profound gratitude for the outstanding paper and the code implementation. I have one point of curiosity.

In attempting train_unet, isn't it the case that the initial weights of each category token pretrained in train_token are not used?

In the code, train_unet is executed with split="train". Given this, just like when running train_token, wouldn't the initial weights of the concept_token in the text_encoder be initialized identically to the initial weights of the meta_token?

Since all parameters of the text_encoder are frozen during train_unet, wouldn't this mean that the unet is fine-tuned with the initial weights of both the meta_token and concept_token being the same?

In the Loss formula (5) mentioned in the paper, it is depicted as in the linked image. This Loss seems to utilize f* (pretrained initial weight), hence my query.
image

Thank you always for your hard work.

The parameter `test.combine_ratio` seem invalid when running inference

Hello,

The problem is that the parameter w or "{'test': {'combine_ratio': 0.6}}" in the readme.md doesn't seem to work when running inference.
I tried setting values 0, 1, 0.1, 0.9 and compared to the default result of w=0.6 (Also tested in 3 trained instances) ,this 11 result are the same, like :

Cls@1:0.938     Cls@5:0.988     Loc@1:0.481     Loc@5:0.506     Loc_gt:0.514
M-ins:0.000     Part:0.008      More:0.440      Right:0.481     Wrong:0.008     Cls:0.062

(I'm trying to replicate how fr or fd works alone.)

I confirm that I successfully adjusted the parameters ( source code location and output.INFO below ),
"INFO: Test Class [0-9]: [dataset: cub] [eval mode: top1] [cam thr: 0.23] [combine ratio: 0.9]".
So I don't understand why the result will be the same.(or may be )

After reading the paper, I understood w as:

fc = fd( 1-w ) + fr ( w )
w>>0,fc>>fd, heatmap'box will small and incomplete;
w>>1,fc>>fr, heatmap'box will much bigger cause deteriorated by background noise;

I successfully trained this project on one 3090 24G close to the paper,
by reduced train.batch_size and set larger gradient_accumulation_steps.
(This parameter is set w=0.5 when train_unet.)

Asking for questions about evaluation

Thanks for your great work! There is an issue during testing.
When using python main.py --function test --config configs/cub_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/cub983/tokens/', 'load_unet_path': 'ckpts/cub983/unet/', 'save_log_path': 'ckpts/cub983/log.txt'}}" for evaluation, I found that self.step_store、self. attention_store and self.attention_maps are all empty. Would you please tell me where is wrong?
Looking forward to your reply!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.