GithubHelp home page GithubHelp logo

mldl / pade-activation-unit Goto Github PK

View Code? Open in Web Editor NEW

This project forked from christophreich1996/pade-activation-unit

0.0 0.0 0.0 20 KB

PyTorch reimplementation of the paper "Padé Activation Units: End-to-end Learning of Flexible Activation Functions in Deep Networks" [ICLR 2020].

Home Page: https://openreview.net/pdf?id=BJlBSkHtDS

License: MIT License

Python 100.00%

pade-activation-unit's Introduction

Padé Activation Units: End-to-end Learning of Flexible Activation Functions in Deep Networks

License: MIT

Unofficial PyTorch reimplementation of the paper Padé Activation Units: End-to-end Learning of Flexible Activation Functions in Deep Networks by Molina et al., published at ICLR 2020.

This repository includes an easy-to-use pure PyTorch implementation of the Padé Activation Unit (PAU).

Please note that the official implementation provides a probably faster CUDA implementation!

Installation

The PAU can be installed by using pip.

pip install git+https://github.com/ChristophReich1996/Pade-Activation-Unit

Example Usage

The PAU can be simply used as a standard nn.Module:

import torch
import torch.nn as nn
from pau import PAU

network: nn.Module = nn.Sequential(
    nn.Linear(2, 2),
    PAU(),
    nn.Linear(2, 2)
)

output: torch.Tensor = network(torch.rand(16, 2))

The PAU is implemented in an efficient way (checkpointing and sequential computation of Vandermonde matrix), if you want to use the faster but more memory intensive version please use PAU(efficient=False)

If a nominator degree of 5 and a denominator degree of 4 is used the following initializations are available: ReLU (initial_shape=relu), Leaky ReLU negative slope=0.01 (initial_shape=leaky_relu_0_01), Leaky ReLU negative slope=0.2 (initial_shape=leaky_relu_0_2), Leaky ReLU negative slope=0.25 (initial_shape=leaky_relu_0_25), Leaky ReLU negative slope=0.3 (initial_shape=leaky_relu_0_3), Leaky ReLU negative slope=-0.5 (initial_shape=leaky_relu_m0_5), Tanh (initial_shape=tanh), Swish (initial_shape=swish), Sigmoid (initial_shape=sigmoid).

If a different nominator and denominator degree or initial_shape=None is utilized the PAU is initialized with random weights.

If you would like to fix the weights of multiple PAUs in a nn.Module just call module = pau.freeze_pau(module).

For a more detailed examples on hwo to use this implementation please refer to the example file (requires Matplotlib to be installed).

The PAU takes the following parameters.

Parameter Description Type
m Size of nominator polynomial. Default 5. int
n Size of denominator polynomial. Default 4. int
initial_shape Initial shape of PAU, if None random shape is used, also if m and n are not the default value (5 and 4) a random shape is utilized. Default "leaky_relu_0_2". Optional[str]
efficient If true efficient variant with checkpointing is used. Default True. bool
eps Constant for numerical stability. Default 1e-08. float
**kwargs Unused additional key word arguments Any

Reference

@inproceedings{Molina2020,
    title={{Padé Activation Units: End-to-end Learning of Flexible Activation Functions in Deep Networks}},
    author={Alejandro Molina and Patrick Schramowski and Kristian Kersting},
    booktitle={International Conference on Learning Representations},
    year={2020}
}

pade-activation-unit's People

Contributors

christophreich1996 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.