GithubHelp home page GithubHelp logo

bingzhaozhu / adaptiveneuraltrees Goto Github PK

View Code? Open in Web Editor NEW

This project forked from rtanno21609/adaptiveneuraltrees

0.0 0.0 0.0 279 KB

Adaptive Neural Trees

License: MIT License

Python 19.93% Shell 0.11% Jupyter Notebook 79.96%

adaptiveneuraltrees's Introduction

Adaptive Neural Trees

MIT License

This repository contains our PyTorch implementation of Adaptive Neural Trees (ANTs).

The code was written by Ryutaro Tanno and supported by Kai Arulkumaran.

Paper (ICML'19) | Video from London ML meetup

Prerequisites

  • Linux or macOS
  • Python 2.7
  • Anaconda >= 4.5
  • CPU or NVIDIA GPU + CUDA 8.0

Installation

  • Clone this repo:
git clone https://github.com/rtanno21609/AdaptiveNeuralTrees.git
cd AdaptiveNeuralTrees
  • (Optional) create a new Conda environment and activate it:
conda create -n ANT python=2.7
source activate ANT
  • Run the following to install required packages.
bash ./install.sh

Usage

An example command for training/testing an ANT is given below.

python tree.py --experiment test_ant_cifar10  #name of experiment \
               --subexperiment myant  #name of subexperiment \
               --dataset cifar10   #dataset \
                # Model details:    \
               --router_ver 3        #type of router module \
               --router_ngf 128      #no. of kernels in routers \
               --router_k 3          #spatial size of kernels in routers \
               --transformer_ver 5   #type of transformer module \
               --transformer_ngf 128 #no. of kernels in transformers \
               --transformer_k 3     #spatial size of kernels in transformers \
               --solver_ver 6        #type of solver module \
               --batch_norm          #apply batch-norm \
               --maxdepth 10         #maximum depth of the tree-structure \
                # Training details: \
               --batch-size 512    #batch size \
               --augmentation_on   #apply data augmentation \
               --scheduler step_lr #learning rate scheduling \
               --criteria avg_valid_loss # splitting criteria
               --epochs_patience 5 #no. of patience per node for growth phase \
               --epochs_node 100   #max no. of epochs per node for growth phase \
               --epochs_finetune 200 #no. of epochs for fine-tuning phase \
               # Others: \
               --seed 0            #randomisation seed
               --num_workers 0     #no. of CPU subprocesses used for data loading \
               --visualise_split  # save the tree structure every epoch \

The model configurations and optimisation trajectory (e.g value of train/validation loss at each time point) are saved in records.jason in the directory ./experiments/dataset/experiment/subexperiment/checkpoints. Similarly, tree structure and best trained model are saved as tree_structures.json and model.pth, respectively under the same directory. If the visualisation option --visualise_split is used, the tree architecture of the ANT is saved in the PNG format in the directory ./experiments/dataset/experiment/subexperiment/cfigures.

By default, the average classification accuracy is also computed on train/valid/test sets for every epoch and saved in records.jason file, so running tree.py would suffice for both training and testing an ANT of particular configurations.

Jupyter Notebooks

We have also included two Jupter notebooks ./notebooks/example_mnist.ipynb and ./notebooks/example_cifar10.ipynb, which illustrate how this repository can be used to train ANTs on MNIST and CIFAR-10 image recognition datasets.

Primitive modules

Defining an ANT amounts to specifying the forms of primitive modules: routers, transformers and solvers. The table below provides the list of currently implemented primitive modules. You can try any combination of three to construct an ANT.

Type Router Transformer Solver
1 1 x Conv + GAP + Sigmoid Identity function Linear classifier
2 1 x Conv + GAP + 1 x FC 1 x Conv MLP with 2 hidden layers
3 2 x Conv + GAP + 1 x FC 1 x Conv + 1 x MaxPool MLP with 1 hidden layer
4 MLP with 1 hidden layer Bottleneck residual block (He et al., 2015) GAP + 2 FC layers + Softmax
5 GAP + 2 x FC layers (Veit et al., 2017) 2 x Conv + 1 x MaxPool MLP with 1 hidden layer in AlexNet (layers-80sec.cfg)
6 1 x Conv + GAP + 2 x FC Whole VGG13 architecture (without the linear layer) GAP + 1 FC layers + Softmax

For the detailed definitions of respective modules, please see utils.py and models.py.

Citation

If you use this code for your research, please cite our ICML paper:

@inproceedings{AdaptiveNeuralTrees19,
  title={Adaptive Neural Trees},
  author={Tanno, Ryutaro and Arulkumaran, Kai and Alexander, Daniel and Criminisi, Antonio and Nori, Aditya},
  booktitle={Proceedings of the 36th International Conference on Machine Learning (ICML)},
  year={2019},
}

Acknowledgements

I would like to thank Daniel C. Alexander at University College London, UK, Antonio Criminisi at Amazon Research, and Aditya Nori at Microsoft Research Cambridge for their valuable contributions to this paper.

adaptiveneuraltrees's People

Contributors

rtanno21609 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.