GithubHelp home page GithubHelp logo

visual-ai / sptnet Goto Github PK

View Code? Open in Web Editor NEW
22.0 2.0 1.0 1.46 MB

The official repository for ICLR2024 paper "SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning"

Home Page: https://visual-ai.github.io/sptnet/

License: Other

Python 100.00%

sptnet's Introduction

SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning (ICLR 2024)

SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning
By Hongjun Wang, Sagar Vaze, and Kai Han.

teaser

Update

[05.2024] We update the results of SPTNet with DINOv2 on CUB, please check our latest version in Arxiv

All Old New
CUB (DINO) 65.8 68.8 65.1
CUB (DINOv2) 76.3 79.5 74.6

Prerequisite πŸ› οΈ

First, you need to clone the SPTNet repository from GitHub. Open your terminal and run the following command:

git clone https://github.com/Visual-AI/SPTNet.git
cd SPTNet

We recommend setting up a conda environment for the project:

conda create --name=spt python=3.9
conda activate spt
pip install -r requirements.txt

Running πŸƒ

Config

Set paths to datasets and desired log directories in config.py

Datasets

We use generic object recognition datasets, including CIFAR-10/100 and ImageNet-100/1K:

We also use fine-grained benchmarks (CUB, Stanford-cars, FGVC-aircraft, Herbarium-19). You can find the datasets in:

Checkpoints

Download the checkpints of SPTNet for different datasets and put them in the ``checkpoints'' folder (only used during evaluation).

Scripts

Eval the model

CUDA_VISIBLE_DEVICES=0 python eval.py \
    --dataset_name 'aircraft' \
    --pretrained_model_path ./checkpoints/fgvc/dinoB16_best.pt \
    --prompt_type 'all' \ # switch to 'patch' for 'cifar10' and 'cifar100'
    --eval_funcs 'v2' \

To reproduce all main results in the paper, just change the name (dataset_name) and its corresponding path (pretrained_model_path) to the pretrained model you downloaded from the above link.

Train the model:

CUDA_VISIBLE_DEVICES=0 python train_spt.py \
    --dataset_name 'aircraft' \
    --batch_size 128 \
    --grad_from_block 11 \
    --epochs 1000 \
    --num_workers 8 \
    --use_ssb_splits \
    --sup_weight 0.35 \
    --weight_decay 5e-4 \
    --transform 'imagenet' \
    --lr 1 \
    --lr2 0.05 \
    --prompt_size 1 \
    --freq_rep_learn 20 \
    --pretrained_model_path ${YOUR_OWN_PRETRAINED_PATH} \
    --prompt_type 'all' \
    --eval_funcs 'v2' \
    --warmup_teacher_temp 0.07 \
    --teacher_temp 0.04 \
    --warmup_teacher_temp_epochs 10 \
    --memax_weight 1 \
    --model_path ${YOUR_OWN_SAVE_DIR}

Just be aware to change the name (dataset_name) and its corresponding path (pretrained_model_path) to the pretrained model. Our SPTNet method is adaptable to various pretrained models, allowing for the modification of the architecture by changing the pretrained_model_path. This feature enables quick adoption of the state-of-the-art (SOTA) method. Our default settings utilize the SimGCD method.

Results

Generic results:

All Old New
CIFAR-10 97.3 95.0 98.6
CIFAR-100 81.3 84.3 75.6
ImageNet-100 85.4 93.2 81.4

Fine-grained results:

All Old New
CUB 65.8 68.8 65.1
Stanford Cars 59.0 79.2 49.3
FGVC-Aircraft 59.3 61.8 58.1
Herbarium19 43.4 58.7 35.2

Citing this work

If you find this repo useful for your research, please consider citing our paper:

@inproceedings{wang2024sptnet,
    author    = {Wang, Hongjun and Vaze, Sagar and Han, Kai},
    title     = {SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning},
    booktitle = {International Conference on Learning Representations (ICLR)},
    year      = {2024}
}

sptnet's People

Contributors

whj363636 avatar dependabot[bot] avatar

Stargazers

 avatar weg avatar Xinzi Cao avatar  avatar  avatar  avatar lilijian avatar hyppku avatar Samil Yalciner avatar Bingchen Zhao avatar Tanmoy avatar Jifeng Wang avatar  avatar wenjie zhu avatar Xiaohu Huang avatar Yuhao Wang  avatar Haonan Lin avatar Iranb avatar  avatar  avatar  avatar  avatar

Watchers

