GithubHelp home page GithubHelp logo

gt-ripl / coda-prompt Goto Github PK

View Code? Open in Web Editor NEW
113.0 7.0 11.0 4.82 MB

PyTorch code for the CVPR'23 paper: "CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning"

License: MIT License

Python 91.60% Shell 8.40%

coda-prompt's Introduction

CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning

PyTorch code for the CVPR 2023 paper:
CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning
James Smith, Leonid Karlinsky, Vyshnavi Gutta, Paola Cascante-Bonilla
Donghyun Kim, Assaf Arbelle, Rameswar Panda, Rogerio Feris, Zsolt Kira
IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2023
[arXiv]

Abstract

Computer vision models suffer from a phenomenon known as catastrophic forgetting when learning novel concepts from continuously shifting training data. Typical solutions for this continual learning problem require extensive rehearsal of previously seen data, which increases memory costs and may violate data privacy. Recently, the emergence of large-scale pre-trained vision transformer models has enabled prompting approaches as an alternative to data-rehearsal. These approaches rely on a key-query mechanism to generate prompts and have been found to be highly resistant to catastrophic forgetting in the well-established rehearsal-free continual learning setting. However, the key mechanism of these methods is not trained end-to-end with the task sequence. Our experiments show that this leads to a reduction in their plasticity, hence sacrificing new task accuracy, and inability to benefit from expanded parameter capacity. We instead propose to learn a set of prompt components which are assembled with input-conditioned weights to produce input-conditioned prompts, resulting in a novel attention-based end-to-end key-query scheme. Our experiments show that we outperform the current SOTA method DualPrompt on established benchmarks by as much as 5.4% in average accuracy. We also outperform the state of art by as much as 6.6% accuracy on a continual learning benchmark which contains both class-incremental and domain-incremental task shifts, corresponding to many practical settings.

Important notice!

We received feedback from other researchers that our orthogonality initialization might not be “in the spirit” of continual learning. So, we fixed it! Our code now uses the Gram-Schmidt process at the start of each new task to initialize new prompting parameters. We found that, with this initialization, we can even set the ortho penalty of our method to 0.0! We hope this might inspire future work, as well. Send me a message if you want to discuss more (email preferred).

Setup

  • Install anaconda: https://www.anaconda.com/distribution/
  • set up conda environment w/ python 3.8, ex: conda create --name coda python=3.8
  • conda activate coda
  • sh install_requirements.sh
  • NOTE: this framework was tested using torch == 2.0.0 but should work for previous versions

Datasets

Training

All commands should be run under the project root directory. The scripts are set up for 4 GPUs but can be modified for your hardware.

sh experiments/cifar100.sh
sh experiments/imagenet-r.sh
sh experiments/domainnet.sh

Results

Results will be saved in a folder named outputs/. To get the final average accuracy, retrieve the final number in the file outputs/**/results-acc/global.yaml

Ready to create your next method?

Create your new prompting method in models/zoo.py, which will require you to create a new class in learners/prompt.py as well. Hopefully, you can create your next method while only modifying these two files! I also recommend you develop with the ImageNet-R benchmark and use fewer epochs for faster results. Cannot wait to see what method you develop!

Model backbone

For fair comparisons with our method and results, please see models/zoo.py to take or replace the exact pre-trained vit_base_patch16_224 weights used in our repo.

Note on setting

Our setting is rehearsal-free class-incremental continual learning. Our method has not been tested for other settings such as domain-incremental continual learning.

Acknowledgement

This material is based upon work supported by the National Science Foundation under Grant No. 2239292.

Citation

If you found our work useful for your research, please cite our work:

@InProceedings{Smith_2023_CVPR,
    author    = {Smith, James Seale and Karlinsky, Leonid and Gutta, Vyshnavi and Cascante-Bonilla, Paola and Kim, Donghyun and Arbelle, Assaf and Panda, Rameswar and Feris, Rogerio and Kira, Zsolt},
    title     = {CODA-Prompt: COntinual Decomposed Attention-Based Prompting for Rehearsal-Free Continual Learning},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {11909-11919}
}

coda-prompt's People

Contributors

jamessealesmith avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

coda-prompt's Issues

Regarding prompt selection frequency

Hi @jamessealesmith , thanks for your amazing work!

I was wondering if you checked the prompt selection frequency in CODA prompt and L2P?

Any idea if your reproduction of L2P has the issue reported here and here about the same set of prompts being selected everytime?

Question about the imagenet-r with covariate domain shift.

Hello, as you reported in the paper, the imagenet-r with covariate domain shift is a new benchmark evaluating the semantic and covariate shift, which can be seen as a task-agnostic incremental learning evaluation. Can you provide the benchmark in the codebase?

On orthogonality penalty loss

Hi authors, thanks for your wonderful work and code.
I see that the weight coefficient of the orthogonal loss term in your code defaults to 0. I tried to set it to 0.1 as set in your paper. It was found that the average accuracy actually decreased in the end. Is there any problem with this? I would greatly appreciate it if you could reply to me.

A question about data preprocessing

Thank you very much for providing the source code! I found that in the data preprocessing stage, the normalized mean and variance were set to 0 and 1 instead of the mean and variance of ImageNet. Is this a bug? Is this consistent with the preprocessing in the ImageNet pretraining stage? I would be very grateful if you could reply.

