GithubHelp home page GithubHelp logo

task_vectors's Introduction

Editing Models with Task Arithmetic

This repository contains code for the ICLR 2023 paper Editing Models with Task Arithmetic, by Gabriel Ilharco, Marco Tulio Ribeiro, Mitchell Wortsman, Suchin Gururangan, Ludwig Schmidt, Hannaneh Hajishirzi and Ali Farhadi.

Abstract

Changing how pre-trained models behave---e.g., improving their performance on a downstream task or mitigating biases learned during pre-training---is a common practice when developing machine learning systems. In this work, we propose a new paradigm for steering the behavior of neural networks, centered around task vectors. A task vector specifies a direction in the weight space of a pre-trained model, such that movement in that direction improves performance on the task. We build task vectors by subtracting the weights of a pre-trained model from the weights of the same model after fine-tuning on a task. We show that these task vectors can be modified and combined together through arithmetic operations such as negation and addition, and the behavior of the resulting model is steered accordingly. Negating a task vector decreases performance on the target task, with little change in model behavior on control tasks. Moreover, adding task vectors together can improve performance on multiple tasks at once. Finally, when tasks are linked by an analogy relationship of the form ``A is to B as C is to D", combining task vectors from three of the tasks can improve performance on the fourth, even when no data from the fourth task is used for training. Overall, our experiments with several models, modalities and tasks show that task arithmetic is a simple, efficient and effective way of editing models.

Summary figure

scatter

An illustration of task vectors and the arithmetic operations we study for editing models. (a) A task vector is obtained by subtracting the weights of a pre-trained model from the weights of the same model after fine-tuning. (b) Negating a task vector degrades performance on the task, without substantial changes in control tasks. (c) Adding task vectors together improves the performance of the pre-trained model on the tasks under consideration. (d) When tasks form an analogy relationship such as supervised and unsupervised learning on two different data sources, it is possible to improve performance on a supervised target task using only vectors from the remaining three combinations of objectives and datasets.

Code

Install dependencies

conda env create
conda activate task-vectors

Add directory to PYTHONPATH:

cd task_vectors
export PYTHONPATH="$PYTHONPATH:$PWD"

Using task vectors

The task vector logic can be found at src/task_vectors.py.

To create a task vector, you will need a pre-trained checkpoint and a fine-tuned checkpoint:

from task_vectors import TaskVector
task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint)

Once created, task vectors can be modified and combined through arithmetic operations! For instance, to negate a task vector, simply use the - operator:

# Negating a task vector
new_task_vector = -task_vector

To add task vectors, you can use the + operator, or sum:

# Adding two task vectors
new_task_vector = task_vector_A + task_vector_B
# Adding multiple task vectors
new_task_vector = sum(list_of_task_vectors)

Analogies can be done as simply as:

# Task analogies
new_task_vector = task_vector_C + task_vector_B - task_vector_A

Checkpoints

Checkpoints for CLIP ViT-B/32, ViT-B/16 and ViT-L/14 are available on he link below, including fine-tuned checkpoints on eight downstream tasks: Stanford Cars, DTD, EuroSAT, GTSRB, MNIST, RESISC45, SUN397 and SVHN.

Download here

Examples

Below is an example of negating a task vector from MNIST, then evaluating on MNIST and on ImageNet:

import torch
from task_vectors import TaskVector
from eval import eval_single_dataset
from args import parse_arguments

# Config
dataset = 'MNIST'
model = 'ViT-L-14'
args = parse_arguments()
args.data_location = '/path/to/data'
args.model = model
args.save = f'checkpoints/{model}'
pretrained_checkpoint = f'checkpoints/{model}/zeroshot.pt'
finetuned_checkpoint = f'checkpoints/{model}/{dataset}/finetuned.pt'


# Create the task vector
task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint)
# Negate the task vector
neg_task_vector = -task_vector
# Apply the task vector
image_encoder = neg_task_vector.apply_to(pretrained_checkpoint, scaling_coef=0.5)
# Evaluate
eval_single_dataset(image_encoder, dataset, args)
eval_single_dataset(image_encoder, 'ImageNet', args)

You can also find an example of adding task vectors together below, using the MNIST and RESISC45 datasets:

import torch
from task_vectors import TaskVector
from eval import eval_single_dataset
from args import parse_arguments

# Config
datasets = ['MNIST', 'RESISC45']
model = 'ViT-L-14'
args = parse_arguments()
args.data_location = '/path/to/data'
args.model = model
args.save = f'checkpoints/{model}'
pretrained_checkpoint = f'checkpoints/{model}/zeroshot.pt'

# Create the task vectors
task_vectors = [
    TaskVector(pretrained_checkpoint, f'checkpoints/{model}/{dataset}/finetuned.pt')
    for dataset in datasets
]
# Sum the task vectors
task_vector_sum = sum(task_vectors)
# Apply the resulting task vector
image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=0.8)
# Evaluate
for dataset in datasets:
    eval_single_dataset(image_encoder, dataset, args)

task_vectors's People

Contributors

gabrielilharco avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

task_vectors's Issues

Corrupted Checkpoint ViT-L-14/Cars/finetuned.pt

