GithubHelp home page GithubHelp logo

beckschen / transunet Goto Github PK

View Code? Open in Web Editor NEW
2.2K 12.0 465.0 29 KB

This repository includes the official project of TransUNet, presented in our paper: TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation.

License: Apache License 2.0

Python 100.00%

transunet's Introduction

TransUNet

This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

📰 News

  • [10/15/2023] 🔥 3D version of TransUNet is out! Our 3D TransUNet surpasses nn-UNet with 88.11% Dice score on the BTCV dataset and outperforms the top-1 solution in the BraTs 2021 challenge. Please take a look at the code and paper.

Usage

1. Download Google pre-trained ViT models

wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz &&
mkdir ../model/vit_checkpoint/imagenet21k &&
mv {MODEL_NAME}.npz ../model/vit_checkpoint/imagenet21k/{MODEL_NAME}.npz

2. Prepare data

Please go to "./datasets/README.md" for details, or use the preprocessed data and data2 for research purposes.

3. Environment

Please prepare an environment with python=3.7, and then use the command "pip install -r requirements.txt" for the dependencies.

4. Train/Test

  • Run the train script on synapse dataset. The batch size can be reduced to 12 or 6 to save memory (please also decrease the base_lr linearly), and both can reach similar performance.
CUDA_VISIBLE_DEVICES=0 python train.py --dataset Synapse --vit_name R50-ViT-B_16
  • Run the test script on synapse dataset. It supports testing for both 2D images and 3D volumes.
python test.py --dataset Synapse --vit_name R50-ViT-B_16

Reference

Citations

@article{chen2021transunet,
  title={TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation},
  author={Chen, Jieneng and Lu, Yongyi and Yu, Qihang and Luo, Xiangde and Adeli, Ehsan and Wang, Yan and Lu, Le and Yuille, Alan L., and Zhou, Yuyin},
  journal={arXiv preprint arXiv:2102.04306},
  year={2021}
}

transunet's People

Contributors

andife avatar beckschen avatar yucornetto avatar yuyinzhou avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

transunet's Issues

ModuleNotFoundError: No module named 'datasets.dataset_synapse'

Hello,
I use the preprocessed data (requested as mentioned in the Readme) and get the
following error:

TransUNet$ python test.py --dataset Synapse --vit_name R50-ViT-B_16
Traceback (most recent call last):
File "test.py", line 12, in
from datasets.dataset_synapse import Synapse_dataset
ModuleNotFoundError: No module named 'datasets.dataset_synapse'

For me, the cause of this is unclear. Does anyone have an idea?

Request preliminary data

Hello, I sent an email to you to get the preprocessed database, but maybe because you are too busy to check your mailbox or my email is judged as spam, I did not receive your preprocessed database. I hope you can upload the database to github in a form similar to Baidu cloud link or Google cloud link. Thanks!

Dataset preprocessing

I want to train my own dataset. So in the image preprocessing, as long as the size of the input image is 224*224 and the input of the model is [bs, channel, 224, 224], is it right?

Wrong Position Embedding Size

In the original implementation, the position embedding has a dimension of [ n_patches+1, hidden_size] to accommodate for additional class token:

https://github.com/jeonsworld/ViT-pytorch/blob/a786151f6ceed00e97ab526772916faec5efb8ed/models/modeling.py#L150

In your implementation, you removed the class token and your position embedding has a dimention [ n_patches, hidden_size]:

https://github.com/Beckschen/TransUNet/blob/main/networks/vit_seg_modeling.py#L149

When I tried to load the pre-trained model based on your changes (the new size) , I get a mismatch error:

model.load_from(np.load(args.pretrained_dir)), in load_from self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) RuntimeError: The size of tensor a (196) must match the size of tensor b (170) at non-singleton dimension 1

Would you please explain how the pretrained checkpoint can be loaded based on these changes ?

image preprocess code

