GithubHelp home page GithubHelp logo

paddlechainrules.jl's Introduction

PaddleChainRules

The idea is from PyCallChainRules.jl

a small demo package of wrapping a full cennected Dense network of PaddlePaddle in julia, and make it differentiable by ChainRulesCore.rrule.

Example

CPU

#install paddlepaddle
using PyCall
run(`$(PyCall.pyprogramname) -m  pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html`)

using PaddleChainRules.Paddle: paddle, PaddleModuleWrapper, PaddleFCNet
using Zygote

dim_ins = 3
hidden_size = 16
dim_outs = 2
batch_size = 32
num_layers = 2

# now only support full connected Dense network
NN = paddle.nn.Sequential(
        paddle.nn.Linear(dim_ins, hidden_size),
        paddle.nn.Sigmoid(),
        paddle.nn.Linear(hidden_size, dim_outs)
    )

jlwrap = PaddleModuleWrapper(NN)

# or use a constructor for full connected network
jlwrap = PaddleFCNet(dim_ins, dim_outs, num_layers, hidden_size; activation="sigmoid")

input = rand(Float32, dim_ins, batch_size)

output = jlwrap(input)

target = rand(Float32, dim_outs, batch_size)
loss(m, x, y) = sum(abs2.(m(x) .- y))

# grad of params 
grad, = Zygote.gradient(m->loss(m, input, target), jlwrap)
# grad of input
grad, = Zygote.gradient(x->loss(jlwrap, x, target), input)

GPU

#install paddlepaddle-gpu
using PyCall
run(`$(PyCall.pyprogramname) -m  pip install paddlepaddle-gpu`)

using PaddleChainRules.Paddle: paddle, PaddleModuleWrapper, PaddleFCNet
using CUDA
using Zygote
# paddle-gpu will use cuda defualtly if cuda is useable
# or set up the device by hand
paddle.device.set_device("gpu:0")

dim_ins = 3
hidden_size = 16
dim_outs = 2
batch_size = 32
num_layers = 2

# now only support full connected Dense network
NN = paddle.nn.Sequential(
        paddle.nn.Linear(dim_ins, hidden_size),
        paddle.nn.Sigmoid(),
        paddle.nn.Linear(hidden_size, dim_outs)
    )

jlwrap = PaddleModuleWrapper(NN)

# or use a constructor for full connected network
jlwrap = PaddleFCNet(dim_ins, dim_outs, num_layers, hidden_size; activation="sigmoid")

input = CUDA.cu(rand(Float32, dim_ins, batch_size))

output = jlwrap(input)

target = CUDA.cu(rand(Float32, dim_outs, batch_size))
loss(m, x, y) = sum(abs2.(m(x) .- y))

# grad of params 
grad, = Zygote.gradient(m->loss(m, input, target), jlwrap)
# grad of input
grad, = Zygote.gradient(x->loss(jlwrap, x, target), input)

And there is a demo for neuralPDE.

TODO

  • In the demo of neuralPDE, this package is much slower than Flux.jl, need to imporve the speed.
  • Now only the Dense network is supported, more genneral network structure?(rough solution in #2)
  • test code. compare output of forwrad and backward to the result from paddle's api.(done)
  • Some benchmarks:
    • forward and backward.(done)
    • possion equation with NeuralPDE, compared with PyCallChainRules and Flux.

paddlechainrules.jl's People

Contributors

songjhaha avatar

Stargazers

 avatar

Watchers

 avatar  avatar

paddlechainrules.jl's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.