GithubHelp home page GithubHelp logo

python-repository-hub / research-hdas Goto Github PK

View Code? Open in Web Editor NEW

This project forked from msight-tech/research-hdas

0.0 0.0 0.0 980 KB

H-DAS: Unchain the Search Space with Hierarchical Differentiable Architecture Search

License: Other

Python 100.00%

research-hdas's Introduction

License: CC BY-NC 4.0

H-DAS: Unchain the Search Space with Hierarchical Differentiable Architecture Search

PyTorch implementation of Unchain the Search Space with Hierarchical Differentiable Architecture Search, supporting DDP (DistributedDataParallel) and Apex Amp (Automatic Mixed Precision).

Framework

Main

Requirements

python>=3.6
pytorch>=1.4
tensorboardX
apex
numpy
graphviz

Datasets

While CIFAR10 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded following the instructions here.

Hc-DAS

Search for three stage-specific normal cells and two reduction cells

Adjust the batch size if out of memory occurs. It depends on your gpu memory size and genotype.

  • Search micro-architectures (cell-level structures)
python searchCell_main.py --name cifar10_test --dataset cifar10
  • Augment micro-architectures (cell-level structures)
python augmentCell_main.py --name Hc-DAS  --dataset cifar10  \
--genotype "[Genotype3 searched by 'searchCell_main.py']"
  • Run the following commands with DistributedDataParallel (ddp_version)
python -m torch.distributed.launch --nproc_per_node=8 \
ddp/augmentCell_main.py --name Hc-DAS  --dataset cifar10  \
--genotype "[Genotype3 searched by 'searchCell_main.py']"  --dist

Hs-DAS

Search for macro-architectures with fixed cell structures.

Search macro-architectures (stage-level structures)
  • Run the following commands to perform a stage-level structure search progress on CIFAR10.
python searchStage_main.py --name macro-cifar10-test \
--w_weight_decay 0.0027  --dataset cifar10 --batch_size 64 \
--workers 0  --genotype "[Genotype searched by DARTS or PDARTS or other Darts-series methods]"
  • Run the following commands to perform a stage-level structure search progress for ImageNet
python searchStage_ImageNet_main.py --name macro-ImageNet-test  \
--dataset cifar10 --batch_size 128 --w_weight_decay 0.0027  \
--workers 16  --genotype "[Genotype searched by DARTS or PDARTS or other Darts-series methods]"
Augment macro-architectures (stage-level structures)
  • Run the following commands to perform a stage-level structure augment progress on CIFAR10.
python augmentStage_main.py  --name Hs-DAS-cifar10-test  \
--init_channels 36  --workers 16 --lr 0.025  --batch_size 64 \
--dataset cifar10  --epochs 600  \
--genotype "[Genotype searched by DARTS or PDARTS or other Darts-series methods]"  \
--DAG "[Genotype2 searched by 'searchDAG_main.py']"
  • Run the following commands with DistributedDataParallel (ddp_version)
python -m torch.distributed.launch --nproc_per_node=8  \
ddp/augmentStage_main.py  --init_channels 36 --workers 16 \
--name Hs-DAS-cifar10-test  --dataset cifar10 \
--genotype "[Genotype searched by DARTS or PDARTS or other Darts-series methods]" \
--DAG "[Genotype2 searched by 'searchDAG_main.py']" --dist
  • Run the following commands to perform a stage-level structure augment progress on ImageNet.
python augmentStage_ImageNet_main.py  --init_channels 39 \
--batch_size 105  --lr 0.1  --workers 16  --dataset imagenet \
--name Hs-DAS_imagenet-test  --print_freq 100  --epoch 250 \
--genotype "[Genotype searched by DARTS or PDARTS or other Darts-series methods]" \
--DAG "[Genotype2 searched by 'searchDAG_ImageNet_main.py']"
  • Run the following commands with DistributedDataParallel (ddp_version)
python -m torch.distributed.launch --nproc_per_node=8 \
ddp/augmentStage_ImageNet_main.py --init_channels 39 \
--workers 16 --name Hs-DAS-imagenet-test  --dataset imagenet  \
--print_freq 100  --epochs 250 \
--genotype "[Genotype searched by DARTS or PDARTS or other Darts-series methods]" \
--DAG "[Genotype2 searched by 'searchDAG_ImageNet_main.py']" --dist

We provide the augmentCell and augmentStage with DistributedDataParallel (ddp_version), but you should change some hyperparameters, such as lr and batch_size.

Search for distribution of cells

The search of cell distribution over three stages is performed under a constraint of certain computational complexity.

  • Run the following commands to search for the distribution of cells.
python searchDistribution_main.py --name cifar10-changeStage \
--w_weight_decay 0.0027  --dataset cifar10 --batch_size 64 --workers 0  \
--genotype "[Genotype searched by DARTS or PDARTS or other Darts-series methods]"

Results

Results on CIFAR10

Table_CIFAR10

Results on ImageNet

Table_ImageNet

Searched micro-architectures

  • normal cells (Hc-DAS) normal_cell
  • reduction cells (Hc-DAS) reduction_cell

Searched macro-architectures

  • macro-architectures of CIFAR10 (Hs-DAS) macro-cifar
  • macro-architectures of ImageNet (Hs-DAS) macro-imagenet

Citations

Please cite our paper if this implementation helps your research. BibTex reference is shown in the following.

@inproceedings{liu2021hdas,
  title={Unchain the Search Space with Hierarchical Differentiable Architecture Search},
  author={Liu, Guanting and Zhong, Yujie and Guo, Sheng and Scott, Matthew R and Huang, Weilin},
  booktitle={AAAI},
  year={2021}
}

Contact

For any questions, please fell free to reach:

License

H-DAS is CC-BY-NC 4.0 licensed, as found in the LICENSE file. It is released for academic research / non-commercial use only. If you wish to use for commercial purposes, please contact [email protected].

research-hdas's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.