Thanks for your great contribution!
Would u like to share your image preprocess code, that is get 3D image normalize and ertract 2D from 3D and then save to .npz file
Thanks !

IndexError:index 1 is out of bounds for dimension 0 with size 1

Hi,
I have prepared dataset as your suggestion, but there's wrong information when i was training.
Here is the error info.

File in line 74, in trainer_synapse:
image = image_batch[1, 0:1, :, :] IndexError: index 1 is out of bounds for dimension 0 with size 1.

And my settings of training: batch:4, image size: 224. GPU:2080Ti *1.
I have tried to change the batch num, but it can't work.

Looking forward to your reply soon!

colab error

when i run !CUDA_VISIBLE_DEVICES=0 python /content/gdrive/My\Drive/project_TransUNet/TransUNet/train.py --dataset Synapse --vit_name R50-ViT-B_16 in colab pro, it show this error:

Traceback (most recent call last):
File "/content/gdrive/MyDrive/project_TransUNet/TransUNet/train.py", line 92, in
net.load_from(weights=np.load(config_vit.pretrained_path))
File "/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py", line 432, in load
pickle_kwargs=pickle_kwargs)
File "/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py", line 186, in init
_zip = zipfile_factory(fid)
File "/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py", line 112, in zipfile_factory
return zipfile.ZipFile(file, *args, **kwargs)
File "/usr/lib/python3.7/zipfile.py", line 1258, in init
self._RealGetContents()
File "/usr/lib/python3.7/zipfile.py", line 1325, in _RealGetContents
raise BadZipFile("File is not a zip file")
zipfile.BadZipFile: File is not a zip file
Exception ignored in: <function NpzFile.del at 0x7f766b4b2710>
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py", line 223, in del
self.close()
File "/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py", line 214, in close
if self.zip is not None:
AttributeError: 'NpzFile' object has no attribute 'zip'

I have no idea. Thanks for help.

I have a question about input channel.

Hi! i am very interested in your Paper TransUNet and after reading your paper I have a question about input channel.

the medical domain often has 1 channel(gray scale) but ResNetV2 which you proposed treat 3 channel as input

Did you just simply stack the gray scale image to make 3 channel image?

Errors during training

Hi!
I got this errors while running training:

File "train.py", line 93, in <module>
    trainer[dataset_name](args, net, snapshot_path)
  File "/home/ubuntu/projects/TransUNet/trainer.py", line 57, in trainer_synapse
    loss_dice = dice_loss(outputs, label_batch, softmax=True)
  File "/home/ubuntu/miniconda3/envs/transunet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/projects/TransUNet/utils.py", line 43, in forward
    class_wise_dice.append(1.0 - dice.item())
RuntimeError: CUDA error: device-side assert triggered

I inspect the shapes:
label_batch (24, 224, 224);
outputs (24, 9, 224, 224);
image_batch (24, 1, 224, 224)

Then when I try to debug it, points me to another error:

CUDA error: device-side assert triggered
  File "/home/ubuntu/projects/TransUNet/utils.py", line 34, in forward
    inputs = torch.softmax(inputs, dim=1)
  File "/home/ubuntu/miniconda3/envs/transunet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/projects/TransUNet/trainer.py", line 57, in trainer_synapse
    loss_dice = dice_loss(outputs, label_batch, softmax=True)
  File "/home/ubuntu/projects/TransUNet/train.py", line 93, in <module>
    trainer[dataset_name](args, net, snapshot_path)
  File "/home/ubuntu/miniconda3/envs/transunet/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/miniconda3/envs/transunet/lib/python3.7/runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "/home/ubuntu/miniconda3/envs/transunet/lib/python3.7/runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "/home/ubuntu/miniconda3/envs/transunet/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/miniconda3/envs/transunet/lib/python3.7/runpy.py", line 193, in _run_module_as_main (Current frame)
    "__main__", mod_spec)

