aghdamamir / 3d-unet Goto Github PK
View Code? Open in Web Editor NEWA pytorch implementation of 3D UNet for 3D MRI Segmentation.
A pytorch implementation of 3D UNet for 3D MRI Segmentation.
Hello, Amir, I see your code is pretty good, it looks clean and brief,
class UpConv3DBlock(nn.Module):
def __init__(self, in_channels, res_channels=0, last_layer=False, num_classes=None) -> None:
super(UpConv3DBlock, self).__init__()
assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments'
self.upconv1 = nn.ConvTranspose3d(in_channels=in_channels, out_channels=in_channels, kernel_size=(2, 2, 2), stride=2)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm3d(num_features=in_channels//2)
self.conv1 = nn.Conv3d(in_channels=in_channels+res_channels, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
self.conv2 = nn.Conv3d(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
self.last_layer = last_layer
if last_layer:
self.conv3 = nn.Conv3d(in_channels=in_channels//2, out_channels=num_classes, kernel_size=(1,1,1))
def forward(self, input, residual=None):
out = self.upconv1(input)
if residual!=None: out = torch.cat((out, residual), 1)
out = self.relu(self.bn(self.conv1(out)))
out = self.relu(self.bn(self.conv2(out)))
if self.last_layer: out = self.conv3(out)
return out
and in this part, I don't quite understand the sentence
assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments'
why do you emphasize the last_layer
in this function?
Also, in the last layer, I didn't see the activation function you used.
so, I added a sentence, maybe it's more complete, right?
class UNet3D(nn.Module):
def __init__(self, in_channels, num_classes, level_channels=[32, 64, 128], bottleneck_channel=256) -> None:
super(UNet3D, self).__init__()
level_1_chnls, level_2_chnls, level_3_chnls = level_channels[0], level_channels[1], level_channels[2]
self.a_block1 = Conv3DBlock(in_channels=in_channels, out_channels=level_1_chnls)
self.a_block2 = Conv3DBlock(in_channels=level_1_chnls, out_channels=level_2_chnls)
self.a_block3 = Conv3DBlock(in_channels=level_2_chnls, out_channels=level_3_chnls)
self.bottleNeck = Conv3DBlock(in_channels=level_3_chnls, out_channels=bottleneck_channel, bottleneck= True)
self.s_block3 = UpConv3DBlock(in_channels=bottleneck_channel, res_channels=level_3_chnls)
self.s_block2 = UpConv3DBlock(in_channels=level_3_chnls, res_channels=level_2_chnls)
self.s_block1 = UpConv3DBlock(in_channels=level_2_chnls, res_channels=level_1_chnls, num_classes=num_classes, last_layer=True)
self.final_activation = nn.Sigmoid() # here
def forward(self, input):
#Analysis path forward feed
out, residual_level1 = self.a_block1(input)
out, residual_level2 = self.a_block2(out)
out, residual_level3 = self.a_block3(out)
out, _ = self.bottleNeck(out)
#Synthesis path forward feed
out = self.s_block3(out, residual_level3)
out = self.s_block2(out, residual_level2)
out = self.s_block1(out, residual_level1)
out = self.final_activation(out) # and here
return out
Thank you!
Jing
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
Why do I use the above code to select gpu # 1 but the code is still displayed to run on gpu # 0 when the error is reported??
RuntimeError: CUDA out of memory. Tried to allocate 7.00 GiB (GPU 0; 23.70 GiB total capacity; 21.29 GiB already allocated; 870.81 MiB free; 21.30 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Traceback (most recent call last):
File "train.py", line 45, in
target = model(image)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/task/3DUNet/unet3d.py", line 117, in forward
out, residual_level1 = self.a_block1(input)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/task/3DUNet/unet3d.py", line 39, in forward
res = self.relu(self.bn1(self.conv1(input)))
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 610, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 605, in _conv_forward
return F.conv3d(
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/monai/data/meta_tensor.py", line 249, in torch_function
ret = super().torch_function(func, types, args, kwargs)
File "/dssg/home/acct-phywd/phywd/.conda/envs/py38/lib/python3.8/site-packages/torch/_tensor.py", line 1386, in torch_function
ret = func(*args, **kwargs)
RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [1, 1, 16, 240, 240, 160]
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.