GithubHelp home page GithubHelp logo

zero-cost-nas's Introduction

Zero-Cost-NAS

Companion code for the ICLR2021 paper: Zero-Cost Proxies for Lightweight NAS
tl;dr A single minibatch of data is used to score neural networks for NAS instead of performing full training.

In this README, we provide:

If you have any questions, please open an issue or email us. (last update: 02.02.2021)

Summary

Intro. To perform neural architecture search (NAS), deep neural networks (DNNs) are typically trained until a final validation accuracy is computed and used to compare DNNs to each other and select the best one. However, this is time-consuming because training takes multiple GPU-hours/days/weeks. This is why a proxy for final accuracy is often used to speed up NAS. Typically, this proxy is a reduced form of training (e.g. EcoNAS) where the number of epochs is reduced, a smaller model is used or the training data is subsampled.

Proxies. Instead, we propose a series of "zero-cost" proxies that use a single-minibatch of data to score a DNN. These metrics are inspired by recent pruning-at-initialization literature, but are adapted to score an entire DNN and work within a NAS setting. When compared against econas (see orange pentagon in plot below), our zero-cost metrics take ~1000X less time to run but are better-correlated with final validation accuracy (especially synflow and jacob_cov), making them better (and much cheaper!) proxies for use within NAS. Even when EcoNAS is tuned specifically for NAS-Bench-201 (see econas+ purple circle in the plot), our vote zero-cost proxy is still better-correlated and is 3 orders of magnitude cheaper to compute.

Figure 1: Correlation of validation accuracy to final accuracy during the first 12 epochs of training (blue line) for three CIFAR-10 on the NAS-Bench-201 search space. Zero-cost and EcoNAS proxies are also labeled for comparison.

zero-cost vs econas

Zero-Cost NAS We use the zero-cost metrics to enhance 4 existing NAS algorithms, and we test it out on 3 different NAS benchmarks. For all cases, we achieve a new SOTA (state of the art result) in terms of search speed. We incorporate zero-cost proxies in two ways: (1) warmup: Use proxies to initialize NAS algorithms, (2) move proposal: Use proxies to improve the selection of the next model for evaluation. As Figure 2 shows, there is a significant speedup to all evaluated NAS algorithms.

Figure 2: Zero-Cost warmup and move proposal consistently improves speed and accuracy of 4 different NAS algorithms.

Zero-Cost-NAS speedup

For more details, please take a look at our paper!

Running the Code

  • Install PyTorch for your system (v1.5.0 or later).
  • Install the package: pip install . (add -e for editable mode) -- note that all dependencies other than pytorch will be automatically installed.

API

The main function is find_measures below. Given a neural net and some information about the input data (dataloader) and loss function (loss_fn) it returns an array of zero-cost proxy metrics.

def find_measures(net_orig,                  # neural network
                  dataloader,                # a data loader (typically for training data)
                  dataload_info,             # a tuple with (dataload_type = {random, grasp}, number_of_batches_for_random_or_images_per_class_for_grasp, number of classes)
                  device,                    # GPU/CPU device used
                  loss_fn=F.cross_entropy,   # loss function to use within the zero-cost metrics
                  measure_names=None,        # an array of measure names to compute, if left blank, all measures are computed by default
                  measures_arr=None):        # [not used] if the measures are already computed but need to be summarized, pass them here

The available zero-cost metrics are in the measures directory. You can add new metrics by simply following one of the examples then registering the metric in the load_all function. More examples of how to use this function can be found in the code to reproduce results (below). You can also modify data loading functions in p_utils.py

Reproducing Results

NAS-Bench-201

  1. Download the NAS-Bench-201 dataset and put in the data directory in the root folder of this project.
  2. Run python nasbench2_pred.py with the appropriate cmd-line options -- a pickle file is produced with zero-cost metrics (see notebooks folder on how to use the pickle file.
  3. Note that you need to manually download ImageNet16 and put in _datasets/ImageNet16 directory in the root folder. CIFAR-10/100 will be automatically downloaded.

NAS-Bench-101

  1. Download the data directory and save it to the root folder of this repo. This contains pre-cached info from the NAS-Bench-101 repo.
  2. [Optional] Download the NAS-Bench-101 dataset and put in the data directory in the root folder of this project and also clone the NAS-Bench-101 repo and install the package.
  3. Run python nasbench1_pred.py. Note that this takes a long time to go through ~400k architectures, but precomputed results are in the notebooks folder (with a link to the results).

PyTorchCV

  1. Run python ptcv_pred.py

NAS-Bench-ASR

Coming soon...

NAS with Zero-Cost Proxies

For the full list of NAS algorithms in our paper, we used a different NAS tool which is not publicly released. However, we included a notebook nas_examples.ipynb to show how to use zero-cost proxies to speed up aging evolution and random search methods using both warmup and move proposal.

Citation

@inproceedings{
  abdelfattah2021zerocost,
  title={{Zero-Cost Proxies for Lightweight NAS}},
  author={Mohamed S. Abdelfattah and Abhinav Mehrotra and {\L}ukasz Dudziak and Nicholas D. Lane},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2021}
}

zero-cost-nas's People

Contributors

mohsaied avatar vaenyr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

zero-cost-nas's Issues

AssertionError: _dist_info should not exist when repo is in place

