GithubHelp home page GithubHelp logo

github30 / spiral Goto Github PK

View Code? Open in Web Editor NEW

This project forked from google-deepmind/spiral

0.0 1.0 1.0 29 KB

We provide a pre-trained model for unconditional 19-step generation of CelebA-HQ images

License: Apache License 2.0

CMake 4.16% Python 64.38% C++ 19.37% Jupyter Notebook 12.10%

spiral's Introduction

SPIRAL

Overview

This repository contains agents and environments described in the ICML'18 paper "Synthesizing Programs for Images using Reinforced Adversarial Learning". For the time being, we are providing the libmypaint-based simulator (more coming soon) and a Sonnet module for the unconditional agent as well as pre-trained model snapshots (9 agents from a single population) available from TF-Hub.

If you feel an immediate urge to dive into the code the most relevant files are:

Path Description
spiral/agents/default.py The architecture of the agent
spiral/environments/libmypaint.py The libmypaint-based environment

Reference

If this repository is helpful for your research please cite the following publication:

@inproceedings{ganin2018synthesizing,
  title={Synthesizing Programs for Images using Reinforced Adversarial Learning},
  author={Ganin, Yaroslav and Kulkarni, Tejas and Babuschkin, Igor and Eslami, SM Ali and Vinyals, Oriol},
  booktitle={ICML},
  year={2018}
}

Installation

Clone this repository and fetch the external submodules:

git clone https://github.com/deepmind/spiral.git
cd spiral
git submodule update --init --recursive

Install necessary packages:

apt-get install cmake pkg-config libjson-c-dev intltool libpython3-dev python3-pip
pip3 install six setuptools numpy tensorflow==1.14 tensorflow-hub dm-sonnet

WARNING: Make sure that you have cmake 3.14 or later since we rely on its capability to find numpy libraries. If your package manager doesn't provide it follow the installation instructions from here. You can check the version by running cmake --version .

Finally, run the following command to install the SPIRAL package itself:

python3 setup.py develop --user

You will also need to obtain the brush files for the libmypaint environment to work properly. These can be found here. For example, you can place them in third_party folder like this:

wget -c https://github.com/mypaint/mypaint-brushes/archive/v1.3.0.tar.gz -O - | tar -xz -C third_party

Optionally, in order to be able to try out the package in the provided jupyter notebook, you’ll need to install the following packages:

pip3 install matplotlib jupyter

Usage

For a basic example of how to use the package please follow this notebook.

Sampling from a pre-trained model

We provide a pre-trained model for unconditional 19-step generation of CelebA-HQ images. Here is how you can sample from it:

import matplotlib.pyplot as plt

import spiral.agents.default as default_agent
import spiral.agents.utils as agent_utils
import spiral.environments.libmypaint as libmypaint


# The path to a TF-Hub module.
MODULE_PATH = "https://tfhub.dev/deepmind/spiral/default-wgangp-celebahq64-gen-19steps/agent4/1"
# The folder containing `libmypaint` brushes.
BRUSHES_PATH = "the/path/to/libmypaint-brushes"

# Here, we create an environment.
env = libmypaint.LibMyPaint(episode_length=20,
                            canvas_width=64,
                            grid_width=32,
                            brush_type="classic/dry_brush",
                            brush_sizes=[1, 2, 4, 8, 12, 24],
                            use_color=True,
                            use_pressure=True,
                            use_alpha=False,
                            background="white",
                            brushes_basedir=BRUSHES_PATH)


# Now we load the agent from a snapshot.
initial_state, step = agent_utils.get_module_wrappers(MODULE_PATH)

# Everything is ready for sampling.
state = initial_state()
noise_sample = np.random.normal(size=(10,)).astype(np.float32)

time_step = env.reset()
for t in range(19):
    time_step.observation["noise_sample"] = noise_sample
    action, state = step(time_step.step_type, time_step.observation, state)
    time_step = env.step(action)

# Show the sample.
plt.close("all")
plt.imshow(time_step.observation["canvas"], interpolation="nearest")

Converting a trained agent into a TF-Hub module

import spiral.agents.default as default_agent
import spiral.agents.utils as agent_utils
import spiral.environments.libmypaint as libmypaint


# This where we're going to put our TF-Hub module.
TARGET_PATH = ...
# A path to a checkpoint of the trained model.
CHECKPOINT_PATH = ...

# We will need to create an environment in order to obtain the specifications
# for the agent's action and the observation.
env = libmypaint.LibMyPaint(...)

# Here, we wrap a Sonnet module constructor for our agent in a function.
# This is to avoid contaminating the default tensorflow graph.
def agent_ctor():
  return default_agent.Agent(action_spec=env.action_spec(),
                             input_shape=(64, 64),
                             grid_shape=(32, 32),
                             action_order="libmypaint")

# Finally, export a TF-Hub module. We need to specify which checkpoint to use
# to extract the weights for the agent. Since the variable names in the
# checkpoint may differ from the names in the Sonnet module produced by
# `agent_ctor`, we may also want to provide an appropriate name mapping
# function.
agent_utils.export_hub_module(agent_ctor=agent_ctor,
                              observation_spec=env.observation_spec(),
                              noise_dim=10,
                              module_path=TARGET_PATH,
                              checkpoint_path=CHECKPOINT_PATH,
                              name_transform_fn=lambda name: ...)

Disclaimer

This is not an official Google product.

spiral's People

Contributors

ddtm avatar

Watchers

 avatar

Forkers

githubgreat886

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.