GithubHelp home page GithubHelp logo

sahgerlad / resnext-in-tensorflow Goto Github PK

View Code? Open in Web Editor NEW

This project forked from wenxinxu/resnext-in-tensorflow

0.0 0.0 0.0 1.37 MB

This is an implementation of ResNeXt (by Xie et al.) in tensorflow

Python 100.00%

resnext-in-tensorflow's Introduction

ResNeXt in Tensorflow

This is an implementation of ResNeXt in tensorflow. The tensorboard visualization of the detailed model structure (ResNeXt-29, 4x64d as example) is shown in graph, block, and split.

I implemented the blocks with both methods in Figure 3b (split->concat).

**Figure 3c (grouped convolutions) is not available yet. **

tensorflow 1.0.0 is needed here. They re-named lots of functions, so the previous versions won't work...

Quick guide

You can run cifar10_train.py to train a ResNeXt on cifar10 and see how it works from the screen output (the code will download the data for you if you don't have them yet). It’s better to speicify a version identifier before running, since the training logs, checkpoints, and error.csv file will be saved in the folder with name logs_$version.

python cifar10_train.py --version='test'

You may alter the shape of the model via the hyper-parameters. Typically a ResNeXt is represented as 'ResNeXt-a, b*c'.

a is the total layer, which is defined by 9 * FLAGS.num_resnext_blocks + 2.

b is the cardinality, which is defined by FLAGS.cardinality.

c is the number of channels in each split, which is defined by FLAGS.block_unit_depth

To play with these hyper-parameters, you could either change inside the hyper-parameters.py or use the following commands:

python cifar10_train.py --version='test',num_resnext_blocks=3,cardinality=4

More explanations about the hyper-parameters can be found here

Files included

  1. hyper-parameters.py defines the hyper-parameters related to train, ResNeXt structure, data augmentation, etc.

  2. cifar10_input.py includes the data I/O, pre-processing of images and data augmentation

  3. resNeXt.py is the main body of ResNeXt network

  4. cifar10_train.py is responsible for the training and validation

TODO:

  1. Train the model and generate the learning curve

  2. Data augmentation: cv2 is not compatible w/ tensorflow 1.0.0, so I need to:

    a. Wait for a new version of opencv

    or

    b. Use the queueRunner in tensorflow as data I/O and implement data augmentaion

resnext-in-tensorflow's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.