GithubHelp home page GithubHelp logo

prantik-pdeb / litemedsam-lora Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lseventeen/litemedsam-lora

0.0 0.0 0.0 208 KB

License: Apache License 2.0

Shell 0.14% Python 99.80% Dockerfile 0.06%

litemedsam-lora's Introduction

LiteMedSAM

A lightweight version of MedSAM for fast training and inference. The model was trained with the following two states:

  • Stage 1. Distill a lightweight image encoder TinyViT from the MedSAM image encoder ViT by imposing the image embedding outputs to be the same
  • State 2. Replace the MedSAM image encoder ViT with TinyViT and fine-tune the whole pipeline

Installation

The codebase is tested with: Ubuntu 20.04 | Python 3.10 | CUDA 11.8 | Pytorch 2.1.2

  1. Create a virtual environment conda create -n medsam python=3.10 -y and activate it conda activate medsam
  2. Install Pytorch 2.0
  3. git clone -b LiteMedSAM https://github.com/bowang-lab/MedSAM/
  4. Enter the MedSAM folder cd MedSAM and run pip install -e .

Quick tutorial on making submissions to CVPR 2024 MedSAM on Laptop Challenge

Sanity test

  • Download the LiteMedSAM checkpoint here and put it in work_dir/LiteMedSAM.
  • Download the demo data demo_scribble.zip here
  • Run the following command for a sanity test
python CVPR24_LiteMedSAM_infer.py -i test_demo/imgs/ -o test_demo/segs

Build Docker

docker build -f Dockerfile -t litemedsam .

Note: don't forget the . in the end

Run the docker on the testing demo images

docker container run -m 8G --name litemedsam --rm -v $PWD/test_demo/imgs/:/workspace/inputs/ -v $PWD/test_demo/litemedsam-seg/:/workspace/outputs/ litemedsam:latest /bin/bash -c "sh predict.sh"

