GithubHelp home page GithubHelp logo

upsampling's Introduction

Image Upsampling with PyTorch

This project trains an image upsampling model using PyTorch and the Microsoft DeepSpeed framework. The model is designed to increase the resolution of images by a factor of 2. For the most part this is currently a replication of the results of the authors of the VapSR ( https://github.com/zhoumumu/VapSR ) super-resolution model in my own training framework.

Technology explored in this project:

  • Image upsampling model training and evaluation with PSNR, SSIM, and LPIPS.
  • Dataset construction with multi-threaded Python tools.
  • Loading training data from image folders, video files, and ImageNet dataset.
  • Data loading with Nvidia DALI - Benchmarked nvJPEG2000, WebP, and PNG for lossless loading.
  • Model optimization with TorchDynamo and Inductor.
  • Model training with DeepSpeed, Accelerate and DistributedDataParallel using NCCL backend across a LAN cluster of nodes each with multple GPUs.
  • Evaluated whether OKLAB/YUV colorspace improves model performance.
  • Ablation study of different network architectures and training strategies.
  • Exploration of ImageNet and other normalization strategies, and in-network normalization.
  • Exploration of L1, L2(MSE), and LPIPS loss functions.
  • Conversion to ONNX and OpenVINO IR model formats and execution on Intel CPU/GPU.

There is a blog post that accompanies this project: https://catid.io/posts/upsampler/

Example

Top-Left: Original Image. Top-Right: My model output.

Side-By-Side Example

Bottom-Left: Input Image. Bottom-Right: Baseline Bicubic upsampler.

Prerequisites

  • Linux
  • Python 3.8 or higher
  • Conda
  • Ffmpeg
  • Nvidia GPU

How-To

  1. Clone this repository.
sudo apt install ffmpeg gcc-11 g++-11 libopenmpi-dev

git clone https://github.com/catid/upsampling.git
cd upsampling
  1. Create a new conda environment and install dependencies:
conda create --name upsampling python=3.8
conda activate upsampling

# Update this from https://pytorch.org/get-started/locally/
pip3 install torch torchvision functorch --extra-index-url https://download.pytorch.org/whl/cu118

# Update this from https://github.com/NVIDIA/DALI#installing-dali
pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110

pip install -r requirements.txt

Note that the default location for the dataset will be at ~/dataset for all the scripts.

  1. Extract training images from a folder of image files:

For training I used just the training sets from these datasets: DIV2K, Flickr2K, KODAK, Sunhays80. The validation sets are not used for fair comparison to other algorithms. You can grab these datasets from my Google Drive: https://drive.google.com/drive/folders/1kILRdaxD2o273E2NnrXzhaLbQCXDxz9l?usp=share_link

To incorporate one of these training sets into the dataset, create a directory containing only the *_HR.png images.

# Use the example KODAK dataset included with this repo under kodak/
python images_to_data.py kodak ~/dataset/kodak

Random crops are taken at different scales.

We do not 2x downsample PNG images because they are lossless images without chroma subsampling. The default is to 2x downsample JPEG input images because most images are 4:2:0 chroma subsampled and have some compression artifacts.

  1. Validate the training data to verify that all images are the same size.
# Print a list of all the data that is included in the dataset, to make sure nothing was accidentally left out
python list_data_folders.py

# Check the dataset for any images that are too small to use or truncated, and delete them
python check_training_data.py --fast --clean

You can remove the --fast parameter to do full image decoding during the check, but it is about 100x slower.

The --clean parameter will remove all the images that are too small or too large to be used. We currently train on random 256x256 crops of the input images in each epoch, so the files need to be bigger than that. They need to be smaller than about 2048x2048 to avoid crashing DALI during loading.

  1. Check training:

Select the training/validation files to use from the dataset:

python select_training_files.py

In the ~/sources/upsampling directory of each node, launch training locally to verify that it can run well using multiple GPUs:

./launch_local_train.sh --reset
  1. Train using the whole GPU cluster:

Update the hostfile to include all the nodes on your network. Modify the launch_distributed_train.sh script to use the correct number of GPUs and nodes.

./launch_distributed_train.sh --reset

To stop training just hit CTRL+C at any time, and when you run it again without the --reset argument it will resume from where it left off:

./launch_distributed_train.sh

To watch the Tensorboard:

./run_tensorboard.sh

Export the model to PyTorch .pth format:

python export_trained_model.py

This produces a file named upsampling.pth in the current directory.

  1. Evaluate the trained model:
python evaluate.py

This evaluates the model on the Urban100 dataset in this repo under urban100/. The side-by-side results are saved to outputs/.

  1. Export model to ONNX so that it can be used in other frameworks:
python convert_to_onnx.py

Results

For the following results we use validation loss as the metric for model selection. We stop training after 100 or 200 epochs with no improvement in validation loss, so these are all fully trained.

I tried converting from RGB to another perceptually uniform colorspace inside the network, but it did not improve the results. My conclusion is that using the native colorspace of the training data is the best option:

YUV PSNR: 29.25841712579564
OKLAB PSNR: 29.416221449440155

RGB:
commit 684cdf6f2a31af9311b20792e10afea80082892a
Trained for 41M iterations
2023-05-10 15:29:32,658 [INFO] Model PSNR: 30.183498600102137 - Bicubic PSNR: 24.87192185623686
2023-05-10 15:29:32,658 [INFO] Model SSIM: 0.9237528208169244 - Bicubic SSIM: 0.8233347716777945
2023-05-10 15:29:32,658 [INFO] Model LPIPS: 7.562230394501671e-05 - Bicubic LPIPS: 0.0013645301913965267

I used the VapSR upsampling network for my project since it is the best "small" model I could find, and it uses some modern ideas like attention. I tried modifying the VapSR network to use 3x3 convolutions instead of pointwise 1x1 convolutions in the residual blocks, and it improved the results significantly for such a small change. The network is 10% slower but has about 4% higher quality, which is a good tradeoff:

Conv2D-3x3 for each block
commit e05fb91dcc9948f343d44a9747b0564ce47289ad
2023-05-11 00:53:21,679 [INFO] Model PSNR: 30.419869491633694 - Bicubic PSNR: 24.87192185623686
2023-05-11 00:53:21,679 [INFO] Model SSIM: 0.9261785885967156 - Bicubic SSIM: 0.8233347716777945
2023-05-11 00:53:21,679 [INFO] Model LPIPS: 6.864146791633323e-05 - Bicubic LPIPS: 0.0013645301913965267

I tried training the model using full FP32 precision instead of FP16 and found that it improved performance but only slightly so does not seem worth the extra effort:

FP32
commit 1241bad70f747ab3debc2b0295c4de92a1a0086e
2023-05-11 15:38:53,131 [INFO] Model PSNR: 30.42989522500928 - Bicubic PSNR: 24.87192185623686
2023-05-11 15:38:53,131 [INFO] Model SSIM: 0.9264403109565023 - Bicubic SSIM: 0.8233347716777945
2023-05-11 15:38:53,131 [INFO] Model LPIPS: 6.918588840497184e-05 - Bicubic LPIPS: 0.0013339655822680986

I tried using MSE loss instead of L1 loss and found that it did not improve the results, so ended up sticking with L1 loss.

MSE loss
commit b27d12c527385b2014b83566be105f839fe51327
2023-05-11 22:40:08,426 [INFO] Model PSNR: 30.103380391483654 - Bicubic PSNR: 24.87192185623686
2023-05-11 22:40:08,426 [INFO] Model SSIM: 0.9215354474245302 - Bicubic SSIM: 0.8233347716777945
2023-05-11 22:40:08,426 [INFO] Model LPIPS: 7.704998459073747e-05 - Bicubic LPIPS: 0.0013645301913965267

I also tried LPIPS loss and mixes of LPIPS and L1 loss, but they did not improve the results either (results not shown).

I found that adding random 90 degree rotations during training improved the results significantly over 0.16 dB, though training took longer.

Image rotations in data loader
2023-05-15 18:57:32,264 [INFO] Model PSNR: 30.596167203837407 - Bicubic PSNR: 24.87192185623686
2023-05-15 18:57:32,264 [INFO] Model SSIM: 0.9274689177510691 - Bicubic SSIM: 0.8233347716777945
2023-05-15 18:57:32,264 [INFO] Model LPIPS: 6.428865776182942e-05 - Bicubic LPIPS: 0.0013645301913965267

I tried using smaller 128x128 crops instead of 256x256 crops and using a 4x larger batch size, with the intuition that it might lead to better generalization. The results were not as good:

2023-05-16 00:43:52,728 [INFO] Model PSNR: 30.464718024898396 - Bicubic PSNR: 24.87192185623686
2023-05-16 00:43:52,728 [INFO] Model SSIM: 0.9261055387180951 - Bicubic SSIM: 0.8233347716777945
2023-05-16 00:43:52,728 [INFO] Model LPIPS: 6.534671619856208e-05 - Bicubic LPIPS: 0.0013645301913965267

Improving the data augmentation further, the quality is significantly higher, but training time is also 11 hours instead of 5-6 hours:

  • 40% horizontal flip
  • 20% random brightness adjust between 50% and 120%
  • 50% random rotation of 90, 180, or 270 degrees
2023-05-18 16:44:34,113 [INFO] Model PSNR: 30.6207257848213 - Bicubic PSNR: 24.87192185623686
2023-05-18 16:44:34,113 [INFO] Model SSIM: 0.9277041514352384 - Bicubic SSIM: 0.8233347716777945
2023-05-18 16:44:34,113 [INFO] Model LPIPS: 6.412160131001836e-05 - Bicubic LPIPS: 0.0013645301913965267

The gap between bicubic and model is about the same (5.749) as the VapSR paper (5.57), so I consider these results to replicate their success, though the actual numbers are different due to using crops of the Urban100 test set. I did use a slightly larger model and additional augmentations to get this higher quality, but these are changes that the authors did not consider in their paper, and I consider it to be an improvement.

OpenVino Inference

Follow README instructions in the OpenVino directory.

License

This project is licensed under the MIT License.

upsampling's People

Contributors

catid avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.