GithubHelp home page GithubHelp logo

holbsmasto / optax Goto Github PK

View Code? Open in Web Editor NEW

This project forked from google-deepmind/optax

0.0 0.0 0.0 2.28 MB

Optax is a gradient processing and optimization library for JAX.

Home Page: https://optax.readthedocs.io

License: Apache License 2.0

Shell 0.49% Python 99.51%

optax's Introduction

Optax

CI status Documentation Status pypi

Introduction

Optax is a gradient processing and optimization library for JAX.

Optax is designed to facilitate research by providing building blocks that can be easily recombined in custom ways.

Our goals are to

  • Provide simple, well-tested, efficient implementations of core components.
  • Improve research productivity by enabling to easily combine low level ingredients into custom optimisers (or other gradient processing components).
  • Accelerate adoption of new ideas by making it easy for anyone to contribute.

We favour focusing on small composable building blocks that can be effectively combined into custom solutions. Others may build upon these basic components more complicated abstractions. Whenever reasonable, implementations prioritise readability and structuring code to match standard equations, over code reuse.

An initial prototype of this library was made available in JAX's experimental folder as jax.experimental.optix. Given the wide adoption across DeepMind of optix, and after a few iterations on the API, optix was eventually moved out of experimental as a standalone open-source library, renamed optax.

Documentation on Optax can be found at optax.readthedocs.io.

Installation

You can install the latest released version of Optax from PyPI via:

pip install optax

or you can install the latest development version from GitHub:

pip install git+https://github.com/deepmind/optax.git

Quickstart

Optax contains implementations of many popular optimizers and loss functions. For example the following code snippet uses the Adam optimizer from optax.adam and the mean squared error from optax.l2_loss. We initialize the optimizer state using the init function and params of the model.

optimizer = optax.adam(learning_rate)
# Obtain the `opt_state` that contains statistics for the optimizer.
params = {'w': jnp.ones((num_weights,))}
opt_state = optimizer.init(params)

To write the update loop we need a loss function that can be differentiated by Jax (with jax.grad in this example) to obtain the gradients.

compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), y)
grads = jax.grad(compute_loss)(params, xs, ys)

The gradients are then converted via optimizer.update to obtain the updates that should be applied to the current params to obtain the new ones. optax.apply_updates is a convinience utility to do this.

updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)

You can continue the quick start in the Optax quickstart notebook.

Components

We refer to the docs for a detailed list of available Optax components. Here, we highlight the main categories of building blocks provided by Optax.

Gradient Transformations (transform.py)

One of the key building blocks of optax is a GradientTransformation.

Each transformation is defined two functions:

  • state = init(params)
  • grads, state = update(grads, state, params=None)

The init function initializes a (possibly empty) set of statistics (aka state) and the update function transforms a candidate gradient given some statistics, and (optionally) the current value of the parameters.

For example:

tx = scale_by_rms()
state = tx.init(params)  # init stats
grads, state = tx.update(grads, state, params)  # transform & update stats.

Composing Gradient Transformations (combine.py)

The fact that transformations take candidate gradients as input and return processed gradients as output (in contrast to returning the updated parameters) is critical to allow to combine arbitrary transformations into a custom optimiser / gradient processor, and also allows to combine transformations for different gradients that operate on a shared set of variables.

For instance, chain combines them sequentially, and returns a new GradientTransformation that applies several transformations in sequence.

For example:

my_optimiser = chain(
    clip_by_global_norm(max_norm),
    scale_by_adam(eps=1e-4),
    scale(-learning_rate))

Wrapping Gradient Transformations (wrappers.py)

Optax also provides several wrappers that take a GradientTransformation as input and return a new GradientTransformation that modifies the behaviour of the inner transformation in a specific way.

For instance the flatten wrapper flattens gradients into a single large vector before applying the inner GradientTransformation. The transformed updated are then unflattened before being returned to the user. This can be used to reduce the overhead of performing many calculations on lots of small variables, at the cost of increasing memory usage.

For example:

my_optimiser = flatten(adam(learning_rate))

Other examples of wrappers include accumulating gradients over multiple steps, or applying the inner transformation only to specific parameters or at specific steps.

Schedules (schedule.py)

Many popular transformations use time dependent components, e.g. to anneal some hyper-parameter (e.g. the learning rate). Optax provides for this purpose schedules that can be used to decay scalars as a function of a step count.

For example you may use a polynomial schedule (with power=1) to decay a hyper-parameter linearly over a number of steps:

schedule_fn = polynomial_schedule(
    init_value=1., end_value=0., power=1, transition_steps=5)

