GithubHelp home page GithubHelp logo

concretedropoutlayer.jl's Introduction

Concrete Dropout in Julia

Implementation of the Concrete Dropout layer by Y. Gal et al. in Julia with the Deep Learning package Flux.jl and Lux.jl.

The notebook example regression_MCDropout_Lux.ipynb or regression_MCDropout_Flux.ipynb showcases the usage of Concrete Dropout layers in the context of Bayesian Neural Networks (see this paper).

Warning: I tried to use Package extension to have a version for Flux or Lux depending on which you load. Unfortunately, it was not as easy as I thought e.g. this PR and a lot of related question on Discourse. I am not sure what I was aiming for is currently possible easily.

For Flux version, the initial version v0.0.0 should work. For Lux please use the latest version.

Download

Add this module as an unregistered Julia package or via my local registry

import Pkg
Pkg.add(url="https://github.com/dmetivie/ConcreteDropoutLayer.jl") # master branch
# or
using LocalRegistry
Pkg.pkg"registry add https://github.com/dmetivie/LocalRegistry"
Pkg.add("ConcreteDropoutLayer.jl") # you can select the version depending on what Flux/Lux you want

Lux usage

using Lux, Random
using ConcreteDropoutLayer

channel = 10

model = Chain(
        Conv((3,), channel => 64, relu),
        ConcreteDropout(; dims=(2, 3)), # ConcreteDropout for Conv1D layer
        FlattenLayer(),
        Dense(6272 => 100, relu),
        ConcreteDropout(), # ConcreteDropout for Dense layer
    )

See the notebook for a complete example.

Flux Usage

On version v0.0.0 of the package only for now!

Adding a Concrete Dropout layer

Then add the layers like any other layers

using Flux
using ConcreteDropoutLayer

channel = 10

model = Chain(
        Conv((3,), channel => 64, relu),
        ConcreteDropout(; dims=(2, 3)), # ConcreteDropout for Conv1D layer
        Flux.MLUtils.flatten,
        Dense(6272 => 100, relu),
        ConcreteDropout(), # ConcreteDropout for Dense layer
    )
X = rand(Float32, 100, channel, 258)
output = model(X)

If you want to use Concrete Dropout outside training, e.g., Monte Carlo Dropout, use Flux.testmode!(model, false).

Training

To add the regularization to the loss as proposed in the Concrete Dropout paper use

wr = get_weight_regularizer(n_train, l=1.0f-2, τ=1.0f0) # weight regularization hyperparameter
dr = get_dropout_regularizer(n_train, τ=1.0f0, cross_entropy_loss=false) # dropout hyperparameter

full_loss(model, x, y; kwargs...) = original_loss(model(x), y) + add_CD_regularization(model; kwargs...)

API

mutable struct ConcreteDropout{F,D,R<:AbstractRNG}
  p_logit::F # logit value of the dropout probability
  dims::D # dimension to which the Dropout is applied
  active::Union{Bool,Nothing} # weather dropout is active or not
  rng::R # rng used for the dropout
end

Here is a reminder of the typical dims setting depending on the type of previous layer

  • On Dense layer, use dims = : i.e. it acts on all neurons and samples independently
  • On "Conv1D", use dims = (2,3) i.e. it applies Concrete Dropout independently to each feature (channel) and all samples (but it is the same for the first dimension)
  • On "Conv2D", use dims = (3,4)
  • On "Conv3D", use dims = (4,5)

TODO

  • Clean regularization   - Ideally, the L2 term should directly be in the optimizer with something like OptimiserChain(WeightDecay(lw/(1-p)), Adam(0.1)). And at each time step, the value of p is adjust!. Or maybe with another normalization, one could get rid of the 1/(1-p).   - The entropy and L2 regularization are handled automatically, i.e., all relevant layers (nested or not) are found quickly and adjusted at every step. (Done for Lux)

Acknowledgments

This code is inspired by the Python (tensorflow/pytorch) implementations of @aurelio-amerio, see his module. Thanks to @ToucheSir for some useful comments on the Flux version.

concretedropoutlayer.jl's People

Contributors

dmetivie avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.