GithubHelp home page GithubHelp logo

pytorch-2d-3d-unet-tutorial's Introduction

PyTorch-2D-3D-UNet-Tutorial

A beginner-friendly tutorial to start a 2D or 3D image segmentation deep learning project with PyTorch & the U-Net architecture. Based on the blog series "Creating and training a U-Net model with PyTorch for 2D & 3D semantic segmentation - A guide to semantic segmentation with PyTorch and the U-Net".

image1 image2

Installation

  1. Set up a new environment with an environment manager (recommended):
    1. conda:
      1. conda create --name unet-tutorial -y
      2. conda activate unet-tutorial
      3. conda install python=3.8 -y
    2. venv:
      1. python3 -m venv unet-tutorial
      2. source unet-tutorial/bin/activate
  2. Install the libraries: pip install -r requirements.txt
  3. Start a jupyter server: jupyter-notebook OR jupyter-lab

Note: This will install the CPU-version of torch. If you want to use a GPU or TPU, please refer to the instructions on the PyTorch website

Summary

  • Part I: Building a dataset in PyTorch & visualizing it with napari
  • Part II: Creating the U-Net model in PyTorch & information about model input and output
  • Part III: Training a 2D U-Net model on a sample of the Carvana dataset with improving datasets (caching, multiprocessing)
  • Part IV: Running inference on test data
  • Part V: Building a 3D dataset
  • Part VI: Running an experiment with the experiment tracker neptune.ai and the high-level PyTorch library PyTorch Ligthning

Note: Due to updates in the neptune API, part VI will probably not work and needs to be migrated to the new neptune API first

Dataset samples

This repository contains a sample of a 2D and a 3D dataset for semantic segmentation.

U-Net model

If you are unsure what arguments to pass in the Unet class, please take a look at the enums in unet.py and view examples in test_unet.py.

Note: Weights are initialized randomly (xavier initialization) and training can take some time. To faster train a segmentation model, it is recommended to use a pretrained backbone such as ResNet for 2D or even 3D tasks.

pytorch-2d-3d-unet-tutorial's People

Contributors

johschmidt42 avatar narenakash avatar schmiddi-75 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

pytorch-2d-3d-unet-tutorial's Issues

Missing License

Dear Mr. Schmidt,

first off: Thanks for your amazing tutorial on medium and the code provided in this repository! It is a really helpful introduction to image segmentation with pytorch. After a while, I still use some of your provided code without adaptions (e.g. transformations.py).
This leads to my question regarding code usage. The sources you are refer to published their code under MIT License. In your repository no LICENSE file is given. Do you consider adding an official license in the future? I assume you intend people to use, learn with and adapt your code, so this would be really helpful in legal terms.
I would be happy to reference your article and this repository in the future.

Best regards

CPU execution fails

My GPU is too old - level 3, while the minimum needed is 3.5 - so in Part 3 I changed the code to force use of the CPU:

device

#GB if torch.cuda.is_available():
#GB device = torch.device('cuda')
#GB else:
#GB torch.device('cpu')
torch.device('cpu')

I get a runtime error, a string of messages that ends with:

RuntimeError: CUDA error: no kernel image is available for execution on the device

Although I have set 'cpu' as the device, it still seems to try to use 'cuda'.

Resize 3D images

I want to express images predicted by model.
I run 3D unet refer to your code Part5.
However, I have problem with something when it is express!
Used data is Microblues3D.

Resize code is under...
images_res = [resize(img, (32, 200, 200, 1)) for img in images]
resize_kwargs = {"order": 0, "anti_aliasing": False, "preserve_range": True}
#targets_res = [resize(tar, (128, 128), **resize_kwargs) for tar in targets]
targets_res = [resize(tar, (200, 200, 32), **resize_kwargs) for tar in targets]

transformations.py

I think the normalize function in Normalize Class should be passed two arguments (mean and std). Calling it as it results in an error.

Part3,CUDA error: device-side assert triggered

I'm trying to train the unet model for 3-channel input image and 1 channel mask (1,0) - similar to the data structure for this tutorial. I used the same code in part 3 with GPU but the training cell outputs the following error:
CUDA error: device-side assert triggered.
However, when I change the device to CPU the model trains fine, do you know the cause of this issue ?

Input and Target Shape

Hi @johschmidt42 thank you very much for your effort, I really appreciate your work. I have some basic questions about 3DU-Net side, In you blog you said that;
The type of x and y is already correct.
The x should have a shape of [N, C, H, W] . So the channel dimension should be second instead of last.
The y is supposed to have a shape [N, H, W] , (this is because torch’s loss function torch.nn.CrossEntropyLoss expects it, even though it will internally one-hot encode it).
But on my side Input and Output (x,y) has same shape [N,C,H,W] because my images are 3D for image and masks. I think it could be okay but when I am trying to train with this datas the weights and input size doesnt match. So I couldnt find how can I figure it out.. If you can help me it would be perfect, thanks a lot in advance.
These are my input and masks shape...
shapeprobelms
This problem which I have faced...
nsionalweightproblem

DiceLoss implementation is missing

Hey @johschmidt42, thanks for your efforts.

The losses module with the DiceLoss implementation used in unet_lightning.py is not available in the repository. Could you please add the loss functions as well? Thanks.

Part1 error: in [4]: no module named 'napari'

I have installed napari, and if I open a new notebook and type: "import napari" there is no error, but when I run the Part1 notebook there is an error at that line:

ModuleNotFoundError Traceback (most recent call last)
in
3 from visual import show_input_target_pair_napari
4 gen = Input_Target_Pair_Generator(dataloader_training, rgb=True)
----> 5 show_input_target_pair_napari(gen)

F:\PyTorch-2D-3D-UNet-Tutorial\visual.py in show_input_target_pair_napari(gen_training, gen_validation)
11
12 # Napari
---> 13 import napari
14 with napari.gui_qt():
15 viewer = napari.Viewer()

ModuleNotFoundError: No module named 'napari'

what am I doing wrong?
As you can tell I am a complete newbie to Jupyter. Forget my first line comment - when I run the command I get an error. But the fact remains that napari is installed - I can use it to view an image from the command line, or to open a napari window inside Python. Why doesn't it work inside the Jupyter notebook?

Problem with x and y dimensions

After running through your tutorials up to Tutorial 3, I've noticed that the shape of my targets/masks tensors is one degree higher than yours.

Mine:

image

Yours:

image

And after getting to the learning rate finder in tutorial 3, I now get the following error:

image

Which I am guessing might have something to do with Torch's loss function torch.nn.CrossEntropyLoss expecting a target tensor shape of [N, H, W]

I'm a bit confused as to how your mask tensors are one dimension lower than your input tensors i.e. 3D vs. 4D? And why is it that the x should have a shape of [N, C, H, W], but that the y is supposed to have a shape of [N, W, H]? Any idea how I could solve this?

Unexpected keyword argument

I forgot to mention that there was another error in the code:

TypeError Traceback (most recent call last)
in
35
36 # start training
---> 37 training_losses, validation_losses, lr_rates = trainer.run_trainer(notebook=True)
38 #training_losses, validation_losses, lr_rates = trainer.run_trainer()

TypeError: run_trainer() got an unexpected keyword argument 'notebook'

I fixed this by removing 'notebook=True' (i.e. by using the commented line 38 above.)
Presumably this is a software version issue.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.