GithubHelp home page GithubHelp logo

jakeret / unet Goto Github PK

View Code? Open in Web Editor NEW
252.0 5.0 88.0 6.89 MB

Generic U-Net Tensorflow 2 implementation for semantic segmentation

Home Page: https://u-net.readthedocs.io/en/latest/?badge=latest

License: GNU General Public License v3.0

Python 0.47% Jupyter Notebook 99.53%
semantic-segmentation deep-learning tensorflow keras-tensorflow

unet's Introduction

Tensorflow Unet

Documentation Status https://travis-ci.com/jakeret/unet.svg?branch=master http://img.shields.io/badge/arXiv-1609.09077-orange.svg?style=flat https://camo.githubusercontent.com/c8e5db7a5d15b0e7c13480a0ed81db1ae2128b80/68747470733a2f2f62696e6465722e70616e67656f2e696f2f62616467655f6c6f676f2e737667 https://camo.githubusercontent.com/52feade06f2fecbf006889a904d221e6a730c194/68747470733a2f2f636f6c61622e72657365617263682e676f6f676c652e636f6d2f6173736574732f636f6c61622d62616467652e737667

This is a generic U-Net implementation as proposed by Ronneberger et al. developed with Tensorflow 2. This project is a reimplementation of the original tf_unet.

Originally, the code was developed and used for Radio Frequency Interference mitigation using deep convolutional neural networks .

The network can be trained to perform image segmentation on arbitrary imaging data. Checkout the Usage section, the included Jupyter notebooks or on Google Colab for a toy problem or the Oxford Pet Segmentation example available on Google Colab.

The code is not tied to a specific segmentation such that it can be used in a toy problem to detect circles in a noisy image.

Segmentation of a toy problem.

To more complex application such as the detection of radio frequency interference (RFI) in radio astronomy.

Segmentation of RFI in radio data.

Or to detect galaxies and star in wide field imaging data.

Segmentation of a galaxies.

The architectural elements of a U-Net consist of a contracting and expanding path:

Unet architecture.

As you use unet for your exciting discoveries, please cite the paper that describes the package:

@article{akeret2017radio,
  title={Radio frequency interference mitigation using deep convolutional neural networks},
  author={Akeret, Joel and Chang, Chihway and Lucchi, Aurelien and Refregier, Alexandre},
  journal={Astronomy and Computing},
  volume={18},
  pages={35--39},
  year={2017},
  publisher={Elsevier}
}

unet's People

Contributors

gokarslan avatar jakeret avatar tdrobbins avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

unet's Issues

Unable to restore custom object of type _tf_keras_metric currently while loading previously saved model without custom layers

I ran the scripts/oxford_iiit_pet.py and got a saved model in model_path.

now I would like to load this model with:
model = tf.keras.models.load_model(model_path)

but I get:

ValueError: Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements get_configand from_config when saving. In addition, please use the custom_objects arg when calling load_model().

I searched this question on Stack Overflow and found that it worked when I changed the code to this:
model = tf.keras.models.load_model(model_path, custom_objects = {"mean_iou": mean_iou, "dice_coefficient": dice_coefficient})

so, I think in order to deserialize the model more conveniently, these two metrics should subclass tf.keras.metrics.Metric

ValueError when channels > 3

Trying to train on images with more than three channels raises a ValueError. I fixed the bug with a really simple patch to unet.utils.to_rgb() and can submit a pull request, if you like.

Here's the full error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-8-220be538e83e> in <module>
----> 1 trainer.fit(model4, data)

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/trainer.py in fit(self, model, train_dataset, validation_dataset, test_dataset, epochs, batch_size, **fit_kwargs)
     94                             epochs=epochs,
     95                             callbacks=callbacks,
---> 96                             **fit_kwargs)
     97 
     98         self.evaluate(model, test_dataset, prediction_shape)

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     67 
     68     # Running inside `run_distribute_coordinator` already.

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
    874           epoch_logs.update(val_logs)
    875 
--> 876         callbacks.on_epoch_end(epoch, epoch_logs)
    877         if self.stop_training:
    878           break

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
    363     logs = self._process_logs(logs)
    364     for callback in self.callbacks:
--> 365       callback.on_epoch_end(epoch, logs)
    366 
    367   def on_train_batch_begin(self, batch, logs=None):

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/callbacks.py in on_epoch_end(self, epoch, logs)
     32         self._log_histogramms(epoch, predictions)
     33 
---> 34         self._log_image_summaries(epoch, predictions)
     35 
     36         self.file_writer.flush()

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/callbacks.py in _log_image_summaries(self, epoch, predictions)
     50                                  utils.to_rgb(cropped_labels[..., :1].numpy()),
     51                                  utils.to_rgb(mask)),
---> 52                                 axis=2)
     53 
     54         with self.file_writer.as_default():

<__array_function__ internals> in concatenate(*args, **kwargs)

ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 3, the array at index 0 has size 4 and the array at index 1 has size 3

And a quick script to reproduce the error:

import unet
import tensorflow as tf

X = tf.random.normal((100,256,256,4))
Y_flat = tf.random.categorical(tf.math.log([[0.5, 0.5]]),100*256*256)
Y = tf.reshape(Y_flat,(100,256,256))
Y_onehot = tf.one_hot(Y,2)

data = tf.data.Dataset.from_tensor_slices((X,Y_onehot))
train_data = data.take(75)
test_data = data.skip(75)

model4 = unet.build_model(256,256,channels=4,padding="same")
unet.finalize_model(model4,loss=tf.keras.losses.categorical_crossentropy)
trainer = unet.Trainer()

trainer.fit(model4, data)

Error loading saved model

I am getting the error below right after executing this command:

unet_model = tf.keras.models.load_model("M:\\ant\\2021-06-04T22-22_27", custom_objects=custom_objects)

File "segmenter2-generator.py", line 152, in evaluate unet_model = tf.keras.models.load_model("M:\\anthurium\\ant\\2021-06-04T22-22_27", custom_objects=None, compile=False) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\saving\save.py", line 206, in load_model return saved_model_load.load(filepath, compile, options) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\saving\saved_model\load.py", line 155, in load keras_loader.finalize_objects() File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\saving\saved_model\load.py", line 626, in finalize_objects self._reconstruct_all_models() File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\saving\saved_model\load.py", line 645, in _reconstruct_all_models self._reconstruct_model(model_id, model, layers) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\saving\saved_model\load.py", line 692, in _reconstruct_model config, created_layers={layer.name: layer for layer in layers}) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\engine\functional.py", line 1289, in reconstruct_from_config process_node(layer, node_data) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\engine\functional.py", line 1237, in process_node output_tensors = layer(input_tensors, **kwargs) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 970, in __call__ input_list) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1108, in _functional_construction_call inputs, input_masks, args, kwargs) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 840, in _keras_tensor_symbolic_call return self._infer_output_signature(inputs, args, kwargs, input_masks) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 880, in _infer_output_signature outputs = call_fn(inputs, *args, **kwargs) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 69, in return_outputs_and_add_losses outputs, losses = fn(*args, **kwargs) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 167, in wrap_with_training_arg lambda: replace_training_and_call(False)) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\utils\control_flow_util.py", line 110, in smart_cond pred, true_fn=true_fn, false_fn=false_fn, name=name) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\smart_cond.py", line 56, in smart_cond return false_fn() File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 167, in <lambda> lambda: replace_training_and_call(False)) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 163, in replace_training_and_call return wrapped_call(*args, **kwargs) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py", line 889, in __call__ result = self._call(*args, **kwds) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py", line 933, in _call self._initialize(args, kwds, add_initializers_to=initializers) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py", line 764, in _initialize *args, **kwds)) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py", line 3050, in _get_concrete_function_internal_garbage_collected graph_function, _ = self._maybe_define_function(args, kwargs) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py", line 3397, in _maybe_define_function self._function_spec.canonicalize_function_inputs(*args, **kwargs) File "C:\Users\JP\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py", line 2752, in canonicalize_function_inputs self.signature_summary(), missing_args[0])) TypeError: f(inputs, training, training, training, training, *, training, training) missing 1 required argument: training

Any tips what is causing this and how to solve this?