Orthogonal penalty calculation

def ortho_penalty(t):
return ((t @t.T - torch.eye(t.shape[0]).cuda())**2).mean() * 1e-6

The above is the orthogonal penalty calculation in your zoo.py, but why multiply it by 1e-6, which is not explained in the paper?

Reproducing issue for upper-bound

Hi, thank you for your amazing paper.

I added two new parameters, 'oracle_flag' and 'upper_bound_flag,' to the cifar-100.sh and set the schedule to 200 in cifar-100_prompt.yaml. I conducted experiments both with and without coda-prompt, and the obtained results were significantly better than those reported in the paper, achieving accuracies of 91% and 94%, respectively. Moreover, it appears that the results can still be further improved with additional iterations. Similarly, on the ImageNet-R dataset, the accuracy achieved far surpasses what is documented in the paper.

So, I am curious to understand how the 'upper_bound' is calculated.

error when setting prompt length to an odd number

Hi authors, thanks for your wonderful work and code.

I try to run the code of CODAprompt to find the best prompt lengths.
When setting prompt length as an even number, the code works.
However, when setting prompt length to an odd number, runtime error happens. And it seems to result from the mismatch of two tensors.
Here is the specific error when setting prompt length to 5.
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [768, 199] but got: [768, 200].

I'd be appreciate if you can fix this.

Question about the implementation.

Hi, you did impressive work and the codebase you released is great. However, I am confused about the implementation. I met some obstacles and I really value your opinions.

As described in the paper, the equation $q(x) \odot A$ is the Hadamard product while the implementation looks like the outer product:

# with attention and cosine sim
# (b x 1 x d) * soft([1 x k x d]) = (b x k x d) -> attention = k x d
a_querry = torch.einsum('bd,kd->bkd', x_querry, A)

Another question is that I don't know how to get the subsequent cosine similarity, i.e. how to transform a $b\times k\times d$ matrix and a $k\times d$ matrix to a $b\times k$ matrix:

aq_k = torch.einsum('bkd,kd->bk', q, n_K)

Could you please tell me how the torch.einsum works to get the cosine similarity matrix? I would be grateful if you can solve my confusion.

Thank you so much.

Calculation for Forgetting

Hi, thanks for the amazing work!
We have run the code you provided, but we found that the results of the run did not contain the metrics for forgetting rate reported in the paper. We would like to know if you have provided the code to calculate the forgetting, or maybe we didn't find it. If it was not provided, we would like to know if you will upload if soon?
Thanks again!

About the hardware requirements

Hi, thanks for the code, and it is awesome work.
I am trying to run the code but unfortunately, an OOM error was thrown when running ./experiments/cifar-100.sh.

My GPU:
2x Nvidia RTX 3090

Could you please release the requirements of the hardware?
Thanks.

Question about the cosine scheduler

Dear author,

Thank you for your fantastic work. I have a question regarding the number "200" inside the cosine scheduler. Is there a reason for setting it to be 200 here, and do I have to change this value If I were to train with different number of epochs?

class CosineSchedule(_LRScheduler):

    def __init__(self, optimizer, K):
        self.K = K
        super().__init__(optimizer, -1)

    def cosine(self, base_lr):
        return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (self.K-1)))

    def get_lr(self):
        return [self.cosine(base_lr) for base_lr in self.base_lrs]

Best regards,
Louis

Question about FT and FT++

Hi, thank you for your fantastic paper. Your work has truly impressed me. However, I find myself uncertain about a particular aspect and would greatly appreciate your guidance. Specifically, I am wondering difference between FT and FT++.

# FT
python -u run.py --config $CONFIG_FT --gpuid $GPUID --repeat $REPEAT --overwrite $OVERWRITE \
    --learner_type default --learner_name FinetunePlus \
    --log_dir ${OUTDIR}/ft++

# FT++
python -u run.py --config $CONFIG_FT --gpuid $GPUID --repeat $REPEAT --overwrite $OVERWRITE \
    --learner_type default --learner_name NormalNN \
    --log_dir ${OUTDIR}/ft

Reproducing issue for ImageNet-R dataset

Thanks for sharing your wonderful work!

I'm trying to run the code using the ImageNet-R dataset.
However, I obtained a lower performance than the results reported in the paper.
I checked all the details listed in the supplementary material.
Does it need additional techniques to achieve the performance of the paper?

Sincerely,

domainNet setting?

Hi, I am very interested in your wonderful work. Could you provide the code to run the domainNet? Thanks so much

Does CODA-Prompt needs Memory buffer?

Hello, first of all, thank you for your amazing paper!
I found out that train_dataset has the "append_coreset" function at trainer.py and "update_coreset" at learners/default.py. According to the paper, CODA-Prompt aims for rehearsal-free continual learning. then why need these coreset functions?

Sincerely

class-incremental settings?

Interesting work!
During inference, the CodaPrompt.forward function is called, where self.task_count is a required variable, whether or not it indicates that the CodaPrompt method requires test-time task identifiers and is not really a class-incremental learning?

New prompt keys for each layer.

This repo uses a separate pool of keys for each layer for DualPrompt implementation which is different from the original implementation of the DualPrompt paper.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.