Again inspect the shapes:
inputs (24, 9, 224, 224)
target (24, 224, 224)

I would appreciate it if you can help me with this :)

P.S. I preprocessed the images as it was described (i.e. clipped, normalized and extracted the slices)

Running the test.py.

Traceback (most recent call last):
File "test.py", line 122, in
net.load_state_dict(torch.load(snapshot))
File "/home/jdj/anaconda3/envs/trans/lib/python3.7/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for VisionTransformer:
Missing key(s) in state_dict:
A lot of keys were merged, could u please tell me how to fix it?
Thanks!

Infer single image

Hi,thanks for your wonderful work,I want to ask how to use model to predict single image?

Evaluation metric

Hi,

Thanks a lot for sharing your code. I have a question regarding the computation of the evaluation metrics. Your code is:

def calculate_metric_percase(pred, gt):  
    pred[pred > 0] = 1
    gt[gt > 0] = 1
    if pred.sum() > 0 and gt.sum()>0:
        dice = metric.binary.dc(pred, gt)
        hd95 = metric.binary.hd95(pred, gt)
        return dice, hd95
    elif pred.sum() > 0 and gt.sum()==0:
        return 1, 0
    else:
        return 0, 0

I get the first "if" condition because the hausdorff distance requires groundtruth and predictions to be defined.
However I don't get why you set the dice score to 1, when predictions is not empty and groundtruth empty. From the dice formula, the score should be 0 in that case. And it's the same when predictions is empty and ground truth not empty (and that you did).
Concerning the hausdorff distance, it doesn't seem right to me to set a null distance when prediction or groundtruth doesn't exist, cause it should be considered as a mistake and not be taken into account in the evaluation metric.
What do you think about that ?
If I didn't understand your code correctly, sorry in advance

lists_Synapse mismatch to files? / content of all.lst test_vol.txt train.txt

Hello,
how must the lists in project_TransUNet/TransUNet/lists/ be processed for an own record?

I.e. all.lst test_vol.txt train.txt

For me, for example, there seems to be a mismatch between the names in test_vol.txt
case0008
case0022
case0038
case0036
case0032
case0002
case0029
case0003
case0001
case0004
case0025
case0035

and received preprocessed data:
Synapse/test_vol_h5$ ls
case0001.npy.h5 case0003.npy.h5 case0008.npy.h5 case0025.npy.h5 case0032.npy.h5 case0036.npy.h5 case0002.npy.h5 case0004.npy.h5 case0022.npy.h5 case0029.npy.h5 case0035.npy.h5 case0038.npy.h5

case0002 only exists in test_vol.txt ?

Thank you

Tensorflow variant

Hi,
are there any plans to create the model also using tensorflow?
I would be very interested

Thank you

Is there pretrained model of R50+ViT_16 for 512x512 input?

Hi, thanks for the great work.
I just wonder that you have the pretrained models of R50+ViT_16.
In your article, you said all R50 and ViT are pretrained on ImageNet and you showed the experimental results for 512x512 input.
But I can't find the pretrained models for ViT of 512x512 input.

How do you get your Epos?

I wanna use encoder of ViT in other CNN like u do. When I get the hidden feature by downsample, I don't know how to embed the feature. Maybe shape of hidden feature is HWC, I can flatten them to N*(P* P *C). But I don't know how to add the Epos.
image

Questions about your data

First of all, thank you for your contribution.I sent an email to you before and also downloaded your data set. I wanted to see your original data image, so I wrote a program to open your NPY data, but the image I got was all black, and the print also displayed 0 for the mdarray.What is the matter, please?

Also, if I want to train my own data set, is there anything I should pay attention to when creating it?

Has the performance of TransUNet been exaggerated?

Hi, thanks for your work. However, I have a big question on the comparison results reported in your paper. In Table 1, the V-Net performs the worst among all the competing methods. This really contradicts with our experience on medical image segmentation. How did you train V-Net and have you tried your best to obtain better results (e.g., good sampling strategy)?

