GithubHelp home page GithubHelp logo

aghdamamir / 3d-unet Goto Github PK

View Code? Open in Web Editor NEW
30.0 2.0 7.0 177 KB

A pytorch implementation of 3D UNet for 3D MRI Segmentation.

Python 100.00%
segmentation 3dunet medical-image-segmentation mri-segmentation unet

3d-unet's People

Contributors

aghdamamir avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

3d-unet's Issues

UpConv3DBlock as well as activation function

Hello, Amir, I see your code is pretty good, it looks clean and brief,

class UpConv3DBlock(nn.Module):

    def __init__(self, in_channels, res_channels=0, last_layer=False, num_classes=None) -> None:
        super(UpConv3DBlock, self).__init__()
        assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments'
        self.upconv1 = nn.ConvTranspose3d(in_channels=in_channels, out_channels=in_channels, kernel_size=(2, 2, 2), stride=2)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm3d(num_features=in_channels//2)
        self.conv1 = nn.Conv3d(in_channels=in_channels+res_channels, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
        self.conv2 = nn.Conv3d(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
        self.last_layer = last_layer
        if last_layer:
            self.conv3 = nn.Conv3d(in_channels=in_channels//2, out_channels=num_classes, kernel_size=(1,1,1))
            
        
    def forward(self, input, residual=None):
        out = self.upconv1(input)
        if residual!=None: out = torch.cat((out, residual), 1)
        out = self.relu(self.bn(self.conv1(out)))
        out = self.relu(self.bn(self.conv2(out)))
        if self.last_layer: out = self.conv3(out)
        return out

and in this part, I don't quite understand the sentence
assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments'
why do you emphasize the last_layer in this function?

Also, in the last layer, I didn't see the activation function you used.

so, I added a sentence, maybe it's more complete, right?

class UNet3D(nn.Module):
    
    def __init__(self, in_channels, num_classes, level_channels=[32, 64, 128], bottleneck_channel=256) -> None:
        super(UNet3D, self).__init__()
        level_1_chnls, level_2_chnls, level_3_chnls = level_channels[0], level_channels[1], level_channels[2]
        self.a_block1 = Conv3DBlock(in_channels=in_channels, out_channels=level_1_chnls)
        self.a_block2 = Conv3DBlock(in_channels=level_1_chnls, out_channels=level_2_chnls)
        self.a_block3 = Conv3DBlock(in_channels=level_2_chnls, out_channels=level_3_chnls)
        self.bottleNeck = Conv3DBlock(in_channels=level_3_chnls, out_channels=bottleneck_channel, bottleneck= True)
        self.s_block3 = UpConv3DBlock(in_channels=bottleneck_channel, res_channels=level_3_chnls)
        self.s_block2 = UpConv3DBlock(in_channels=level_3_chnls, res_channels=level_2_chnls)
        self.s_block1 = UpConv3DBlock(in_channels=level_2_chnls, res_channels=level_1_chnls, num_classes=num_classes, last_layer=True)
        self.final_activation = nn.Sigmoid() # here
    
    
    def forward(self, input):
        #Analysis path forward feed
        out, residual_level1 = self.a_block1(input)
        out, residual_level2 = self.a_block2(out)
        out, residual_level3 = self.a_block3(out)
        out, _ = self.bottleNeck(out)

        #Synthesis path forward feed
        out = self.s_block3(out, residual_level3)
        out = self.s_block2(out, residual_level2)
        out = self.s_block1(out, residual_level1)
        out = self.final_activation(out) # and here
        return out

Thank you!
Jing

Why is the code always automatically selected to run on Gpus 0

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

Why do I use the above code to select gpu # 1 but the code is still displayed to run on gpu # 0 when the error is reported??

RuntimeError: CUDA out of memory. Tried to allocate 7.00 GiB (GPU 0; 23.70 GiB total capacity; 21.29 GiB already allocated; 870.81 MiB free; 21.30 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Memory Errors

Hello Amir:

I am trying the 3D-UNet code for the first time. I have a nVidia A100 card with 40GB memory. I am seeing the below torch.cuda.OutOfMemoryError when I am trying to run the train.py. Any ideas on how to get over this issue?

image

RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [1, 1, 16, 240, 240, 160]

Traceback (most recent call last):
File "train.py", line 45, in
target = model(image)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/task/3DUNet/unet3d.py", line 117, in forward
out, residual_level1 = self.a_block1(input)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/task/3DUNet/unet3d.py", line 39, in forward
res = self.relu(self.bn1(self.conv1(input)))
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 610, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 605, in _conv_forward
return F.conv3d(
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/monai/data/meta_tensor.py", line 249, in torch_function
ret = super().torch_function(func, types, args, kwargs)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/_tensor.py", line 1386, in torch_function
ret = func(*args, **kwargs)
RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [1, 1, 16, 240, 240, 160]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.