Hi @gabrielilharco, when I download the ViT-L-14/Cars/finetuned.pt checkpoint from google drive and try to load the checkpoint, I am getting an error that the checkpoint is corrupted. All the other checkpoints work fine. I have tried downloading it multiple times and it is still not working. Would it be possible for you to reupload ViT-L-14/Cars/finetuned.pt checkpoint?

Split definition of DTD, EuroSAT and SUN397

Hi, awesome work!

I'm trying to reproduce your results but I cannot find the split definitions you use for DTD, EuroSAT and SUN397. Would you mind pointing me to the right resources to download the versions of these datasets compatible with your code?

Thanks a lot!

Clarification about classification heads

Hi @gabrielilharco,
thanks for the great work!

I have a question about classification heads used in your experiments and available here. How exactly did you train them? Looking at the code I can see that that you manually construct a zero shot classifier based on embeddings of class names put in various templates. Importantly, you use the pretrained OpenCLIP model to calculate embeddings for classifier, not the model finetuned for a particular task. Am I right about that?

What is the reason for that? To me, the most natural way to obtain the classifier would be to get model.classification_head after the finetuning of a task specific model (here). This classification head is aligned with the finetuned model while the zeroshot head is aligned with pretrained model therefore the head from finetuning seems more suitable. Did you consider such an approach?

Task vectors for GPT2 model: attn.bias weights ignored

I am trying to use task vectors for GPT2 models.

In the TaskVectors class, when the task vector is created, there is a condition which ignores keys in the state dict that have dtype of uint8.

Due to this condition, when I call the apply_to() method of a task vector instance by passing a GPT-2 model checkpoint, I get the following error.

Warning: key transformer.h.0.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.1.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.2.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.3.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.4.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.5.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.6.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.7.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.8.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.9.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.10.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.11.attn.bias is present in the pretrained state dict but not in the task vector

Is there a reason why state dict keys with dtype torch.uint8 are ignored.

When that condition is removed, the code to run without any errors.

Please suggest what should be the best thing to do here.

ImageNet and its split

Dear authors,

I am trying to reproduce you work based on this repo. Now I encounter a problem. It seems that ImageNet is not downloaded automatically in your repo. So, which ImageNet did you adopt? ILSVRC_2012? And any other changes ought to be applied to the datasets?

Best regards,
Hongduan

Classification Heads for Val Split.

Hi @gabrielilharco, I was trying to run some experiments on the validation dataset. I learned from the other issues that I need to use [dataset]Val to use the validation split. When trying this the code started to create a head for the validation split of the dataset. Is this the expected behavior? I thought that a single classifier would work for all splits. Am I missing something here?

Requesting the Multitask Checkpoint For Learning via Addition

Hi @gabrielilharco, I see that in Appendix D2 you mentioned that you have tried training a multitask checkpoint on the eight vision tasks (SUN397 Cars RESISC45 EuroSAT SVHN GTSRB MNIST DTD) for the learning via addition experiments. I can see that in Appendix D.2 you report the multitask normalized performance to be 99.4. Did you share the raw numbers on each task somewhere? If you can share these checkpoints for ViT-B-32 and ViT-L-14 that would be really helpful.

Thanks in advance,
Prateek

Give a pre-trained model that can be loaded directly using the ‘model_weights = torch.load(file_path, map_location='cpu')’

While trying to load the file using Python’s pickle module, I encountered an _pickle.UnpicklingError, stating that persistent IDs in protocol 0 must be ASCII strings. Here is the exact error message:

_pickle.UnpicklingError: persistent IDs in protocol 0 must be ASCII strings

I attempted to resolve the issue by employing various methods, including utilizing the persistent_load parameter with pickle.Unpickler and trying to load the file in different environments, but unfortunately, all efforts have been in vain.

Request for Assistance:
Given the circumstances, I was hoping you could provide some insights or guidance on the following points:

Creation Environment: Could you share details about the environment in which the file was created, including the Python and PyTorch versions used?
Persistent IDs: Any information or context regarding the persistent IDs encountered in the file would be immensely helpful.
Loading Method: If there is a specific method or procedure to correctly load the file, could you please share it with me?
Additional Details: Any other details or specifications about the file that you think might assist in resolving the issue would be greatly appreciated.

Can't get attribute 'VisualTransformer' on <module 'open_clip.model' from '~/open_clip/model.py'>

Hi @gabrielilharco ,
Thank you for your exciting work! I tried to replicate the result using code from README.md. It showed an error when running

# Create the task vector
task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint)

The error is:
AttributeError: Can't get attribute 'VisualTransformer' on <module 'open_clip.model' from '/srv/home/<user_name>/anaconda3/envs/task-vectors/lib/python3.10/site-packages/open_clip/model.py'>

The error is from loading the model with trained weights pretrained_state_dict = torch.load(pretrained_checkpoint).state_dict().

Could you help me with this? Thanks!

Value for args.data_location

It's unclear to me what should I assign as the value for args.data_location. The README tells args.data_location = '/path/to/data' but I'm not sure which folder that means.

checkpoint of linearized models

I found your analysis of linearized fine-tuning to be both fascinating and helpful, and I’m currently attempting to reproduce it. However, I’m encountering some issues because I’m using a Mac with an M1 chip, which doesn’t support DDP and causes compatibility problems with JAX.
Would it be possible for you to share the checkpoint for the linearized fine-tuning? It would be incredibly convenient and helpful for my work as a contribution to interpretability. I would greatly appreciate it!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.