Kai Han avatar  avatar

Forkers

whj363636

sptnet's Issues

about the project code wrong

Thank for your code.
But I found there're some wrong.
from project_utils.general_utils from methods.vpt.utils import cosine_lr
I found general_utils in util,but couldn't cosine_lr

About the checkpoints of CIFAR-100

Thanks for your great work. I downloaded the checkpoints file you provided and found there is a problem with the cifar10&100 checkpoints running on eval.py. Did you accidentally provide the wrong weight file, or do I need to modify the code before it run? Looking forward to your reply.
The results of CIFAR-100
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 118/118 [01:17<00:00, 1.53it/s]
Epoch 0, Train ACC Unlabelled_v2: All 0.1460 | Old 0.1101 | New 0.2177
Best Accuracies: All 0.1460 | Old 0.1101 | New 0.2177

The results of CIFAR-10
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 147/147 [01:38<00:00, 1.50it/s]
Epoch 0, Train ACC Unlabelled_v2: All 0.4842 | Old 0.2973 | New 0.5777
Best Accuracies: All 0.4842 | Old 0.2973 | New 0.5777

feasibility for general classification

Hi,

Thanks for your excellent work and doc. Currently, I am actively involved in a GCD-related project, and your work is instrumental in reproducing the results. Based on my trials, SPTNet indeed shows contribution to the GCD task. However, I am curious to know if you have investigated the potential of utilizing the alternative method for enhancing representation in broader classification tasks, such as Flowers-102. Have you conducted any experiments in this context?

the pretrained model weights

How can I load the parameters for both the backbone and projector together? I can only load the backbone parameters, and loading both results in an error, even when using the full checkpoint which contains weights for both the backbone and projection head

About unfreeze and freeze function.

Thank for your code.
"
from util.general_utils import str2bool, get_params_groups, finetune_params, freeze, unfreeze, cosine_lr
"
I can't find codes about 'freeze' and 'unfreeze'.
Meanwhile, pretrained vit can't be loaded on classifier.
"
projector = DINOHead(in_dim=args.feat_dim, out_dim=args.num_ctgs, nlayers=args.num_mlp_layers)
classifier = nn.Sequential(backbone, projector).cuda()
state_dict = torch.load(args.pretrained_model_path, map_location='cpu')
classifier.load_state_dict(state_dict)
"
Maybe classifier.load_state_dict(state_dict, strict = False) or backbone.load_state_dict(state_dict).

About abstract description

I would like to ask which SOTA method is that "Notably, we find our method achieves an average accuracy of 61.4% on the SSB, surpassing prior state-of-the-art methods by approximately 10%." in abstract. Looking forward to your reply.

Out of memory problem

Hello, Iβ€˜ve tried to run the training code, and I have succeded loading the pretrained simgcd weights, but when the code run to

feats = backbone(prompter(images))

it already consumed the max memory of 3090, leading to the corruption.
I did not change any code related to the training. Do you have any suggestion? Thank you very much!

About the performance

Hello, I have tried to train the model using simgcd pretrained weights on CUB. SCARS, FGVC, but the performance I get from the model is almost the same as the pretrained simgcd weights, I wonder if there are any details I overlooked? I run the training using exactly the provided command.

How to properly train SPTNet

Thank you for your cool work on GCD.

I ran the training script using the script in the Readme and tried to train a model based on DINO pretraining on the CUB dataset, but it seems that there are issues with the results.

CUDA_VISIBLE_DEVICES=0 python train_spt.py \
    --dataset_name 'CUB' \
    --batch_size 128 \
    --grad_from_block 11 \
    --epochs 1000 \
    --num_workers 8 \
    --use_ssb_splits \
    --sup_weight 0.35 \
    --weight_decay 5e-4 \
    --transform 'imagenet' \
    --lr 1 \
    --lr2 0.05 \
    --prompt_size 1 \
    --freq_rep_learn 20 \
    --pretrained_model_path ./pretrained/dino_vitbase16_pretrain.pth \
    --prompt_type 'all' \
    --eval_funcs 'v2' \
    --warmup_teacher_temp 0.07 \
    --teacher_temp 0.04 \
    --warmup_teacher_temp_epochs 10 \
    --memax_weight 1 \
    --model_path ./model_save

Here is the results.txt , which records the accuracy changes of each epoch during the training of 1000 epochs.

result.txt

What parameters do I need to modify to reproduce the results in the paper?

I look forward to your response and would like to thank you once again for your great work !

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.