GithubHelp home page GithubHelp logo

alexxnica / learning-to-learn Goto Github PK

View Code? Open in Web Editor NEW

This project forked from google-deepmind/learning-to-learn

0.0 2.0 0.0 88 KB

Learning to Learn in TensorFlow

Home Page: https://arxiv.org/abs/1606.04474

License: Apache License 2.0

Python 100.00%

learning-to-learn's Introduction

Learning to Learn in TensorFlow

Dependencies

Training

python train.py --problem=mnist --save_path=./mnist

Command-line flags:

  • save_path: If present, the optimizer will be saved to the specified path every time the evaluation performance is improved.
  • num_epochs: Number of training epochs.
  • log_period: Epochs before mean performance and time is reported.
  • evaluation_period: Epochs before the optimizer is evaluated.
  • evaluation_epochs: Number of evaluation epochs.
  • problem: Problem to train on. See Problems section below.
  • num_steps: Number of optimization steps.
  • unroll_length: Number of unroll steps for the optimizer.
  • learning_rate: Learning rate.
  • second_derivatives: If true, the optimizer will try to compute second derivatives through the loss function specified by the problem.

Evaluation

python evaluate.py --problem=mnist --optimizer=L2L --path=./mnist

Command-line flags:

  • optimizer: Adam or L2L.
  • path: Path to saved optimizer, only relevant if using the L2L optimizer.
  • learning_rate: Learning rate, only relevant if using Adam optimizer.
  • num_epochs: Number of evaluation epochs.
  • seed: Seed for random number generation.
  • problem: Problem to evaluate on. See Problems section below.
  • num_steps: Number of optimization steps.

Problems

The training and evaluation scripts support the following problems (see util.py for more details):

  • simple: One-variable quadratic function.
  • simple-multi: Two-variable quadratic function, where one of the variables is optimized using a learned optimizer and the other one using Adam.
  • quadratic: Batched ten-variable quadratic function.
  • mnist: Mnist classification using a two-layer fully connected network.
  • cifar: Cifar10 classification using a convolutional neural network.
  • cifar-multi: Cifar10 classification using a convolutional neural network, where two independent learned optimizers are used. One to optimize parameters from convolutional layers and the other one for parameters from fully connected layers.

New problems can be implemented very easily. You can see in train.py that the meta_minimize method from the MetaOptimizer class is given a function that returns the TensorFlow operation that generates the loss function we want to minimize (see problems.py for an example).

It's important that all operations with Python side effects (e.g. queue creation) must be done outside of the function passed to meta_minimize. The cifar10 function in problems.py is a good example of a loss function that uses TensorFlow queues.

Disclaimer: This is not an official Google product.

learning-to-learn's People

Contributors

sergomezcol avatar normanheckscher avatar ncoronges avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.