GithubHelp home page GithubHelp logo

desi-ivanov / med-seg-diff-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lucidrains/med-seg-diff-pytorch

0.0 1.0 0.0 299 KB

Implementation of MedSegDiff in Pytorch - SOTA medical segmentation using DDPM and filtering of features in fourier space

License: MIT License

Python 100.00%

med-seg-diff-pytorch's Introduction

MedSegDiff - Pytorch

Implementation of MedSegDiff in Pytorch - SOTA medical segmentation out of Baidu using DDPM and enhanced conditioning on the feature level, with filtering of features in fourier space.

Appreciation

  • StabilityAI for the generous sponsorship, as well as my other sponsors out there

  • Isamu and Daniel for adding a training script for a skin lesion dataset!

Install

$ pip install med-seg-diff-pytorch

Usage

import torch
from med_seg_diff_pytorch import Unet, MedSegDiff

model = Unet(
    dim = 64,
    image_size = 128,
    mask_channels = 1,          # segmentation has 1 channel
    input_img_channels = 3,     # input images have 3 channels
    dim_mults = (1, 2, 4, 8)
)

diffusion = MedSegDiff(
    model,
    timesteps = 1000
).cuda()

segmented_imgs = torch.rand(8, 1, 128, 128)  # inputs are normalized from 0 to 1
input_imgs = torch.rand(8, 3, 128, 128)

loss = diffusion(segmented_imgs, input_imgs)
loss.backward()

# after a lot of training

pred = diffusion.sample(input_imgs)     # pass in your unsegmented images
pred.shape                              # predicted segmented images - (8, 3, 128, 128)

Training

Command to run

accelerate launch driver.py --mask_channels=1 --input_img_channels=3 --image_size=64 --data_path='./data' --dim=64 --epochs=100 --batch_size=1 --scale_lr --gradient_accumulation_steps=4

If you want to add in self condition where we condition with the mask we have so far, do --self_condition

Todo

  • some basic training code, with Trainer taking in custom dataset tailored for medical image formats - thanks to @isamu-isozaki
  • full blown transformer of any depth in the middle, as done in simple diffusion

Citations

@article{Wu2022MedSegDiffMI,
    title   = {MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model},
    author  = {Junde Wu and Huihui Fang and Yu Zhang and Yehui Yang and Yanwu Xu},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2211.00611}
}
@inproceedings{Hoogeboom2023simpleDE,
    title   = {simple diffusion: End-to-end diffusion for high resolution images},
    author  = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
    year    = {2023}
}

med-seg-diff-pytorch's People

Contributors

dsbuddy avatar isamu-isozaki avatar lucidrains avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.