GithubHelp home page GithubHelp logo

anandbhattad / inferno Goto Github PK

View Code? Open in Web Editor NEW

This project forked from inferno-pytorch/inferno

0.0 1.0 0.0 39.58 MB

A utility library around PyTorch

License: Other

Python 99.45% Shell 0.10% Makefile 0.46%

inferno's Introduction

Inferno

https://travis-ci.org/inferno-pytorch/inferno.svg?branch=master Documentation Status

Inferno is a little library providing utilities and convenience functions/classes around PyTorch. It's a work-in-progress, but the latest release (0.3.1) should be fairly stable!

Features

Current features include:
import torch.nn as nn
from inferno.io.box.cifar import get_cifar10_loaders
from inferno.trainers.basic import Trainer
from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
from inferno.extensions.layers.convolutional import ConvELU2D
from inferno.extensions.layers.reshape import Flatten

# Fill these in:
LOG_DIRECTORY = '...'
SAVE_DIRECTORY = '...'
DATASET_DIRECTORY = '...'
DOWNLOAD_CIFAR = True
USE_CUDA = True

# Build torch model
model = nn.Sequential(
    ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2),
    ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2),
    ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2),
    Flatten(),
    nn.Linear(in_features=(256 * 4 * 4), out_features=10),
    nn.LogSoftmax(dim=1)
)

# Load loaders
train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY,
                                                    download=DOWNLOAD_CIFAR)

# Build trainer
trainer = Trainer(model) \
  .build_criterion('NLLLoss') \
  .build_metric('CategoricalError') \
  .build_optimizer('Adam') \
  .validate_every((2, 'epochs')) \
  .save_every((5, 'epochs')) \
  .save_to_directory(SAVE_DIRECTORY) \
  .set_max_num_epochs(10) \
  .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
                                  log_images_every='never'),
                log_directory=LOG_DIRECTORY)

# Bind loaders
trainer \
    .bind_loader('train', train_loader) \
    .bind_loader('validate', validate_loader)

if USE_CUDA:
  trainer.cuda()

# Go!
trainer.fit()

To visualize the training progress, navigate to LOG_DIRECTORY and fire up tensorboard with

$ tensorboard --logdir=${PWD} --port=6007

and navigate to localhost:6007 with your browser.

Installation

Conda packages for python >= 3.6 for all distributions are availaible on conda-forge:

$ conda install -c pytorch -c conda-forge inferno

Future Features:

Planned features include:
  • a class to encapsulate Hogwild! training over multiple GPUs,
  • minimal shape inference with a dry-run,
  • proper packaging and documentation,
  • cutting-edge fresh-off-the-press implementations of what the future has in store. :)

Credits

All contributors are listed here_. .. _here: https://inferno-pytorch.github.io/inferno/html/authors.html

This package was partially generated with Cookiecutter and the audreyr/cookiecutter-pypackage project template + lots of work by Thorsten.

inferno's People

Contributors

nasimrahaman avatar derthorsten avatar constantinpape avatar steffen-wolf avatar bstriner avatar manuelhaussmann avatar vzinche avatar svenpeter42 avatar abailoni avatar ottogin avatar imagirom avatar fynnbe avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.