Thanks

How to install this package

I want to install this in miniconda by inputing'python setup.py install' in Anaconda Prompt .But Error: AttributeError: 'Configuration' object has no attribute 'parentdir_prefix_version' Would you tell me why?Thanks!!!

What is the meaning of classes in Oxford Pets dataset?

In your code you write that there are 3 classes in the dataset. Where does it come from? On TF datasets page it says that there are 37 classes.

Also, what is the meaning of masks in the dataset? I've seen that there is only one channel for each mask so it isn't clear what should it mean for 3 classes

AttributeError: module 'unet' has no attribute 'build_model'`

Hi!

While running the notebook 'oxford_pets.ipynb', I ran into the following issue:

`AttributeError Traceback (most recent call last)
in
1 LEARNING_RATE = 1e-3
----> 2 unet_model = unet.build_model(*oxford_iiit_pet.IMAGE_SIZE,
3 channels=oxford_iiit_pet.channels,
4 num_classes=oxford_iiit_pet.classes,
5 layer_depth=5,

AttributeError: module 'unet' has no attribute 'build_model'`

What could it be?

Thanks!

Take() got an unexpected keyword argument 'count'.

Heya,
I'm trying to use the model to perform some image separation using the mnist database.
I literally copy your code and modify some parameters to match with the mnist database, but I get this weird mesage:

image

If you're wondering, here's how I create the train and test dataset:

(mnist_x_train, mnist_y_train), (mnist_x_test, mnist_y_test) = mnist.load_data()

Not sure what this is related to, any ideas?

Thanks,
Mat

JSONDecodeError: Expecting value: line 1 column 1 (char 0)

after saving my model like this I can't reopen it
unet_model.save(r'C:\Users\ggrimard\Documents\Models/UnetTake2')

modelPath = r'C:\Users\ggrimard\Documents\Models/UnetTake2'
from unet import custom_objects
reconstructed_model = tf.keras.models.load_model(modelPath, custom_objects=custom_objects)

ConcatOp error when changing layer_depth=3 to a larger value in circle demo.

Thanks much for your project and making your source available to others! : ) Wondering if someone may be able to help with an issue I'm having.

I took the code from the circle.ipynb demo and made a .py demo from it (copy / paste). Everything is working great there. However, when I change layer_depth=3 to a new value, I get ConcatOp errors.

I change:
unet_model = unet.build_model(channels=circles.channels, num_classes=circles.classes, layer_depth=3, filters_root=16)

To be:
unet_model = unet.build_model(channels=circles.channels, num_classes=circles.classes, layer_depth=4, filters_root=16)

Then I see errors when I try to run:
tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [1,64,35,35] vs. shape[1] = [1,64,34,34] [[node unet/crop_concat_block/concat (defined at /home/avose/workspace/butterfly/synth/bf_unet/unet/unet.py:130) ]] [Op:__inference_distributed_function_2008]

Any tips would be greatly appreciated. I have also been playing with the TF1.x version of unet, and that version does seem to allow me to change the layer depth option without giving any errors.

Load Unet model in OpenCV?

I'm trying to load a saved Unet model in OpenCV using dnn.readNetFromTensorflow and dnn.readNetFromONNX, but both methods are failing. The first one hangs and the second results in an obscure error.

> Node [[email protected]]:(onnx_node!StatefulPartitionedCall/unet/crop_concat_block/strided_slice_6) parse error: OpenCV(4.6.0) /Users/xperience/actions-runner/_work/opencv-python/opencv-python/opencv/modules/dnn/src/onnx/onnx_importer.cpp:1345: error: (-215:Assertion failed) axes[i - 1] == axes[i] - 1 in function 'parseSlice'

Has anyone done this before?

Example Notebook Cleanup

Looks like the predictions are all completely masked to black images. I could replicate it on my end and got accurate predictions. I would love to fix it.

Working the code for 6 channel inputs

As per the code, if the number of input channels is>3 it just uses the first 3 channels.
The input I want to give consists of 6 input channels, first three being rgb and the other 3 being some specific inputs I want to give to the model for my desired use case.
How can we make the model input all 6 channels?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.