When I run python nasbench2_pred.py as mentioned in the reproducing results part, I encounter with such error:
Traceback (most recent call last):
File "/home/yuntao/test/zero-cost-nas/nasbench2_pred.py", line 20, in
from foresight.models import *
File "/home/yuntao/test/zero-cost-nas/foresight/init.py", line 16, in
from .version import *
File "/home/yuntao/test/zero-cost-nas/foresight/version.py", line 44, in
assert not has_repo, '_dist_info should not exist when repo is in place'
AssertionError: _dist_info should not exist when repo is in place

list index out of range

I keep getting this list index out of range error for the example code:

==========

IndexError Traceback (most recent call last)
in ()
6 ae_best_valids, ae_best_tests = run_evolution_search(max_trained_models=length)
7 ae.append(ae_best_tests)
----> 8 ae_warmup_best_valids, ae_warmup_best_tests = run_evolution_search(max_trained_models=length, zero_cost_warmup=3000)
9 ae_warmup.append(ae_warmup_best_tests)
10 ae_move_best_valids, ae_move_best_tests = run_evolution_search(max_trained_models=length, zero_cost_move=True)

in run_evolution_search(max_trained_models, pool_size, tournament_size, zero_cost_warmup, zero_cost_move)
16 spec_idx = spec_to_idx[str(spec)]
17 # try:
---> 18 zero_cost_pool.append((synflow_proxy[spec_idx], spec))
19 zero_cost_pool = sorted(zero_cost_pool, key=lambda i:i[0], reverse=True)
20 # except Exception as ex:

IndexError: list index out of range

Is this related to the start/end parameters for this step?

python nasbench2_pred.py --start 0 --end 5

Problem about the calculation of synflow.

Hello! I'm very interested in your work.

In Section 3.2.1 "SNIP, GRASP AND Synaptic Flow", synflow is formulated without abs operation.

However, for the code in "synflow.py", synflow is calculated by "torch.abs(layer.weight * layer.weight.grad)".

I wonder about the role and the reference of the abs operation. Why should we use it? I try to refer to the original paper "Pruning neural networks without any data by iteratively conserving synaptic flow", but I still have no idea.

NAS

When I do NAS, I find that the prediction accuracy of different network frameworks is very different. When normalizing, I encounter very huge difficulties.Is there any way to deal with it?

measures['synflow']:2.774976807125541e+43
measures['synflow']:1.811579008753419e+32
measures['synflow']:2.914237836508487e+61

The fitness in nasbench2_pred.py

In Nasbench2_pred.py, the value of accuracy was found to depend on the data in NAS-Bench-201-v1_0-E61699.pth. As mentioned in the paper, is synflow used as the fitness to measure the framework?

Reproducing results for NAS-Bench-ASR

Currently I am trying to reproduce results on the NAS-ASR benchmark however I have been unsuccessful so far. It would be really helpful if you could answer some questions I had regarding this:

  1. I am querying architectures from the ASR space and directly running the given proxies on the architectures (similar to how it's done in the code for NASBench201 (nasbench201_pred.py)). Are there any intermediary steps involved specific to the ASR case that I am missing?

  2. Is the original ASR architecture (https://github.com/SamsungLabs/nb-asr) modified in any way or are any of the parameters changed before running the given implementation of the proxies?

  3. Are the zero-cost proxy implementations modified in any way for the ASR case? Specific to Synflow, are all the layers (Conv1d, Linear, LSTM and Layernorm) considered for pruning while calculating the metric for a particular architecture?

In case I am completely off the mark, I would be really grateful if you could provide a broad outline for how I can go about reproducing these results.

Reproduce NAS-Bench-NLP results

Hi, I really enjoyed your paper!

We incorporated your techniques in our recent project. I was wondering about your plan for releasing the code for the NAS-Bench-NLP results, because we would like to include your techniques on that search space as well. Thanks!

Public precomputed results

Thanks for your great job!!!
I have tried to download your pre-computed results, but it seems that they are private. Could you make them public?
Hopefully for your reply.

Can you provide the nasbench101 search notebook?

For example, in Fig 3. AE+W (15k) achieves 94.22 with 50 trained models. I don't understand because 64 models are trained for the initial pool and it seems weird that in this case AE doesn't work at all. But like RAND only need 34 trained models to achieve 94.22 so I'm a lil bit confuse...

Top-5%/top-64 computation

Hello,

Thanks for a great paper.

When you compute top-5%/top-64 score (Tables 4, 11), how many architectures are there in total?
Is it 3000 architectures (only warmup) or the size of the entire dataset?

Cheers,
Ekaterina

About the ImageNet1k datasets

Hi, thanks for the nice work. I would like to ask how to get 'imageNet-train-256. h5' and' imageNet-val-256. h5' in imageNet1k.

elif dataset == 'ImageNet1k':
train_dataset = H5Dataset(os.path.join(datadir, 'imagenet-train-256.h5'), transform=train_transform)
test_dataset = H5Dataset(os.path.join(datadir, 'imagenet-val-256.h5'), transform=test_transform)

pickle file used in nas_examples.ipynb

Thank you very much for this great work!

In nas_examples.ipynb, you pre-computed the scores (synflow_proxy) and stored in nb2_cf100_seed42_dlrandom_dlinfo1_initwnone_initbnone.p for the zero_cost_warmup.

To run your notebook, would you mind sharing this pre-computed synflow_proxy score file?

Thank you!

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

Excuse me, when I run python nasbench1_pred.py --dataset cifar10 --start 0 --end 1000, such error occurs at idx=22 :
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.DoubleTensor [1, 512, 8, 8]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its
gradient, with torch.autograd.set_detect_anomaly(True).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.