GithubHelp home page GithubHelp logo

zzwei1 / snow-gan-classification Goto Github PK

View Code? Open in Web Editor NEW

This project forked from jleinonen/snow-gan-classification

0.0 1.0 0.0 24 KB

Classification of snowflakes with a GAN and K-medoids

Python 100.00%

snow-gan-classification's Introduction

GAN-based unsupervised classification of snowflake images

This Python/TensorFlow code demonstrates unsupervised snowflake classification from images obtained with the Multi-Angle Snowflake Camera using a GAN and K-medoids classification. It supports a paper "Unsupervised classification of snowflake images using a generative adversarial network and K-medoids classification" to be submitted to Atmospheric Measurement Techniques and provides all code needed to replicate the results.

Instructions

Requirements

You need a Python 3 environment and the following libraries:

  • TensorFlow (not tested with 2.0+)
  • NumPy
  • SciPy
  • Matplotlib
  • Seaborn
  • Python NetCDF4
  • h5py
  • imageio
  • Dask

A GPU is highly recommended for training, but the experiments with pre-trained models can be run on a CPU as well. 16+ GB of RAM should be enough.

Data

Download the training datasets here (they are too big to include in the repository). Save the .nc and .npy files in the data directory.

If you want to use the pre-trained models, you can download them here. Save the contents of the zip file in the models directory.

Running the code

The high-level code that runs the training and evaluation needed to replicate the results can be found in replication.py. This file has a command line interface (see below), but you could also call the functions within from an iPython terminal or a Jupyter notebook.

If you want to modify the training code, you should start by following the code flow in replication.training.

Running the plotting and evaluation

You can evaluate the model and generate the plots shown in the paper using the downloadable datasets and the pre-trained GAN on the command line in the snow-gan-classification directory like this:

python replication.py experiments --model_name=../models/masc_infogan_combined

where model_name is the name of the model you want to load (use the default for the pre-trained model). For the pre-trained model, this should replicate the results exactly. If you trained the GAN yourself, you probably will get slightly different results. The plots will be saved in the figures directory.

In practice, you may want to run the experiments one by one by copypasting the code from replication.experiments to a terminal.

Training the GAN

You can run the training like this:

python replication.py train --model_save_name=../models/masc_infogan

Change the --model_save_name parameter to the name of the model you want to save. You can load a pre-existing model at the start of training using the --model_name parameter. So, for example, to load the pre-trained model and train it further:

python replication.py train --model_name=../models/masc_infogan_combined --model_save_name=../models/masc_infogan

Computing the latent variables

Run the following to compute the latent variables for all snowflakes in the dataset:

python replication.py latents --model_name=../models/masc_infogan_combined --latents_file=../data/masc_latents.nc --latent_dist_file=../data/masc_latent_dist.nc

where the --latents_file and --latent_dist_file parameters control where the latents are saved.

snow-gan-classification's People

Contributors

jleinonen avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.