GithubHelp home page GithubHelp logo

mcl's Introduction

Learning the Unlearned: Mitigating Feature Suppression in Contrastive Learning

This is the official implementation of Multistage Contrastive Learning (MCL) proposed in Learning the Unlearned: Mitigating Feature Suppression in Contrastive Learning on ECCV 2024

Set up

To get started, create and activate a Conda environment using the provided environment.yml file:

conda env create -f environment.yml
conda activate fs

Evaluation on MMVP

Here we provide the weights of the CLIP ViT models obtained by MCL as described in Section 5.5 of our main paper. For evaluating the MCL fine-tuned Vision Transformer (ViT) version of CLIP models on MMVP, please download the MMVP_VLM Benchmark to ./MMVP_VLM and refer to MMVP.ipynb. The fine-tuned model weights are available on OSF. Please download the weights and tar -xzvf to ./weights.

Fine-Tuning CLIP with MCL

Begin by downloading the CC12M dataset using img2dataset. Note that we currently do not support the webdataset format. Please download the dataset in the standard image and text file format (--output_format files)

To initialize the MCL process, perform inference using a vanilla model and conduct clustering to generate pseudo labels for Stage 1 training:

python cluster.py --img-file-path <path to CC12M> --modelarch ViT-L-14 --pretrained openai --stage 0 --num-clusters 10

This command saves the cluster centroids, labels, and pseudo labels of Stage 0 into the ./save directory.

Subsequently, the pseudo labels obtained can be utilized for Stage 1 tuning:

torchrun --nproc_per_node 8 main.py \
    --train-data <path to CC12M> \
    --batch-size 1000 \
    --precision amp \
    --workers 16 \
    --MCL-label-path './save/ViT-L-14_0_pseudo_labels.pt' \
    --epochs 20 \
    --pretrained openai \
    --model ViT-L-14 \
    --force-quick-gelu \
    --report-to tensorboard \
    --zeroshot-frequency 1 \
    --dataset-type files \
    --ddp-static-graph \
    --gather-with-grad \
    --lock-text \
    --lr-scheduler const \
    --warmup 600 \
    --lock-image \
    --lock-image-unlocked-groups 6

The fine-tuned model will be saved in ./log. You can then iteratively perform clustering to generate pseudo labels for Stage 2 tuning.

Adapt MCL to other project

MCL is model-agnostic, and can easily be adapted to other contrastive learning models (e.g., variations of CLIP and SimCLR). Without changing the model's code, you only need to add a clustering process (e.g., K-means) to generate the pseudo labels, and add the feature-aware negative sampling to sample the negative samples according to the pseudo labels in your training loop. The feature-aware negative sampling mechanism is detailed in ./my_training/data.py at line 587.

Acknowledgement

  • Open_clip The codebase we built upon.
  • MMVP We use MMVP for evaluation.

mcl's People

Contributors

majordavidzhang avatar

Stargazers

Zican Hu avatar  avatar Tong Zhu (朱桐) avatar Lyxx avatar Xiaoye Qu avatar Beichen Zhang avatar Guanjie Chen avatar  avatar LI ZHI avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.