GithubHelp home page GithubHelp logo

lsrock1 / rgb-no-more Goto Github PK

View Code? Open in Web Editor NEW

This project forked from jeongsoop/rgb-no-more

0.0 0.0 0.0 4.39 MB

An official code release of the paper RGB no more: Minimally Decoded JPEG Vision Transformers

Home Page: http://www-personal.umich.edu/~jespark/rgbnomore-2023/rgbnomore_web.html

License: Other

Shell 87.07% C++ 1.20% Python 11.73%

rgb-no-more's Introduction

RGB no more: Minimally Decoded JPEG Vision Transformers

An official code release of the paper RGB no more: Minimally Decoded JPEG Vision Transformers
This repository contains the pipeline for training JPEG ViT, including DCT augmentations.

If you use this repo, please cite our paper:

BibTex:

@InProceedings{Park_2023_CVPR,
    author    = {Park, Jeongsoo and Johnson, Justin},
    title     = {RGB No More: Minimally-Decoded JPEG Vision Transformers},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {22334-22346}
}

Plain text:

Jeongsoo Park, and Justin Johnson. "RGB no more: Minimally-decoded JPEG Vision Transformers." In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 22334-22346. 2023.

This readme heavily borrows the format by SwinV2.

Pretrained Models

Pretrained models are available as:

Model Val Acc (%) Link
JPEG-Ti 75.1 link
ViT-Ti 74.1 link
JPEG-S 76.5 link
ViT-S 76.5 link
SwinV2-T
(DCT, window=8)
79.4 link
SwinV2-T
(RGB, window=8)
79.0 link

It is possible to download using wget using wget [link]. One example:

wget http://www-personal.umich.edu/~jespark/rgbnomore-2023/imgnetDCTViTTi_ep300_75.1.pth

Note: RGB ViT-S is trained on raw ImageNet to reproduce the recipe by Google. Others are trained on 512x512 resized ImageNet.

Usage

Install

Note that modified libjpeg requires Linux to compile properly. Other OS are not supported.

  • Clone this repo:
git clone https://github.com/JeongsooP/RGB-no-more.git
cd RGB-no-more
  • Create a conda virtual environment and activate it:
conda create -n rgbnomore python=3.10
conda activate rgbnomore
  • Install PyTorch>=1.12.1 CUDA>=11.3 (we used PyTorch=1.12.1 and CUDA=11.3 in our paper):
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 "llvm-openmp<16.0" -c pytorch

We recommend using llvm-openmp<16.0 until the multiprocessing issue is fixed.

  • Install gcc_linux-64==12.2.0, gxx_linux-64==12.2.0 and other necessary libraries:
conda install -c conda-forge gcc_linux-64==12.2.0 gxx_linux-64==12.2.0 torchmetrics torchinfo tensorboard einops scipy yacs pandas timm imageio iopath psutil

Make sure the versions for gcc_linux-64 and gxx_linux-64 are exactly 12.2.0.

  • Install other requirements:
pip install -r requirements.txt
  • Compile dct_manip -- a modified libjpeg handler:

    • Open dct_manip/setup.py

    • Modify include_dirs and library_dirs to include your include and library folder.

    • Modify extra_objects to the path containing libjpeg.so

    • Modify headers to the path containing jpeglib.h

    • Run cd dct_manip

    • Run pip install .

Data preparation

  • Download the ImageNet dataset from http://image-net.org/.
  • Place the tar files and valprep.sh in the same folder.
  • The dataloader will automatically organize the data as following:
    imgnet_data
    ├── train
    │   ├── n01440764
    │   │   ├── n01440764_10026.jpeg
    │   │   ├── n01440764_10027.jpeg
    │   │   └── ...
    │   ├── n01443537
    │   │   ├── n01443537_10007.jpeg
    │   │   └── ...
    │   └── ...
    └── val
        ├── n01440764
        │   ├── ILSVRC2012_val_00000293.jpeg
        │   ├── ILSVRC2012_val_00002138.jpeg
        │   └── ...
        ├── n01443537
        │   ├── ILSVRC2012_val_00000236.jpeg
        │   └── ...
        └── ...

Images will be resized to 512x512 unless --no_resize flag is used.

Pipeline

