GithubHelp home page GithubHelp logo

sayakpaul / maxim-tf Goto Github PK

View Code? Open in Web Editor NEW
123.0 5.0 15.0 11.41 MB

Implementation of MAXIM in TensorFlow.

Home Page: https://huggingface.co/models?pipeline_tag=image-to-image&sort=downloads&search=maxim

License: Apache License 2.0

Python 37.30% Jupyter Notebook 62.70%
all-mlp computer-vision conv deblurring dehazing denoising deraining enhancement gmlp keras

maxim-tf's Introduction

MAXIM in TensorFlow

HugginFace badge Open In Colab TensorFlow 2.10 Models on TF-Hub HugginFace badge

Implementation of MAXIM [1] in TensorFlow. This project received the #TFCommunitySpotlight Award.

MAXIM introduces a backbone that can tackle image denoising, dehazing, deblurring, deraining, and enhancement.

Taken from the MAXIM paper

The weights of different MAXIM variants are in JAX and they're available in [2].

You can find all the TensorFlow MAXIM models here on TensorFlow Hub as well as on Hugging Face Hub.

You can try out the models on Hugging Face Spaces:

If you prefer Colab Notebooks, then you can check them out here.

Model conversion to TensorFlow from JAX

Blocks and layers related to MAXIM are implemented in the maxim directory.

convert_to_tf.py script is leveraged to initialize a particular MAXIM model variant and a pre-trained checkpoint and then run the conversion to TensorFlow. Refer to the usage section of the script to know more.

This script serializes the model weights in .h5 as as well pushes the SavedModel to Hugging Face Hub. For the latter, you need to authenticate yourself if not already done (huggingface-cli login).

This TensorFlow implementation is in close alignment with [2]. The author of this repository has reused some code blocks from [2] (with credits) to do.

Results and model variants

A comprehensive table is available here. The author of this repository validated the results with the converted models qualitatively.

Inference with the provided sample images

You can run the run_eval.py script for this purpose.

Image Denoising (click to expand)
python3 maxim/run_eval.py --task Denoising --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-3_denoising_sidd/1/uncompressed \
  --input_dir images/Denoising --output_dir images/Results --has_target=False --dynamic_resize=True
Image Deblurring (click to expand)
python3 maxim/run_eval.py --task Deblurring --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-3_deblurring_gopro/1/uncompressed \
  --input_dir images/Deblurring --output_dir images/Results --has_target=False --dynamic_resize=True
Image Deraining (click to expand)

Rain streak:

python3 maxim/run_eval.py --task Deraining --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_deraining_rain13k/1/uncompressed \
  --input_dir images/Deraining --output_dir images/Results --has_target=False --dynamic_resize=True

Rain drop:

python3 maxim/run_eval.py --task Deraining --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_deraining_raindrop/1/uncompressed \
  --input_dir images/Deraining --output_dir images/Results --has_target=False --dynamic_resize=True
Image Dehazing (click to expand)

Indoor:

python3 maxim/run_eval.py --task Dehazing --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_dehazing_sots-indoor/1/uncompressed \
  --input_dir images/Dehazing --output_dir images/Results --has_target=False --dynamic_resize=True

Outdoor:

python3 maxim/run_eval.py --task Dehazing --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_dehazing_sots-outdoor/1/uncompressed \
  --input_dir images/Dehazing --output_dir images/Results --has_target=False --dynamic_resize=True
Image Enhancement (click to expand)

Low-light enhancement:

python3 maxim/run_eval.py --task Enhancement --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_enhancement_lol/1/uncompressed \
  --input_dir images/Enhancement --output_dir images/Results --has_target=False --dynamic_resize=True

Retouching:

python3 maxim/run_eval.py --task Enhancement --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_enhancement_fivek/1/uncompressed \
  --input_dir images/Enhancement --output_dir images/Results --has_target=False --dynamic_resize=True

Notes:

  • The run_eval.py script is heavily inspired by the original one.
  • You can set dynamic_resize to False to obtain faster latency compromising the prediction quality.

XLA support

The models are XLA-supported. It can drammatically reduce the latency. Refer to the benchmark_xla.py script for more.

Known limitations

These are some of the known limitations of the current implementation. These are all open for contributions.

Supporting arbitrary image resolutions

MAXIM supports arbitrary image resolutions. However, the available TensorFlow models were exported with (256, 256, 3) resolution. So, a crude form of resizing is done on the input images to perform inference with the available models. This impacts the results quite a bit. This issue is discussed in more details here. Some work has been started to fix this behaviour (without ETA). I am thankful to Amy Roberts from Hugging Face for guiding me in the right direction.

But these models can be extended to support arbitrary resolution. Refer to this notebook for more details. Specifically, for a given task and an image, a new version of the model is instantiated and the weights of the available model are copied into the new model instance. This is a time-consuming process and isn't very efficient.

Output mismatches

The outputs of the TF and JAX models vary slightly. This is because of the differences in the implementation of different layers (resizing layer mainly). Even though the differences in the outputs of individual blocks of TF and JAX models are small, they add up, in the end, to be larger than one might expect.

With all that said, the qualitative performance doesn't seem to be disturbed at all.

Call for contributions

  • Add a minimal training notebook.
  • Fix any of the known limitations stated above

