GithubHelp home page GithubHelp logo

bigballon / cifar-zoo Goto Github PK

View Code? Open in Web Editor NEW
700.0 19.0 178.0 79 KB

PyTorch implementation of CNNs for CIFAR benchmark

License: MIT License

Python 99.82% Shell 0.18%
pytorch cifar resnet resnext densenet shake-shake mixup cutout learning-rate-decay senet

cifar-zoo's Introduction

Awesome CIFAR Zoo

Status: Archive ( Final test with PyTorch 1.7 and no longer maintained, I would recommend you to use pycls powered by FAIR, which is a simple and flexible codebase for image classification )

This repository contains the pytorch code for multiple CNN architectures and improve methods based on the following papers, hope the implementation and results will helpful for your research!!

Requirements and Usage

Requirements

  • Python (>=3.6)
  • PyTorch (>=1.1.0)
  • Tensorboard(>=1.4.0) (for visualization)
  • Other dependencies (pyyaml, easydict)
pip install -r requirements.txt

Usage

simply run the cmd for the training:

## 1 GPU for lenet
CUDA_VISIBLE_DEVICES=0 python -u train.py --work-path ./experiments/cifar10/lenet

## resume from ckpt
CUDA_VISIBLE_DEVICES=0 python -u train.py --work-path ./experiments/cifar10/lenet --resume

## 2 GPUs for resnet1202
CUDA_VISIBLE_DEVICES=0,1 python -u train.py --work-path ./experiments/cifar10/preresnet1202

## 4 GPUs for densenet190bc
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u train.py --work-path ./experiments/cifar10/densenet190bc

## 1 GPU for vgg19 inference
CUDA_VISIBLE_DEVICES=0 python -u eval.py --work-path ./experiments/cifar10/vgg19

We use yaml file config.yaml to save the parameters, check any files in ./experimets for more details.
You can see the training curve via tensorboard, tensorboard --logdir path-to-event --port your-port.
The training log will be dumped via logging, check log.txt in your work path.

Results on CIFAR

Vanilla architectures

architecture params batch size epoch C10 test acc (%) C100 test acc (%)
Lecun 62K 128 250 67.46 34.10
alexnet 2.4M 128 250 75.56 38.67
vgg19 20M 128 250 93.00 72.07
preresnet20 0.27M 128 250 91.88 67.03
preresnet110 1.7M 128 250 94.24 72.96
preresnet1202 19.4M 128 250 94.74 75.28
densenet100bc 0.76M 64 300 95.08 77.55
densenet190bc 25.6M 64 300 96.11 82.59
resnext29_16x64d 68.1M 128 300 95.94 83.18
se_resnext29_16x64d 68.6M 128 300 96.15 83.65
cbam_resnext29_16x64d 68.7M 128 300 96.27 83.62
ge_resnext29_16x64d 70.0M 128 300 96.21 83.57

With additional regularization

PS: the default data augmentation methods are RandomCrop + RandomHorizontalFlip + Normalize,
and the โˆš means which additional method be used. ๐Ÿฐ

architecture epoch cutout mixup C10 test acc (%)
preresnet20 250 91.88
preresnet20 250 โˆš 92.57
preresnet20 250 โˆš 92.71
preresnet20 250 โˆš โˆš 92.66
preresnet110 250 94.24
preresnet110 250 โˆš 94.67
preresnet110 250 โˆš 94.94
preresnet110 250 โˆš โˆš 95.66
se_resnext29_16x64d 300 96.15
se_resnext29_16x64d 300 โˆš 96.60
se_resnext29_16x64d 300 โˆš 96.86
se_resnext29_16x64d 300 โˆš โˆš 97.03
cbam_resnext29_16x64d 300 โˆš โˆš 97.16
ge_resnext29_16x64d 300 โˆš โˆš 97.19
-- -- -- -- --
shake_resnet26_2x64d 1800 96.94
shake_resnet26_2x64d 1800 โˆš 97.20
shake_resnet26_2x64d 1800 โˆš 97.42
shake_resnet26_2x64d 1800 โˆš โˆš 97.71

PS: shake_resnet26_2x64d achieved 97.71% test accuracy with cutout and mixup!!
It's cool, right?

With different LR scheduler

architecture epoch step decay cosine htd(-6,3) cutout mixup C10 test acc (%)
preresnet20 250 โˆš 91.88
preresnet20 250 โˆš 92.13
preresnet20 250 โˆš 92.44
preresnet20 250 โˆš โˆš โˆš 93.30
preresnet110 250 โˆš 94.24
preresnet110 250 โˆš 94.48
preresnet110 250 โˆš 94.82
preresnet110 250 โˆš โˆš โˆš 95.88

Acknowledgments

Provided codes were adapted from

Feel free to contact me if you have any suggestions or questions, issues are welcome,
create a PR if you find any bugs or you want to contribute. ๐Ÿ˜Š

Citation

@misc{bigballon2019cifarzoo,
  author = {Wei Li},
  title = {CIFAR-ZOO: PyTorch implementation of CNNs for CIFAR dataset},
  howpublished = {\url{https://github.com/BIGBALLON/CIFAR-ZOO}},
  year = {2019}
}

cifar-zoo's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

cifar-zoo's Issues

senet improvement

line 180 of senet.py:
x = F.avg_pool2d(x, 8, 1)

will be better fitting all image sizes if changed to
x = F.avg_pool2d(x, x.size(3), 1)

How are the results reported?

I wonder how the results are reported in your benchmark?
I have had two trials of lenet with the same config, but got best_test_acc of 67.980% and 67.210%, respectively.

Layer C5 in LeNet

Thanks for your repo!

A detail about LeNet-5:

Layer C5 in LeNet should be a convolutional layer as LeCun decribed in his paper:

Layer C5 is a convolutional layer with 120 feature maps.
Each unit is connected to a 5x5 neighborhood on all 16
of S4's feature maps. Here, because the size of S4 is also
5x5, the size of C5's feature maps is 1x1: this amounts
to a full connection between S4 and C5. C5 is labeled
as a convolutional layer, instead of a fully-connected layer
because if LeNet-5 input were made bigger with everything
else kept constant, the feature map dimension would be
larger than 1x1

It is nothing serious~

load() missing 1 required positional argument

def main():
global args, config, last_epoch, best_prec, writer
writer = SummaryWriter(log_dir=args.work_path + "/event")

# read config from yaml file
with open(args.work_path + "/config.yaml") as f:
    config = yaml.load(f)

line 168 in train.py, load() missing 1 required positional argument: 'Loader'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.