"num_classes" setting

The initial num_classes value is 9(8 organs + background). But I want to segment one organ, so I changed the num_classes value to 2(1 organ + background) while code can not run. I tried some values from 3 to 13, it also can not run. But when num_classes >= 14, the code can run successfully.
I have no idea about this. Anyone who can explain this?

Train error

Traceback (most recent call last):
File "train.py", line 93, in
trainer[dataset_name](args, net, snapshot_path)
File "/home/lichangyong/Code/TransUNet/trainer.py", line 57, in trainer_synapse
loss_dice = dice_loss(outputs, label_batch, softmax=True)
File "/home/lichangyong/.pyenv/versions/pytorch_gpu/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/lichangyong/Code/TransUNet/utils.py", line 43, in forward
class_wise_dice.append(1.0 - dice.item())
RuntimeError: CUDA error: device-side assert triggered

something about custom dataset

Thanks for your excellent work. While I try to make my own dataset, I meet some troubles.

My first question is that what the dimensions of "image" and "label" in the Dataloader module.

def __getitem__(self, idx):
        if self.split == "train":
            slice_name = self.sample_list[idx].strip('\n')
            data_path = os.path.join(self.data_dir, slice_name+'.npz')
            data = np.load(data_path)
            image, label = data['image'], data['label']
        else:
            vol_name = self.sample_list[idx].strip('\n')
            filepath = self.data_dir + "/{}.npy.h5".format(vol_name)
            data = h5py.File(filepath)
            image, label = data['image'][:], data['label'][:]

        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        sample['case_name'] = self.sample_list[idx].strip('\n')
        return sample   

In the above code, could you tell me the shape of the sample['image'] and sample['label'] ?

train error

hello,i meet a error when i run python train.py --dataset Synapse --vit_name R50-ViT-B_16
Error is as follows:

Traceback (most recent call last):
  File "train.py", line 91, in <module>
    net.load_from(weights=np.load(config_vit.pretrained_path))
  File "D:\Python\papers\TransUNet\networks\vit_seg_modeling.py", line 429, in load_from
    unit.load_from(weights, n_block=uname)
  File "D:\Python\papers\TransUNet\networks\vit_seg_modeling.py", line 192, in load_from
    query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  File "C:\ProgramData\Anaconda3\envs\transnet\lib\site-packages\numpy\lib\npyio.py", line 260, in __getitem__
    raise KeyError("%s is not a file in the archive" % key)
KeyError: 'Transformer/encoderblock_0\\MultiHeadDotProductAttention_1/query\\kernel is not a file in the archive'

env is corrent.and i use win10 and download R50+ViT-B_16.npz model in the right place.can you help me ?

Length of each patch vector D

I am not very familiar with PyTorch but is the config.hidden_size = the length of vector which each patch has, i.e. D = config.hidden_size?

But then I have a question how do you choose the vector length, and while in a 16x16 monochromatic patch there are only 256 numbers why do we need that much more than 256 to represents the patch?

Hi, a question about the models' output

Thanks for your work. I have some problems about using it for another task.
I modified the dataloader, my input are {img (1,3,512,512), label(1,1,512,512) }, but on training phase, I got the output=model(image_patch) (1,2,512,512),
when calc loss, caused error RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 1, 512, 512].

args:
vit_name = R50-ViT-B_16, n_skip=3, vit_patches_size=16

can you give me some advice? thank you!

AttributeError: "'skip_channels'"

Hi,
When I run the test.py , I get the following error:

