GithubHelp home page GithubHelp logo

xmed-lab / allspark Goto Github PK

View Code? Open in Web Editor NEW
54.0 4.0 8.0 14.41 MB

CVPR 2024: AllSpark: Reborn Labeled Features from Unlabeled in Transformer for Semi-Supervised Semantic Segmentation

License: MIT License

Python 97.68% Shell 2.32%
attention cvpr2024 semi-supervised-segmentation transformer semantic-segmentation

allspark's Introduction

[CVPR-2024] AllSpark: Reborn Labeled Features from Unlabeled in Transformer for Semi-Supervised Semantic Segmentation

PWC
PWC
PWC
PWC
PWC
PWC
PWC
PWC
PWC
PWC
PWC
PWC

This repo is the official implementation of AllSpark: Reborn Labeled Features from Unlabeled in Transformer for Semi-Supervised Semantic Segmentation which is accepted at CVPR-2024.

The AllSpark is a powerful Cybertronian artifact in the film series of Transformers. It was used to reborn Optimus Prime in Transformers: Revenge of the Fallen, which aligns well with our core idea.


๐Ÿ’ฅ Motivation

In this work, we discovered that simply converting existing semi-segmentation methods into a pure-transformer framework is ineffective.

  • The first reason is that transformers inherently possess weaker inductive bias compared to CNNs, so transformers heavily rely on a large volume of training data to perform well.

  • The more critical issue lies in the existing semi-supervised segmentation frameworks. These frameworks separate the training flows for labeled and unlabeled data, which aggravates the overfitting issue of transformers on the limited labeled data.

Thus, we propose to intervene and diversify the labeled data flow with unlabeled data in the feature domain, leading to improvements in generalizability.


๐Ÿ› ๏ธ Usage

โ€ผ๏ธ IMPORTANT: This version is not the final version. We made some mistakes when re-organizing the code. We will release the correct version soon. Sorry for any inconvenience this may cause.

1. Environment

First, clone this repo:

git clone https://github.com/xmed-lab/AllSpark.git
cd AllSpark/

Then, create a new environment and install the requirements:

conda create -n allspark python=3.7
conda activate allspark
pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu116
pip install tensorboard
pip install six
pip install pyyaml
pip install -U openmim
mim install mmcv==1.6.2
pip install einops
pip install timm

2. Data Preparation & Pre-trained Weights

2.1 Pascal VOC 2012 Dataset

Download the dataset with wget:

wget https://hkustconnect-my.sharepoint.com/:u:/g/personal/hwanggr_connect_ust_hk/EcgD_nffqThPvSVXQz6-8T0B3K9BeUiJLkY_J-NvGscBVA\?e\=2b0MdI\&download\=1 -O pascal.zip
unzip pascal.zip

2.2 Cityscapes Dataset

Download the dataset with wget:

wget https://hkustconnect-my.sharepoint.com/:u:/g/personal/hwanggr_connect_ust_hk/EWoa_9YSu6RHlDpRw_eZiPUBjcY0ZU6ZpRCEG0Xp03WFxg\?e\=LtHLyB\&download\=1 -O cityscapes.zip
unzip cityscapes.zip

2.3 COCO Dataset

Download the dataset with wget:

wget https://hkustconnect-my.sharepoint.com/:u:/g/personal/hwanggr_connect_ust_hk/EXCErskA_WFLgGTqOMgHcAABiwH_ncy7IBg7jMYn963BpA\?e\=SQTCWg\&download\=1 -O coco.zip
unzip coco.zip

Then your file structure will be like:

โ”œโ”€โ”€ VOC2012
    โ”œโ”€โ”€ JPEGImages
    โ””โ”€โ”€ SegmentationClass
    
โ”œโ”€โ”€ cityscapes
    โ”œโ”€โ”€ leftImg8bit
    โ””โ”€โ”€ gtFine
    
โ”œโ”€โ”€ coco
    โ”œโ”€โ”€ train2017
    โ”œโ”€โ”€ val2017
    โ””โ”€โ”€ masks

Next, download the following pretrained weights.

โ”œโ”€โ”€ ./pretrained_weights
    โ”œโ”€โ”€ mit_b2.pth
    โ”œโ”€โ”€ mit_b3.pth
    โ”œโ”€โ”€ mit_b4.pth
    โ””โ”€โ”€ mit_b5.pth

For example, mit-B5:

mkdir pretrained_weights
wget https://hkustconnect-my.sharepoint.com/:u:/g/personal/hwanggr_connect_ust_hk/ET0iubvDmcBGnE43-nPQopMBw9oVLsrynjISyFeGwqXQpw?e=9wXgso\&download\=1 -O ./pretrained_weights/mit_b5.pth

3. Training & Evaluating

# use torch.distributed.launch
sh scripts/train.sh <num_gpu> <port>
# to fully reproduce our results, the <num_gpu> should be set as 4 on all three datasets
# otherwise, you need to adjust the learning rate accordingly

# or use slurm
# sh scripts/slurm_train.sh <num_gpu> <port> <partition>

To train on other datasets or splits, please modify dataset and split in train.sh.

4. Results

Model weights and training logs will be released soon.

4.1 PASCAL VOC 2012 original

Splits 1/16 1/8 1/4 1/2 Full
Weights of AllSpark 76.07 78.41 79.77 80.75 82.12
Reproduced 76.06 | log 78.41 79.93 | log 80.70 | log 82.56 | log

4.2 PASCAL VOC 2012 augmented

Splits 1/16 1/8 1/4 1/2
Weights of AllSpark 78.32 79.98 80.42 81.14

4.3 Cityscapes

Splits 1/16 1/8 1/4 1/2
Weights of AllSpark 78.33 79.24 80.56 81.39

4.4 COCO

Splits 1/512 1/256 1/128 1/64
Weights of AllSpark 34.10 | log 41.65 | log 45.48 | log 49.56 | log

Citation

If you find this project useful, please consider citing:

@inproceedings{allspark,
  title={AllSpark: Reborn Labeled Features from Unlabeled in Transformer for Semi-Supervised Semantic Segmentation},
  author={Wang, Haonan and Zhang, Qixiang and Li, Yi and Li, Xiaomeng},
  booktitle={CVPR},
  year={2024}
}

Acknowlegement

AllSpark is built upon UniMatch and SegFormer. We thank their authors for making the source code publicly available.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.