Note: please run chmod -R 777 ./* if you run into Permission denied error.

Save docker

docker save litemedsam | gzip -c > litemedsam.tar.gz

Compute Metrics

python evaluation/compute_metrics.py -s test_demo/litemedsam-seg -g test_demo/gts -csv_dir ./metrics.csv

Model Training

Data preprocessing

  1. Download the Lite-MedSAM checkpoint and put it under the current directory.
  2. Download the demo dataset. This tutorial assumes it is unzipped it to data/FLARE22Train/.
  3. Run the pre-processing script to convert the dataset to npz format:
python pre_CT_MR.py \
    -img_path data/FLARE22Train/images \ ## path to training images
    -img_name_suffix _0000.nii.gz \ ## extension of training images
    -gt_path data/FLARE22Train/labels \ ## path to training labels
    -gt_name_suffix .nii.gz \ ## extension of training labels
    -output_path data \ ## path to save the preprocessed data
    -num_workers 4 \ ## number of workers for preprocessing
    -modality CT \ ## modality of the preprocessed data
    -anatomy Abd \ ## anatomy of the preprocessed data
    -window_level 40 \ ## window level for CT
    -window_width 400 \ ## window width for CT
    --save_nii ## Also save the preprocessed data in nii.gz format for visual inspection in other software
  • Split dataset: first 40 cases of the demo dataset for training, saved in MedSAM_train, the last 10 for testing, saved in MedSAM_test.
  • For detailed usage of the script, see python pre_CT_MR.py -h.
  1. Convert the training npz to npy format for training:
python npz_to_npy.py \
    -npz_dir data/MedSAM_train \ ## path to the preprocessed npz training data
    -npy_dir data/npy \ ## path to save the converted npy data for training
    -num_workers 4 ## number of workers for conversion in parallel

Fine-tune pretrained Lite-MedSAM

The training pipeline requires about 10GB GPU memory with a batch size of 4

Single GPU

To train Lite-MedSAM on a single GPU, run:

python train_one_gpu.py \
    -data_root data/MedSAM_train \
    -pretrained_checkpoint lite_medsam.pth \
    -work_dir work_dir \
    -num_workers 4 \
    -batch_size 4 \
    -num_epochs 10

To resume interrupted training from a checkpoint, run:

python train_one_gpu.py \
    -data_root data/MedSAM_train \
    -resume work_dir/medsam_lite_latest.pth \
    -work_dir work_dir \
    -num_workers 4 \
    -batch_size 4 \
    -num_epochs 10

For additional command line arguments, see python train_one_gpu.py -h.

Multi-GPU

To fine-tune Lite-MedSAM on multiple GPUs, run:

python train_multi_gpus.py \
    -i data/npy \ ## path to the training dataset
    -task_name MedSAM-Lite-Box \
    -pretrained_checkpoint lite_medsam.pth \
    -work_dir ./work_dir_ddp \
    -batch_size 16 \
    -num_workers 8 \
    -lr 0.0005 \
    --data_aug \ ## use data augmentation
    -world_size <WORLD_SIZE> \ ## Total number of GPUs will be used
    -node_rank 0 \ ## if training on a single machine, set to 0
    -init_method tcp://<MASTER_ADDR>:<MASTER_PORT>

Alternatively, you can use the provided train_multi_gpus.sh script to train on multiple GPUs. To resume interrupted training from a checkpoint, add -resume <your_work_dir> to the command line arguments instead of the checkpoint path for multi-GPU training; the script will automatically find the latest checkpoint in the work directory. For additional command line arguments, see python train_multi_gpus.py -h.

Inference (sanity test)

The inference script assumes the testing data have been converted to npz format. To run inference on the 3D CT FLARE22 dataset, run:

python inference_3D.py \
    -data_root data/npz/MedSAM_test/CT_Abd \ ## preprocessed npz data
    -pred_save_dir ./preds/CT_Abd \
    -medsam_lite_checkpoint_path work_dir/medsam_lite_latest.pth \
    -num_workers 4 \
    --save_overlay \ ## save segmentation overlay on the input image
    -png_save_dir ./preds/CT_Abd_overlay \ ## only used when --save_overlay is set
    --overwrite ## overwrite existing predictions, default continue from existing predictions

For additional command line arguments, see python inference_3D.py -h.

We also provide a script to run inference on the 2D images inference_2D.py, whose usage is the same as the 3D script.

Frequently Asked Questions (FAQ)

What is the difference between the preprocessed npz and npy data?

  • The npz format is used to store both 2D and 3D images (focusing on the ROI), along with their corresponding ground truth masks. In the case of 3D images, spacings are also included. All these data elements are compactly packed into a single npz file. This format is primarily used for distributing our training and validation datasets. Notably, for the validation dataset, bounding boxes are provided in place of ground truth label masks.

  • In contrast, the npy format stores each 2D image or individual slice of a 3D image along with its label mask in separate files. This format is utilized when loading data for training purposes.

I'm having trouble loading my trained model's checkpoint for inference. What should I do?

If you encounter difficulties loading a trained model's checkpoint for inference, we recommend users first try using the extract_weights.py script located under MedSAM/utils/. This script is for extracting weights from your existing checkpoint and save them into a new checkpoint file.

To use this script, execute the following command in your terminal:

python extract_weights.py \
    -from_pth <YOUR_CHECKPOINT_PATH> \
    -to_pth <NEW_CHECKPOINT_PATH>

Replace <YOUR_CHECKPOINT_PATH> with the path to your saved trained model checkpoint, and <NEW_CHECKPOINT_PATH> with the desired path for the new checkpoint file. Once you have executed this command and created the new checkpoint, it should be ready for use in inference tasks.

Acknowledgements

We thank the authors of MobileSAM and TinyViT for making their source code publicly available.

litemedsam-lora's People

Contributors

lseventeen avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.