This repository is the official PyTorch implementation of "CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances" by Jihoon Tack*, Sangwoo Mo*, Jongheon Jeong and Jinwoo Shin.
Currently, requires following packages
- python 3.6+
- torch 1.4+
- torchvision 0.5+
- CUDA 10.1+
- scikit-learn 0.22+
- tensorboard 2.0+
- torchlars == 0.1.2
- pytorch-gradual-warmup-lr packages
- apex == 0.1
- diffdist == 0.1
For CIFAR, please download the following datasets to ~/data
.
For ImageNet-30, please download the following datasets to ~/data
.
- ImageNet-30-train, ImageNet-30-test
- CUB-200, Stanford Dogs, Oxford Pets, Oxford flowers, Food-101, Places-365, Caltech-256, DTD
For Food-101, remove hotdog class to avoid overlap.
Currently, all code examples are assuming distributed launch with 4 multi GPUs. To run the code with single GPU, remove -m torch.distributed.launch --nproc_per_node=4
.
To train unlabeled one-class & multi-class models in the paper, run this command:
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode simclr_CSI --shift_trans_type rotation --batch_size 32 --one_class_idx <One-Class-Index>
Option --one_class_idx denotes the in-distribution of one-class training. For multi-class training, set --one_class_idx as None. To run SimCLR simply change --mode to simclr. Total batch size should be 512 = 4 (GPU) * 32 (--batch_size option) * 4 (cardinality of shifted transformation set).
To train labeled multi-class model (confidence calibrated classifier) in the paper, run this command:
# Representation train
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode sup_simclr_CSI --shift_trans_type rotation --batch_size 32 --epoch 700
# Linear layer train
python train.py --mode sup_linear --dataset <DATASET> --model <NETWORK> --batch_size 32 --epoch 100 --shift_trans_type rotation --load_path <MODEL_PATH>
To run SupCLR simply change --mode to sup_simclr. Total batch size should be same as above. Currently only supports rotation for shifted transformation.
We provide the checkpoint of the CSI pre-trained model. Download the checkpoint from the following link:
- Unlabeled CIFAR-10 multi-class: ResNet-18
- Unlabeled ImageNet-30 multi-class: ResNet-18
- Labeled CIFAR-10 multi-class: ResNet-18
We will update other checkpoints in the future.
To evaluate my model on unlabeled one-class & multi-class out-of-distribution (OOD) detection setting, run this command:
python eval.py --mode ood_pre --dataset <DATASET> --model <NETWORK> --ood_score CSI --shift_trans_type rotation --print_score --ood_samples 10 --resize_factor 0.54 --resize_fix --one_class_idx <One-Class-Index> --load_path <MODEL_PATH>
Option --one_class_idx denotes the in-distribution of one-class evaluation. For multi-class evaluation, set --one_class_idx as None. The resize_factor & resize fix option fix the cropping size of RandomResizedCrop(). For SimCLR evaluation, change --ood_score to simclr.
To evaluate my model on labeled multi-class accuracy, ECE, OOD detection setting, run this command:
# OOD AUROC
python eval.py --mode ood --ood_score baseline_marginalized --print_score --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH>
# Accuray & ECE
python eval.py --mode test_marginalized_acc --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH>
This option is for marginalized inference, for single inference (also used for SupCLR) change --ood_score baseline in first command, and --mode test_acc in second command.
Our model achieves the following performance on:
Method | Dataset | AUROC (Mean) |
---|---|---|
SimCLR | CIFAR-10-OC | 87.9% |
Rot+Trans | CIFAR-10-OC | 90.0% |
CSI (ours) | CIFAR-10-OC | 94.3% |
We only show CIFAR-10 one-class result in this repo. For other setting, please see our paper.
Method | Dataset | OOD Dataset | AUROC (Mean) |
---|---|---|---|
Rot+Trans | CIFAR-10 | CIFAR-100 | 82.5% |
CSI (ours) | CIFAR-10 | CIFAR-100 | 89.3% |
We only show CIFAR-10 to CIFAR-100 OOD detection result in this repo. For other OOD dataset results, see our paper.
Method | Dataset | OOD Dataset | Acc | ECE | AUROC (Mean) |
---|---|---|---|---|---|
SupCLR | CIFAR-10 | CIFAR-100 | 93.9% | 5.54% | 88.3% |
CSI (ours) | CIFAR-10 | CIFAR-100 | 94.8% | 4.24% | 90.6% |
CSI-ensem (ours) | CIFAR-10 | CIFAR-100 | 96.0% | 3.64% | 92.3% |
We only show CIFAR-10 with CIFAR-100 as OOD in this repo. For other dataset results, please see our paper.
We find that current benchmark datasets for OOD detection, are visually far from in-distribution datasets (e.g. CIFAR).
To address this issue, we provide new datasets for OOD detection evaluation: LSUN_fix, ImageNet_fix. See the above figure for the visualization of current benchmark and our dataset.
To generate OOD datasets, run the following codes inside the ./datasets
folder:
# ImageNet FIX generation code
python imagenet_fix_preprocess.py
# LSUN FIX generation code
python lsun_fix_preprocess.py