GithubHelp home page GithubHelp logo

maxibove13 / srgan Goto Github PK

View Code? Open in Web Editor NEW
0.0 1.0 0.0 820 KB

A PyTorch implementation of a Super Resolution Generative Adversarial Network. Based on the work by Ledig et. al. 2017

Home Page: https://arxiv.org/abs/1609.04802

Python 68.95% Jupyter Notebook 30.95% Shell 0.10%

srgan's Introduction

Super Resolution GAN

A naive implementation of a Super-Resolution GAN. Based on the work by Ledig et. al. 2017 arxiv.org/abs/1609.04802

Example of SRGAN using DIV2K dataset Example of SRGAN using UxLES dataset

Instructions

  1. Copy config_sample.yaml file and rename it config.yaml in order to modify any configuration parameter you want without modifying the version control.

  2. Download either DIV2K or UxLES dataset and extract them in data/

  3. Optionally, download the pretrained Generators and Discriminators:

Now you should be ready to use the repository.

config.yaml

This is the configuration file, it is divided in:

data

Here you should specify the root directory of the data samples, the dataset (either DIV2K, UxLES or another of your choice), the high_res size, that is the crop of high resolution images that are going to feed the training network (recomended 96 in DIV2K and 64 in UxLES), the upscale_factor (4 by default) and the img_channels that should be 3 (RGB) in most cases.

models

You can choose to download the pretrained VGG network that is used for one of the Generator loss terms from torch if you set dwnld_vgg to False, or set it to True and download it previously from here: Pretrained VGG19

You should also specify if you want to load a pretrained Generator and Discriminator during training and if to save it at the end of each epoch. Also the filename and directory of the model to load or save must be stated.

train

Here the usual hyperparameters are specified: learning_rate, num_epochs, batch_size and num_workers.

validation

If to apply kfold cross validation during training or not, and how many splits.

figures

Just state the directory where you want to save the figures that result from training or testing.

Train the network

Setup the training session and the hyperparameters in config.yaml. There, you should check that the data root directory and dataset are the ones you want to train with. If you are training with UxLES check that high_res is 64 or less.

Run train_model.py script or run.sh & to run the training in the background and dump the prints in a train.log file.

python3 ./src/train_model.py

Test the network

In order to test the network on the testing dataset, make sure you extracted some dataset into data directory.

The script will lower the resolution of all testing images and evaluate the Generator on them. It will also generate a comparison figure, and create PSNR and SSIM histograms.

Run train_model.py script

python3 ./src/test_model.py

Google Colab

If you want to run this network in google colab:

After you clone the repository in a Google Drive folder, mount it to your colab session in order to use the repository features.

from google.colab import drive
drive.mount('/content/drive')

# Navigate to repository
%cd /content/drive/MyDrive/<Path_to_repo>

# For some reason we need to explicitly install this package
!pip install albumentations==0.4.6

K-fold

K-fold is a technique to evaluate a model performance.

By default, we split the data into two parts, training and testing. Then we check if the model trained with training data performs well on testing data. The issue with this is that there is a risk of overfitting on the test set as it is used to tweak the hyperparamaters. This way, knowledge about the test set can leak into the model and evaluation metrics no longer report on generalization performance. To solve this problem, yet another part of the dataset can be held out as a so-called 'validation set'. Training proceeds on the training set, after which evaluation is done on the validation set, and when te experiment seems to be succesful, final evaluation can be done on the test set.

However, by partitioning the available data into three sets, we drastically reduce the number of samples which can be used for learning the model, and the results can depend on a particular random choice for the pair of (train, validation) sets.

A solution to this problem is a procedure called cross-validation (CV for short). A test set should still be held out for final evaluation, but the validation set is no longer needed when doing CV. In the basic approach, called k-fold CV, the training set is split into k smaller sets. The following procedure is followed for each of the k “folds”:

  • A model is trained using k-1 of the folds as training data;

  • the resulting model is validated on the remaining part of the data (i.e., it is used as a test set to compute a performance measure such as accuracy).

The performance measure reported by k-fold cross-validation is then the average of the values computed in the loop. This approach can be computationally expensive, but does not waste too much data (as is the case when fixing an arbitrary validation set).

srgan's People

Contributors

maxibove13 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.