` File "/home/wf/anaconda3/envs/trunet/lib/python3.7/site-packages/ml_collections/config_dict/config_dict.py", line 883, in getitem
field = self._fields[key]
KeyError: 'skip_channels'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/wf/anaconda3/envs/trunet/lib/python3.7/site-packages/ml_collections/config_dict/config_dict.py", line 807, in getattr
return self[attribute]
File "/home/wf/anaconda3/envs/trunet/lib/python3.7/site-packages/ml_collections/config_dict/config_dict.py", line 889, in getitem
raise KeyError(self._generate_did_you_mean_message(key, str(e)))
KeyError: "'skip_channels'"

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/wf/anaconda3/envs/project_TransUNet_2/project_TransUNet/TransUNet/test.py", line 118, in
net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda()
File "/home/wf/anaconda3/envs/project_TransUNet_2/project_TransUNet/TransUNet/networks/vit_seg_modeling.py", line 378, in init
self.decoder = DecoderCup(config)
File "/home/wf/anaconda3/envs/project_TransUNet_2/project_TransUNet/TransUNet/networks/vit_seg_modeling.py", line 343, in init
skip_channels = self.config.skip_channels
File "/home/wf/anaconda3/envs/trunet/lib/python3.7/site-packages/ml_collections/config_dict/config_dict.py", line 809, in getattr
raise AttributeError(e)
AttributeError: "'skip_channels'"`

Then I change skip_channels = self.config.skip_channels (in vit_seg_modeling.py 343 line) to skip_channels = [512,256,64,16] , I get the following new error:

RuntimeError: Error(s) in loading state_dict for VisionTransformer: Unexpected key(s) in state_dict: "transformer.embeddings.hybrid_model.root.conv.weight", size mismatch for transformer.embeddings.patch_embeddings.weight: copying a param with shape torch.Size([768, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([768, 3, 16, 16])

For me, the cause of this is unclear. Does anyone have an idea?

features is a list?

Hi,
Thanks a lot for sharing your code. When I build the 'R50-ViT-B_16' model, there is an error:

vit_name = 'R50-ViT-B_16'
config_vit = CONFIGS[vit_name]
config_vit.n_classes = 2
config_vit.n_skip = 3
if vit_name.find('R50') != -1:
#config_vit.patches.grid = (int(args.img_size / args.vit_patches_size), int(args.img_size / args.vit_patches_size))
config_vit.patches.grid = (int(224 / 16), int(224/ 16))

net = ViT_seg(config_vit, img_size=224, num_classes=2)
summary(net, (1,224,224), batch_size=1)

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchsummary/torchsummary.py in summary(model, input_size, batch_size, device)
70 # make a forward pass
71 # print(x.shape)
---> 72 model(*x)
73
74 # remove these hooks

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
138 def forward(self, *inputs, **kwargs):
139 if not self.device_ids:
--> 140 return self.module(*inputs, **kwargs)
141
142 for t in chain(self.module.parameters(), self.module.buffers()):

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)

~/Desktop/TransUNet-main/networks/vit_seg_modeling.py in forward(self, x)
393 x = x.repeat(1,3,1,1)
394 print("x:", x.shape)
--> 395 x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
396 print("x:{}, attn_weights:{}, features:{}".format(x.shape, attn_weights.shape, features.shape))
397 x = self.decoder(x, features)

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)

~/Desktop/TransUNet-main/networks/vit_seg_modeling.py in forward(self, input_ids)
255 def forward(self, input_ids):
256 print("input_ids:", input_ids.shape)
--> 257 embedding_output, features = self.embeddings(input_ids)
258 #print("embedding_output:", embedding_output.shape)
259 encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)

~/Desktop/TransUNet-main/networks/vit_seg_modeling.py in forward(self, x)
157 def forward(self, x):
158 if self.hybrid:
--> 159 x, features = self.hybrid_model(x)
160 else:
161 features = None

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
--> 543 hook_result = hook(self, input, result)
544 if hook_result is not None:
545 result = hook_result

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchsummary/torchsummary.py in hook(module, input, output)
21 if isinstance(output, (list, tuple)):
22 summary[m_key]["output_shape"] = [
---> 23 [-1] + list(o.size())[1:] for o in output
24 ]
25 else:

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchsummary/torchsummary.py in (.0)
21 if isinstance(output, (list, tuple)):
22 summary[m_key]["output_shape"] = [
---> 23 [-1] + list(o.size())[1:] for o in output
24 ]
25 else:

AttributeError: 'list' object has no attribute 'size'

x, attn_weights, features = self.transformer(x), features is a list?

train error: after a time, NaN or Inf found in input tensor.

Hello,
executing
python train.py --dataset owndataset --vit_name R50-ViT-B_16 --batch_size 12 --max_iterations 1000 --max_epochs 350

Any idea for the reasons for this?

iteration 755 : loss : 0.232612, loss_ce: 0.031451
iteration 756 : loss : 0.235377, loss_ce: 0.039011
iteration 757 : loss : 0.235090, loss_ce: 0.031103
iteration 758 : loss : 0.234754, loss_ce: 0.035912
iteration 759 : loss : 0.242864, loss_ce: 0.030254
iteration 760 : loss : 0.243939, loss_ce: 0.029230
11%|███▍ | 38/350 [04:56<41:17, 7.94s/it]iteration 761 : loss : 0.232332, loss_ce: 0.029950
iteration 762 : loss : 0.236290, loss_ce: 0.032639
iteration 763 : loss : 0.239043, loss_ce: 0.029412
iteration 764 : loss : 0.223232, loss_ce: 0.036379
iteration 765 : loss : 0.227415, loss_ce: 0.031555
iteration 766 : loss : 0.228688, loss_ce: 0.030908
iteration 767 : loss : 0.246761, loss_ce: 0.032261
iteration 768 : loss : 0.230575, loss_ce: 0.029101
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
iteration 769 : loss : nan, loss_ce: nan
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
iteration 770 : loss : nan, loss_ce: nan
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
iteration 771 : loss : nan, loss_ce: nan
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
iteration 772 : loss : nan, loss_ce: nan
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
iteration 773 : loss : nan, loss_ce: nan
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
iteration 774 : loss : nan, loss_ce: nan
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
iteration 775 : loss : nan, loss_ce: nan
NaN or Inf found in input tensor.

Running the test.py module

Hi! I am very interested in your paper and solution for the medical image segmentation.

Regarding my first question, I ran the test.py and it outputs that there is a problem in line 122 in the code and
that there is no such file or directory '../model/TU_Synapse224/TU_pretrain_R50-ViT-B_16_skip3_bs24_224/epoch_29.pth'
I hope that you can help me with this problem.
I did run it on single image though, if that can be some sort of an issue.

Second question is, why is there a need to clip an image between [-125, 275] ? Image that I downloaded is grayscale of the uint8 data type.

TransUnet for RGB-images?

Hello, it seems that the code currently only works on grayscale images. II am interested in processing images with 3 channels (RGB). Has anyone already modified the code accordingly? What do I have to pay attention to?

Question about "patch_size"

Thanks for your work. I have some questions about the patch size of patch embedding when using CNN and Transformer as the encoder.

In the section 3.2 of the paper, it mentions that patch embedding is applied to 1x1 patches extracted from the CNN feature maps instead of from raw image when using CNN-Tranformer hybrid as the encoder.

From my understanding, regardless of the height and width of the feature map extracted from CNN, the patch embedding will be the nn.Conv2d with kernel_size=1 and stride=1.

Here is the code.

if config.patches.get("grid") is not None:   # ResNet
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
            n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])  
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)

When img_size=512, and configurations in get_r50_b16_config is applied, the outputs of patch_embedding will be a tensor which shape is (B, 1024, 16, 16). The height and width is 1/32, not 1/16 of the original image size.
So you will need total 5 times of upsampling operations instead of 4 times, which is different from your implementation.

Shouldn't the kernel_size and stride be 1 when using CNN-Tranformer as the encoder?

I would be very grateful for letting me know if it is my misunderstanding.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.