Arguments (click to expand)
usage: train.py [-h] [--port PORT] [--model_arch MODEL_ARCH] [--no_subblock] [--embed_type EMBED_TYPE] [--domain DOMAIN] [--datapath DATAPATH] [--temp_datapath TEMP_DATAPATH] [--indexpaths INDEXPATHS] [--delete_dataset] [--no_extract] [--no_resize] [--num_gpus NUM_GPUS] [--num_cpus NUM_CPUS] [--train] [--eval] [--benchmark BENCHMARK] [--savepath SAVEPATH]
              [--loadpath LOADPATH] [--load_ckpt LOAD_CKPT] [--deterministic] [--verbose VERBOSE] [--epochs EPOCHS] [--batch BATCH] [--lr LR] [--wd WD] [--drop DROP] [--warmup_steps WARMUP_STEPS] [--ops_list OPS_LIST] [--num_ops NUM_OPS] [--ops_magnitude OPS_MAGNITUDE] [--amp AMP] [--ampdtype AMPDTYPE] [--seed SEED] [--use_msrsync]

options:
-h, --help            show this help message and exit
--port PORT           Port for pytorch distributed dataparallel
--model_arch MODEL_ARCH
                      Model architecture (vitti, vits, vitb, vitl, swinv2)
--no_subblock         If set, disable subblock conversion
--embed_type EMBED_TYPE
                      Embedding layer type. (1: grouped, 2: separate, 3: concatenate). Default 1
--domain DOMAIN       (DCT/RGB) Choose domain type
--datapath DATAPATH   Path to folder containing the .tar files
--temp_datapath TEMP_DATAPATH
                      Path to extract .tar files
--indexpaths INDEXPATHS
                      Path to train/val index files. Separated by comma.
--delete_dataset      If set, delete dataset after train or eval
--no_extract          If set, assume .tar is already extracted to temp_datapath
--no_resize           If set, do not resize images in temp_datapath to 512x512
--num_gpus NUM_GPUS   number of GPUs to use. If not set, automatically use all available GPUs
--num_cpus NUM_CPUS   number of total available cpu threads
--train               Train new model
--eval                Evaluate model loaded from ``savepath``
--benchmark BENCHMARK
                      If set, benchmark for the set iterations
--savepath SAVEPATH   Save path for model. Also saves checkpoint at this path
--loadpath LOADPATH   Load path for model. Used during evaluation. If empty, copy savepath
--load_ckpt LOAD_CKPT
                      If set, load checkpoint from this path
--deterministic       If set, use deterministic mode
--verbose VERBOSE     (0/1/2) 0: no output, 1: output per epoch, 2: output per iteration
--epochs EPOCHS       Override the number of epochs
--batch BATCH         Override the size of batch (overall batch size)
--lr LR               Override the learning rate
--wd WD               Override the weight decay strength
--drop DROP           Override dropout probability
--warmup_steps WARMUP_STEPS
                      Override warmup steps
--ops_list OPS_LIST   Override augmentation list
--num_ops NUM_OPS     Override number of operations
--ops_magnitude OPS_MAGNITUDE
                      Override augmentation magnitude
--amp AMP             (True:1/False:0) Override automatic mixed precision
--ampdtype AMPDTYPE   Override amp dtype casting
--seed SEED           Override random seed
--use_msrsync         If set, use msrsync instead of .tar

Sample bash script for all models are included in ./job_bash. Make sure to modify --datapath, --temp_datapath, --num_gpus, --num_cpus, --savepath, --loadpath, --verbose and any other flags to your setting.

  • Training from scratch on ImageNet-1K

To train a JPEG-Ti on ImageNet from scratch, run:

python train.py --model_arch=vitti --embed_type=1 --domain=dct --datapath=/path_to/tarfiles --temp_datapath=/tmp/temp_path --indexpaths=assets/indexbase_train.csv,assets/indexbase_val.csv --num_cpus=16 --savepath=/path_to_save/jpegti_model.pth --verbose=2 --train --eval
  • Evaluation

To evaluate a pre-trained JPEG-Ti on ImageNet val, run:

python train.py --model_arch=vitti --embed_type=1 --domain=dct --datapath=/path_to/tarfiles --temp_datapath=/tmp/temp_path --indexpaths=assets/indexbase_train.csv,assets/indexbase_val.csv --num_cpus=16 --loadpath=/model_savedpath/jpegti_model.pth --verbose=2 --eval

You can change the --model_arch option to the corresponding model architecture (vitti, vits, vitb, swinv2).

  • Throughput

To measure the throughput for 128 iterations, run:

python train.py --model_arch=vitti --embed_type=1 --domain=dct --datapath=/path_to/tarfiles --temp_datapath=/tmp/temp_path --indexpaths=assets/indexbase_train.csv,assets/indexbase_val.csv --num_cpus=16 --loadpath=/model_savedpath/jpegti_model.pth --verbose=2 --benchmark 128

rgb-no-more's People

Contributors

jeongsoop avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.