for step_count in range(6):
  print(schedule_fn(step_count))  # [1., 0.8, 0.6, 0.4, 0.2, 0.]

Schedules are used by certain gradient transformation, for instance:

schedule_fn = polynomial_schedule(
    init_value=-learning_rate, end_value=0., power=1, transition_steps=5)
optimiser = chain(
    clip_by_global_norm(max_norm),
    scale_by_adam(eps=1e-4),
    scale_by_schedule(schedule_fn))

Popular optimisers (alias.py)

In addition to the low level building blocks we also provide aliases for popular optimisers built using these components (e.g. RMSProp, Adam, AdamW, etc, ...). These are all still instances of a GradientTransformation, and can therefore be further combined with any of the individual building blocks.

For example:

def adamw(learning_rate, b1, b2, eps, weight_decay):
  return chain(
      scale_by_adam(b1=b1, b2=b2, eps=eps),
      scale_and_decay(-learning_rate, weight_decay=weight_decay))

Applying updates (update.py)

After transforming an update using a GradientTransformation or any custom manipulation of the update, you will typically apply the update to a set of parameters. This can be done trivially using tree_map.

For convenience, we expose an apply_updates function to apply updates to parameters. The function just adds the updates and the parameters together, i.e. tree_map(lambda p, u: p + u, params, updates).

updates, state = tx.update(grads, state, params)  # transform & update stats.
new_params = optax.apply_updates(params, updates)  # update the parameters.

Note that separating gradient transformations from the parameter update is critical to support composing sequence of transformations (e.g. chain), as well as combine multiple updates to the same parameters (e.g. in multi-task settings where different tasks need different sets of gradient transformations).

Losses (loss.py)

Optax provides a number of standard losses used in deep learning, such as l2_loss, softmax_cross_entropy, cosine_distance, etc.

loss = huber_loss(predictions, targets)

The losses accept batches as inputs, however they perform no reduction across the batch dimension(s). This is trivial to do in JAX, for example:

avg_loss = jnp.mean(huber_loss(predictions, targets))
sum_loss = jnp.sum(huber_loss(predictions, targets))

Second Order (second_order.py)

Computing the Hessian or Fisher information matrices for neural networks is typically intractable due to the quadratic memory requirements. Solving for the diagonals of these matrices is often a better solution. The library offers functions for computing these diagonals with sub-quadratic memory requirements.

Stochastic gradient estimators (stochastic_gradient_estimators.py)

Stochastic gradient estimators compute Monte Carlo estimates of gradients of the expectation of a function under a distribution with respect to the distribution's parameters.

Unbiased estimators, such as the score function estimator (REINFORCE), pathwise estimator (reparameterization trick) or measure valued estimator, are implemented: score_function_jacobians, pathwise_jacobians and measure_valued_jacobians. Their applicability (both in terms of functions and distributions) is discussed in their respective documentation.

Stochastic gradient estimators can be combined with common control variates for variance reduction via control_variates_jacobians. For provided control variates see delta and moving_avg_baseline.

The result of a gradient estimator or control_variates_jacobians contains the Jacobians of the function with respect to the samples from the input distribution. These can then be used to update distributional parameters, or to assess gradient variance.

Example of how to use the pathwise_jacobians estimator:

dist_params = [mean, log_scale]
function = lambda x: jnp.sum(x * weights)
jacobians = pathwise_jacobians(
      function, dist_params,
      utils.multi_normal, rng, num_samples)

mean_grads = jnp.mean(jacobians[0], axis=0)
log_scale_grads = jnp.mean(jacobians[1], axis=0)
grads = [mean_grads, log_scale_grads]
optim_update, optim_state = optim.update(grads, optim_state)
updated_dist_params = optax.apply_updates(dist_params, optim_update)

where optim is an optax optimizer.

Citing Optax

This repository is part of the DeepMind JAX Ecosystem, to cite Optax please use the citation:

@software{deepmind2020jax,
  title = {The {D}eep{M}ind {JAX} {E}cosystem},
  author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
  url = {http://github.com/google-deepmind},
  year = {2020},
}

optax's People

Contributors

mtthss avatar n2cholas avatar hbq1 avatar mkunesch avatar joaogui1 avatar rosshemsley avatar 8bitmp3 avatar wdphy16 avatar fabianp avatar holounic avatar lukasmut avatar hawkinsp avatar acforvs avatar tanaymeh avatar vz415 avatar atgctg avatar rdaems avatar suryabhupa avatar tomhennigan avatar vwxyzjn avatar shivance avatar mblondel avatar rwightman avatar rupt avatar dt6a avatar patrickvonplaten avatar yotarok avatar sauravmaheshkar avatar nasyxx avatar merajhashemi avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.