GithubHelp home page GithubHelp logo

mldl / only_train_once Goto Github PK

View Code? Open in Web Editor NEW

This project forked from tianyic/only_train_once

0.0 0.0 0.0 1.24 MB

[ICLR 2023] OTOv2: Automatic, Generic, User-Friendly; [NeurIPS 2021] Only Train Once: A One-Shot Neural Network Training And Pruning Framework

Home Page: https://openreview.net/pdf?id=7ynoX1ojPMt

License: MIT License

Python 91.22% Jupyter Notebook 8.78%

only_train_once's Introduction

Only Train Once (OTO): Automatic One-Shot DNN Training And Compression Framework

OTO-bage autoML-bage DNN-training-bage DNN-compress-bage build-pytorchs-bage build-onnx-bage lincese-bage prs-bage

oto_overview

This repository is the Pytorch implementation of Only Train Once (OTO). OTO is an automatic general DNN training and compression (via structure pruning) framework. By OTO, users could train a general DNN from scratch to achieve both high performance and slimmer architecture simultaneously in the one-shot manner (without pretraining and fine-tuning).

Working Items.

We will release detailed documentations regarding the OTO API in the coming week.

Publications

Please find our series of works.

oto_vs_others

Installation

This package runs under PyTorch 1.9+ except 1.12 (recommend 1.11 and 1.13). Use pip or git clone to install.

pip install only_train_once

or

git clone https://github.com/tianyic/only_train_once.git

Quick Start

We provide an example of OTO framework usage. More explained details can be found in tutorals.

Minimal usage example.

import torch
from backends import DemoNet
from only_train_once import OTO

# Create OTO instance
model = DemoNet()
dummy_input = torch.zeros(1, 3, 32, 32)
oto = OTO(model=model.cuda(), dummy_input=dummy_input.cuda())

# Create DHSPG optimizer
optimizer = oto.dhspg(lr=0.1, target_group_sparsity=0.7)

# Train the DNN as normal via DHSPG
model.train()
model.cuda()
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(max_epoch):
    f_avg_val = 0.0
    for X, y in trainloader:
        X, y = X.cuda(), y.cuda()
        y_pred = model.forward(X)
        f = criterion(y_pred, y)
        optimizer.zero_grad()
        f.backward()
        optimizer.step()

# A DemoNet_compressed.onnx will be generated. 
oto.compress()

How OTO works.

  • Zero-Invariant Group Partition. OTO at first automatically figures out the dependancy inside the target DNN and partition DNN's trainable variables into so-called Zero-Invariant Groups (ZIGs). ZIG is a class of minimal removal structure of DNN, or can be largely interpreted as the minimal group of variables that must be pruned together. zig_partition

  • Dual Half-Space Project Gradient (DHSPG). A structured sparsity optimization problem is formulated. DHSPG is then employed to find out which ZIGs are redundant, and which ZIGs are important for the model prediction. DHSPG explores group sparsity more reliably and typically achieves higher generalization performance than other optimizers. dhspg

  • Construct compressed model. The structures corresponding to redundant ZIGs (being zero) are removed to form the compressed model. Due to the property of ZIGs, the compressed model returns the exact same output as the full model. Therefore, no further fine-tuning is required.

comp_construct

More full and compressed models

Please find more full and compressed models by OTO on checkpoints. The full and compressed models return the exact same outputs given the same inputs.

The dependancy graphs for ZIG partition can be found at Dependancy Graphs.

Remarks and to do list

The current OTO library depends on

  • The target model needed to be convertable into ONNX format for conducting dependancy graph construction.

  • Please check our supported operators list if meeting some errors.

  • The effectiveness (ultimate compression ratio and model performance) relies on the proper usage of DHSPG optimizer. Please go through our tutorials for setup (will be kept updated).

We will routinely complete the following items.

  • Provide more tutorials to cover more use cases and applications of OTO.

  • Provide documentations of the OTO API.

  • Optimize the dependancy list.

Welcome Contributions

We greatly appreciate the contributions from our open-source community to make DNN's training and compression to be more automatic and convinient.

Citation

If you find the repo useful, please kindly star this repository and cite our papers:

@inproceedings{chen2023otov2,
  title={OTOv2: Automatic, Generic, User-Friendly},
  author={Chen, Tianyi and Liang, Luming and Tianyu, DING and Zhu, Zhihui and Zharkov, Ilya},
  booktitle={International Conference on Learning Representations},
  year={2023}
}

@inproceedings{chen2021only,
  title={Only Train Once: A One-Shot Neural Network Training And Pruning Framework},
  author={Chen, Tianyi and Ji, Bo and Tianyu, DING and Fang, Biyi and Wang, Guanyi and Zhu, Zhihui and Liang, Luming and Shi, Yixin and Yi, Sheng and Tu, Xiao},
  booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
  year={2021}
}

only_train_once's People

Contributors

tianyic avatar xloem avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.