GithubHelp home page GithubHelp logo

szacho / augmix-tf Goto Github PK

View Code? Open in Web Editor NEW
16.0 4.0 2.0 8.71 MB

Implementation of AugMix (2020) in TensorFlow

Home Page: https://pypi.org/project/augmix-tf/

License: MIT License

Python 100.00%
data-augmentation augmix

augmix-tf's Introduction

augmix-tf

NOTE: this implementation is no longer supported, you can use keras-team/keras-cv#407 as a replacement (it's based on this repo).

Augmix-tf is an implementation of novel data augmentation AugMix (2020) in TensorFlow. It runs on TPU.

AugMix utilizes simple augmentation operations which are stochastically sampled and layered to produce a high diversity of augmented images. The process of mixing basic tranformations into augmented image is shown below (picture taken from the original paper). This augmentation performs better when used in concert with Jensen-Shannon Divergence Consistency Loss. AugMix pipeline

Installation

pip install augmix-tf

Usage

AugMix

The main function, which does the augmentation is AugMix.transform, let's print a docstring of it.

from augmix import AugMix
print(AugMix.transform.__doc__)
	Performs AugMix data augmentation on given image.

	Parameters:
	image (tf tensor): an image tensor with shape (x, x, 3) and values scaled to range [0, 1]
	severity (int): level of a strength of transformations (integer from 1 to 10)
	width (int): number of different chains of transformations to be mixed
	depth (int): number of transformations in one chain, -1 means random from 1 to 3

	Returns:
	tensor: augmented image

Example 1 - transforming a single image

from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from augmix import AugMix

# precalculated means and stds of the dataset (in RGB order)
means = [0.44892993872313053, 0.4148519066242368, 0.301880284715257]
stds = [0.24393544875614917, 0.2108791383467354, 0.220427056859487]
ag = AugMix(means, stds)

# preprocess
image = np.asarray(Image.open('geranium.jpg'))
image = tf.convert_to_tensor(image)
image = tf.cast(image, dtype=tf.float32)
image = tf.image.resize(image, (331, 331)) # resize to square
image /=  255  # scale to [0, 1]

# augment
augmented = ag.transform(image)

# visualize
comparison = tf.concat([image, augmented], axis=1)
plt.imshow(comparison.numpy())
plt.title("Original image (left) and augmented image (right).")
plt.show()

result of example 1)

Example 2 - transforming a dataset of images

# here a dataset is a tf.data.Dataset object
# assuming images are properly preprocessed (see example 1)
dataset = dataset.map(lambda  img: ag.transform(img))

Example 3 - transforming a dataset to use with the Jensen-Shannon loss

# here a dataset is a tf.data.Dataset object
# assuming images are properly preprocessed (see example 1)
dataset = dataset.map(lambda  img: (img, ag.transform(img), ag.transform(img)))

Visualization

AugMix

original images original images

augmented visualization of augmix

Simple transformations

AugMix mixes images transformed by simple augmentations defined in transformations.py file. Every transformation function takes an image and level parameter that determines a strength of this transformation. This level parameter has the same value as severity parameter in AugMix.transform function, so again it is the integer between 1 and 10, where 10 means the strongest augmentation. These functions can be used by itself. Below is a visualization what every simple augmentation does to a batch of images (at level 10).

translate_x, translate_y translate

rotate rotate

shear_x, shear_y shear

solarize solarize

solarize_add solarize add

posterize posterize

autocontrast autocontrast

contrast contrast

equalize equalize

brightness brightness

color color

More information

TODO

  • batch implementation of AugMix
  • possibility to choose basic transformations easily
  • appendix
    • calculation of mean and standard devation on a dataset
    • implementation of Jensen-Shannon Divergence Consistency Loss

License

MIT

augmix-tf's People

Contributors

szacho avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

augmix-tf's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.