GithubHelp home page GithubHelp logo

haitaozhao / mctorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from mctorch/mctorch

0.0 0.0 0.0 188.08 MB

A manifold optimization library for deep learning

License: MIT License

Python 100.00%

mctorch's Introduction

McTorch Lib, a manifold optimization library for deep learning

McTorch is a Python library that adds manifold optimization functionality to PyTorch.

McTorch:

  • Leverages tensor computation and GPU acceleration from PyTorch.
  • Enables optimization on manifold constrained tensors to address nonlinear optimization problems.
  • Facilitates constrained weight tensors in deep learning layers.

Sections:

More about McTorch

McTorch builds on top of PyTorch and supports all PyTorch functions in addition to Manifold optimization. This is done to ensure researchers and developers using PyTorch can easily experiment with McTorch functions. McTorch's manifold implementations and optimization methods are derived from the Matlab toolbox Manopt and the Python toolbox Pymanopt.

Using McTorch for Optimization

  1. Initialize Parameter - McTorch manifold parameters are same as PyTorch parameters (mctorch.nn.Parameter) and requires just addition of one property to parameter initialization to constrain the parameter values.
  2. Define Cost - Cost function can be any PyTorch function using the above parameter mixed with non constrained parameters.
  3. Optimize - Any optimizer from mctorch.optim can be used to optimize the cost function using same functionality as any PyTorch code.

PCA Example

import torch
import mctorch.nn as mnn
import mctorch.optim as moptim

# Random data with high variance in first two dimension
X = torch.diag(torch.FloatTensor([3,2,1])).matmul(torch.randn(3,200))

# 1. Initialize Parameter
manifold_param = mnn.Parameter(manifold=mnn.Stiefel(3,2))

# 2. Define Cost - squared reconstruction error
def cost(X, w):
    wTX = torch.matmul(w.transpose(1,0), X)
    wwTX = torch.matmul(w, wTX)
    return torch.sum((X - wwTX)**2)

# 3. Optimize
optimizer = moptim.rAdagrad(params = [manifold_param], lr=1e-2)

for epoch in range(30):
    cost_step = cost(X, manifold_param)
    print(cost_step)
    cost_step.backward()
    optimizer.step()
    optimizer.zero_grad()

Using McTorch for Deep Learning

Multi Layer Perceptron Example

import torch
import mctorch.nn as mnn
import torch.nn.functional as F

# a torch module using constrained linear layers
class ManifoldMLP(nn.Module):
    def __init__(self):
        super(ManifoldMLP, self).__init__()
        self.layer1 = mnn.rLinear(in_features=28*28, out_features=100, weight_manifold=mnn.Stiefel)
        self.layer2 = mnn.rLinear(in_features=100, out_features=100, weight_manifold=mnn.PositiveDefinite)
        self.output = mnn.rLinear(in_features=100, out_features=10, weight_manifold=mnn.Stiefel)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.log_softmax(self.output(x), dim=0)
        return x

# create module object and compute cost by applying module on inputs
mlp_module = ManifoldMLP()
cost = mlp_module(inputs)

More examples added - here

Functionality Supported

This would be an ever increasing list of features. McTorch currently supports:

Manifolds

  • Stiefel
  • Positive Definite
  • Hyperbolic
  • Doubly Stochastic

All manifolds support k multiplier as well.

Optimizers

  • rSGD
  • rAdagrad
  • rASA
  • rConjugateGradient

Layers

  • Linear
  • Conv1d, Conv2d, Conv3d

Installation

After installing PyTorch can be installed with python setup.py install

Linux

source activate myenv
conda install numpy setuptools
# Add LAPACK support for the GPU if needed
conda install -c pytorch magma-cuda90 # or [magma-cuda80 | magma-cuda92 | magma-cuda100 ] depending on your cuda version
conda install pytorch torchvision cudatoolkit=9.0 -c pytorch # or cudatoolkit=10.0 | cudatoolkit=10.1 | .. depending on your cuda version
pip install mctorch-lib

Release and Contribution

McTorch is currently under development and any contributions, suggestions and feature requests are welcome. We'd closely follow PyTorch stable versions to keep the base updated and will have our own versions for other additions.

McTorch is released under the open source 3-clause BSD License.

Team

Reference

Please cite [1] if you found this code useful.

McTorch, a manifold optimization library for deep learning

[1] M. Meghawanshi, P. Jawanpuria, A. Kunchukuttan, H. Kasai, and B. Mishra, McTorch, a manifold optimization library for deep learning

@techreport{meghwanshi2018mctorch,
  title={McTorch, a manifold optimization library for deep learning},
  author={Meghwanshi, Mayank and Jawanpuria, Pratik and Kunchukuttan, Anoop and Kasai, Hiroyuki and Mishra, Bamdev},
  institution={arXiv preprint arXiv:1810.01811},
  year={2018}
}

mctorch's People

Contributors

andresy avatar apaszke avatar bddppq avatar bwasti avatar colesbury avatar ezyang avatar gchanan avatar goldsborough avatar houseroad avatar jerryzh168 avatar jspark1105 avatar killeent avatar mingzhe09088 avatar mrshenli avatar onnxbot avatar orionr avatar peterjc123 avatar pietern avatar pjh5 avatar smessmer avatar soumith avatar ssnl avatar suo avatar vishwakftw avatar wanchaol avatar xuhdev avatar yangqing avatar zasdfgbnm avatar zdevito avatar zou3519 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.