Acknowledgements

  • ML Developer Programs' team at Google for providing Google Cloud credits.
  • Gustavo Martins from Google for initial discussions and reviews of the codebase.
  • Amy Roberts from Hugging Face for guiding me in the right direction for handling arbitrary input shapes.

References

[1] MAXIM paper: https://arxiv.org/abs/2201.02973

[2] MAXIM official GitHub: https://github.com/google-research/maxim

maxim-tf's People

Contributors

sayakpaul avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

maxim-tf's Issues

Obtaining results yourself

Steps

  1. Let's we're interested in the image enhancement task. For that, we first need to get the serialized SavedModel ported from JAX:
python convert_to_tf.py \
    --task Enhancement \
    --ckpt_path gs://gresearch/maxim/ckpt/Enhancement/LOL/checkpoint.npz

This will save the model weights in .h5 as well as the entire model as a SavedModel.

All supported tasks and checkpoints are available here.

  1. The next step is to run the SavedModel to obtain results:
python3 run_eval.py --task Enhancement --ckpt_path S-2_enhancement_lol \
  --input_dir images/Enhancement --output_dir images/Results --has_target=False

Please ensure you're using the appropriate checkpoint with the appropriate input images.

In the current form, run_eval.py will aggressively resize the input image to 256x256 (refer to the script for more details). This can be sub-optimal quality-wise. So, if we wanted to follow the original logic of MAXIM, then we need a bit of hacking. Our run_eval.py script has plenty of comments on how to do that. Specifically, you need to uncomment the following things:

Here's how the results look like:

image

image

image

You can follow this workflow to generate more results on other tasks such as denoising, dehazing, etc. I will, of course, work on it when we will publish the models on Hub.

P.S.: Please let me know if you have suggestions for making run_eval.py cleaner and more easily customizable.

@gustheman

3D images

Is it possible to convert this architecture to a 3D model that can handle medical images (without exploding gpu memory due to the extra dimension)?

Building the model with `(None, None, 3)`

The original MAXIM model can accept images of any resolution even though it was trained on 256x256x3 images.

But this doesn't constrain the MAXIM model to accept only 256x256x3 images. As long as the input image's spatial resolutions are divisible by 64, it's all good.

This is how the authors do it:

In our case, the model is built with layers.Input((256, 256, 3)):

inputs = keras.Input((input_resolution, input_resolution, 3))

If we use (None, None, 3), it throws:

Traceback (most recent call last):
  File "convert_to_tf.py", line 234, in <module>
    main(args)
  File "convert_to_tf.py", line 192, in main
    _, tf_model = port_jax_params(configs, args.ckpt_path)
  File "convert_to_tf.py", line 140, in port_jax_params
    tf_model = Model(**configs)
  File "/Users/sayakpaul/Downloads/maxim-tf/create_maxim_model.py", line 31, in Model
    outputs = maxim_model(inputs)
  File "/Users/sayakpaul/Downloads/maxim-tf/maxim/maxim.py", line 99, in apply
    height=h // (2 ** i),
TypeError: unsupported operand type(s) for //: 'NoneType' and 'int'

From the logs, it might seem obvious that we cannot build the Keras model with (None, None, 3) since there are calculations inside the model that require us to specify the spatial dimensions.

Do you know of any way to mitigate this problem or any other approach?

@gustheman

Write README

  • Model reference
  • One line to describe why the model is interesting
  • Model ckpts and Colabs
  • Challenges faced
  • Note around dynamic resizing
  • Contributions
  • Acknowledgements

MAXIM model pre-trained not working

I tried to use MAXIM denoising pre-trained model where https://tfhub.dev/sayakpaul/collections/maxim/1 and https://colab.research.google.com/github/sayakpaul/maxim-tf/blob/main/notebooks/inference.ipynb#scrollTo=ECGdFWQBw8E2.
Both following ways are not working.

In colab case error: gs://maxim-tf/S-2_dehazing_sots-indoor does not exist.

In python case error: WARNING:tensorflow:No training configuration found in save file, so the model was not compiled. Compile it manually.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. model.compile_metrics will be empty until you train or evaluate the model.

Finetuning enhancement task

Hi, thanks for porting this to TF.
Im impressed with enhancement results, however im getting wierd artifacts in the original repo (Here is the issue)
Im wondering if you plan on adding training code somehow? Currently im unable to find anything except pytorch version which has training on all tasks except enchancement part which is what im most interested in. Thanks

Errors during saving the model as a SavedModel resource

Issue

Traceback (most recent call last):
  File "convert_to_tf.py", line 234, in <module>
    main(args)
  File "convert_to_tf.py", line 200, in main
    tf_model.save(saved_model_path)
  File "/Users/sayakpaul/.local/bin/.virtualenvs/keras-io/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.8/lib/python3.8/json/encoder.py", line 199, in encode
    chunks = self.iterencode(o, _one_shot=True)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.8/lib/python3.8/json/encoder.py", line 257, in iterencode
    return _iterencode(o, 0)
TypeError: Unable to serialize 64 to JSON. Unrecognized type <class 'tensorflow.python.framework.ops.EagerTensor'>.

Steps to reproduce

From the root of the directory run - python convert_to_tf.py.

Note that we can serialize the model params as h5 but not the entire model as a SavedModel resource.

Also, note that serialization to SavedModel takes time. So, please be patient.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.