GithubHelp home page GithubHelp logo

cualquiercosa327 / sparse_learning Goto Github PK

View Code? Open in Web Editor NEW

This project forked from timdettmers/sparse_learning

0.0 0.0 0.0 2.38 MB

Sparse learning library and sparse momentum resources.

License: MIT License

Shell 1.47% Python 98.53%

sparse_learning's Introduction

Sparse Learning Library and Sparse Momentum Resources

This repo contains a sparse learning library which allows you to wrap any PyTorch neural network with a sparse mask to emulate the training of sparse neural networks. It also contains the code to replicate our work Sparse Networks from Scratch: Faster Training without Losing Performance.

Requirements

The library requires PyTorch v1.2. You can download it via anaconda or pip, see PyTorch/get-started for further information. For CUDA version < 9.2 you need to either compile from source, or install a new CUDA version along with a compatible video driver.

Installation

  1. Install PyTorch.
  2. Install other dependencies: pip install -r requirements.txt
  3. Install the sparse learning library: python setup.py install

Basic Usage

MNIST & CIFAR-10 models

MNIST and CIFAR-10 code can be found in the mnist_cifar subfolder. You can run python main.py --data DATASET_NAME --model MODEL_NAME to run a model on MNIST (--data mnist) or CIFAR-10 (--data cifar).

The following models can be specified with the --model command out-of-the-box:

 MNIST:

	lenet5
	lenet300-100

 CIFAR-10:

	alexnet-s
	alexnet-b
	vgg-c
	vgg-d
	vgg-like
	wrn-28-2
	wrn-22-8
	wrn-16-8
	wrn-16-10

Beyond standard parameters like batch-size and learning rate which usage can be seen by python main.py --help the following sparse learning specific parameter are available:

--save-features       Resumes a saved model and saves its feature data to
                      disk for plotting.
--bench               Enables the benchmarking of layers and estimates
                      sparse speedups
--growth GROWTH       Growth mode. Choose from: momentum, random, and
                      momentum_neuron.
--death DEATH         Death mode / pruning mode. Choose from: magnitude,
                      SET, threshold.
--redistribution REDISTRIBUTION
                      Redistribution mode. Choose from: momentum, magnitude,
                      nonzeros, or none.
--death-rate DEATH_RATE
                      The pruning rate / death rate.
--density DENSITY     The density of the overall sparse network.
--sparse              Enable sparse mode. Default: True.

Running an ImageNet Model

To run ImageNet with 16-bit you need to install Apex. For me it currently does not work to install apex from pip, but installing it from the repo works just fine.

The ImageNet code for sparse momentum can be found in the sub-folder imagenet which contains two different ResNet-50 ImageNet models: A baseline that is used by Mostafa & Wang (2019) which reaches 74.9% accuravy with 100% weights and a tuned ResNet-50 version which is identical to the baseline but uses a warmup learning rate and label smoothing and reaches 77.0% accuracy with 100% weights. The tuned version builds on NVIDIA Deep Learning Examples: RN50v1.5 while the baseline builds on Intel/dynamic-reparameterization.

Running Your Own Model

With the sparse learning library it is easy to run sparse momentum on your own model. All that you need to do is follow the following code template:

alt text

Extending the Library

It is easy to extend the library with your own functions for growth, redistribution and pruning. See The Extension Tutorial for more information about how you can add your own functions.

sparse_learning's People

Contributors

iamanigeeit avatar timdettmers avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.