GithubHelp home page GithubHelp logo

vsmtrans's Introduction

A Hybrid Paradigm Integrate Self-attention and Convolution on 3D Medical Image Segmentation

[Paper] [Dataset] [BibTeX]

Variable-Shape design

The VSmTrans is a hybrid Transformer that tightly integrates self-attention and convolution into one paradigm, which enjoy the benefits of a large receptive field and strong inductive bias from both sides.

Installation

The code requires python>=3.9, as well as pytorch>1.12. Please follow the instructions here to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended.

Install nnUNet:

cd nnUNet
pip install -e .
pip install einops timm monai

Our model relies on the nnU-Net framework, and its needs to know where you intend to save raw data, preprocessed data and trained models. For this you need to set a few environment variables. Please follow the instructions

export nnUNet_raw="raw data path"
export nnUNet_preprocessed="preprocessed data path"
export nnUNet_results="result path"

Dataset Format

Datasets must be located in the nnUNet_raw folder (which you either define when installing nnU-Net or export/set every time you intend to run nnU-Net commands!). Each segmentation dataset is stored as a separate 'Dataset'. Datasets are associated with a dataset ID, a three digit integer, and a dataset name (which you can freely choose): For example, Dataset005_Prostate has 'Prostate' as dataset name and the dataset id is 5. Datasets are stored in the nnUNet_raw folder like this:

nnUNet_raw/
├── Dataset001_BrainTumour
├── Dataset002_Heart
├── Dataset003_Liver
├── Dataset004_Hippocampus
├── Dataset005_Prostate
├── ...

Within each dataset folder, the following structure is expected:

Dataset001_BrainTumour/
├── dataset.json
├── imagesTr
├── imagesTs  # optional
└── labelsTr

Json format details please refer to here.

Note that the naming of each data is required to specify the input modal type, e.g. BraTS. This dataset hat four input channels: FLAIR (0000), T1w (0001), T1gd (0002) and T2w (0003).

nnUNet_raw/Dataset001_BrainTumour/
├── dataset.json
├── imagesTr
│   ├── BRATS_001_0000.nii.gz
│   ├── BRATS_001_0001.nii.gz
│   ├── BRATS_001_0002.nii.gz
│   ├── BRATS_001_0003.nii.gz
│   ├── BRATS_002_0000.nii.gz
│   ├── BRATS_002_0001.nii.gz
│   ├── BRATS_002_0002.nii.gz
│   ├── BRATS_002_0003.nii.gz
│   ├── ...
├── imagesTs
│   ├── BRATS_485_0000.nii.gz
│   ├── BRATS_485_0001.nii.gz
│   ├── BRATS_485_0002.nii.gz
│   ├── BRATS_485_0003.nii.gz
│   ├── BRATS_486_0000.nii.gz
│   ├── BRATS_486_0001.nii.gz
│   ├── BRATS_486_0002.nii.gz
│   ├── BRATS_486_0003.nii.gz
│   ├── ...
└── labelsTr
    ├── BRATS_001.nii.gz
    ├── BRATS_002.nii.gz
    ├── ...

Here is another example of the second dataset of the MSD, which has only one input channel:

nnUNet_raw/Dataset002_Heart/
├── dataset.json
├── imagesTr
│   ├── la_003_0000.nii.gz
│   ├── la_004_0000.nii.gz
│   ├── ...
├── imagesTs
│   ├── la_001_0000.nii.gz
│   ├── la_002_0000.nii.gz
│   ├── ...
└── labelsTr
    ├── la_003.nii.gz
    ├── la_004.nii.gz
    ├── ...

Remember: For each training case, all images must have the same geometry to ensure that their pixel arrays are aligned. Also make sure that all your data is co-registered!

self.network = VSmixTUnet( in_channels=1, out_channels=16, feature_size=24, split_size=[1, 3, 5, 7], window_size=6, num_heads=[3, 6, 12, 24], img_size=[96, 96, 96], depths=[2, 2, 2, 2], patch_size=(2, 2, 2), do_ds=True )

Experiment planning and preprocessing

nnUNetv2_plan_and_preprocess -d DATASET_ID --verify_dataset_integrity -c 3d_fullres

Where DATASET_ID is the dataset id (duh). it is recommended that you use the --verify_dataset_integrity command. This will check for some of the most common error sources! For more information about all the options available to you please run nnUNetv2_plan_and_preprocess -h.

Training

nnUNetv2_train DATASET_NAME_OR_ID 3d_fullres FOLD [additional options, see -h]

DATASET_NAME_OR_ID specifies what dataset should be trained on and FOLD specifies which fold of the 5-fold-cross-validation is trained. More details please refer here

Inference

nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -d DATASET_NAME_OR_ID -c 3d_fullres -f FOLD [--save_probabilities]

Note that per default, inference will be done with all 5 folds from the cross-validation as an ensemble. We very strongly recommend you use all 5 folds. Thus, all 5 folds must have been trained prior to running inference.

If you wish to make predictions with a single model, train the all fold and specify it in nnUNetv2_predict with -f all

You can also run it directly, for the default configuration see nnUNet framework.

Model

Our model exists in the directory: nnUNet/nnunetv2/training/nnUNetTrainer/VSmTrans.py. You can easily to use it in own framework as follows:

self.network = VSmixTUnet(
    in_channels=1,
    out_channels=16,
    feature_size=48,
    split_size=[1, 2, 3, 4],
    window_size=6,
    num_heads=[3, 6, 12, 24],
    img_size=[96, 96, 96],
    depths=[2, 2, 2, 2],
    patch_size=(2, 2, 2),
    do_ds=enable_deep_supervision
)

Citation

vsmtrans's People

Contributors

qingze-bai avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.