GithubHelp home page GithubHelp logo

yilmazkorkmaz1 / ssdiffrecon Goto Github PK

View Code? Open in Web Editor NEW
22.0 3.0 0.0 110 KB

Official TensorFlow implementation of Self-Supervised MRI Reconstruction with Unrolled Diffusion Models

License: MIT License

Python 92.62% Cuda 7.38%

ssdiffrecon's Introduction

Self-Supervised MRI Reconstruction with Unrolled Diffusion Models

Status License


Official Tensorflow implementation of SSDiffRecon (MICCAI2023)

About

Magnetic Resonance Imaging (MRI) produces excellent soft tissue contrast, albeit it is an inherently slow imaging modality. Promising deep learning methods have recently been proposed to reconstruct accelerated MRI scans. However, existing methods still suffer from various limitations regarding image fidelity, contextual sensitivity, and reliance on fully-sampled acquisitions for model training. To comprehensively address these limitations, we propose a novel self-supervised deep reconstruction model, named Self-Supervised Diffusion Reconstruction (SSDiffRecon). SSDiffRecon expresses a conditional diffusion process as an unrolled architecture that interleaves cross-attention transformers for reverse diffusion steps with data-consistency blocks for physics-driven processing. Unlike recent diffusion methods for MRI reconstruction, a self-supervision strategy is adopted to train SSDiffRecon using only undersampled k-space data. Comprehensive experiments on public brain MR datasets demonstrates the superiority of SSDiffRecon against state-of-the-art supervised, and self-supervised baselines in terms of reconstruction speed and quality.

Prerequisites

Required packages can easily be installed via conda:

conda env create -f environment.yml

Then:

conda activate ssdiffrecon_env

Tensorflow 1.14+ should also work fine since we do not use TF2 specific functionalities.

Datasets

  1. IXI dataset: https://brain-development.org/ixi-dataset/
  2. fastMRI Brain dataset: https://fastmri.med.nyu.edu/

For IXI dataset image dimensions are 256x256. For fastMRI dataset image dimensions vary with contrasts. (T1: 256x320, T2: 288x384, FLAIR: 256x320).

Tensorflow requires datasets in the tfrecords format. To create tfrecords file containing new datasets you can use dataset_tool.py.

Tfrecords files created for fastMRI and IXI datasets can be downloaded from the link:

https://drive.google.com/drive/folders/1h1kt8b4JgPOG-tNtRxAeEfMLUMIEzw9r?usp=drive_link

Coil-sensitivity-maps are estimated using ESPIRIT (http://people.eecs.berkeley.edu/~mlustig/Software.html).

Run Commands

After setting up the environment and downloading dataset files, you can simply run the following commands.

To train the single-coil model (IXI) with default parameters:

python run_ixi.py --train --exp_name ixi_trial1 --gpu 1

To train the multi-coil model (fastMRI) with learning rate 1e-5, run the following:

python run_fastmri.py --train --exp_name fastmri_trial1 --gpu 1 --lr 1e-5

To evaluate the single-coil model (IXI) with default parameters using the checkpoint at 1000th step of training, run the following:

python run_ixi.py --eval --results_dir ./results/ixi_trial1 --eval_checkpoint 1000

To evaluate the multi-coil model (fastMRI) with a beta_start parameter 0.005, run the following:

python run_fastmri.py --eval --results_dir ./results/fastmri_trial1 --beta_start 0.005

Trained Models

Trained model weights for both datasets can be downloaded from this link: https://drive.google.com/drive/folders/1ApxzBWqyD7Km0vAm-pILCyN6nvlsfjSg?usp=drive_link

Citation

You are encouraged to modify/distribute this code. However, please acknowledge this code and cite the paper appropriately.

@article{korkmaz2023self,
  title={Self-Supervised MRI Reconstruction with Unrolled Diffusion Models},
  author={Korkmaz, Yilmaz and Cukur, Tolga and Patel, Vishal M.},
  journal={arXiv preprint arXiv:2306.16654},
  year={2023}
}

Acknowledgements

This github page utilizes libraries from https://github.com/hojonathanho/diffusion/tree/master and https://github.com/icon-lab/SLATER/tree/main.

ssdiffrecon's People

Contributors

yilmazkorkmaz1 avatar

Stargazers

Andrew Jensen avatar  avatar Mary-Brenda Akoda avatar Chenhe Du avatar Xiaoyu Qiao avatar  avatar Echo-cc avatar ozkan avatar MRI_CY avatar MemeCat avatar Bhupender Kaushal avatar Amir Shamaei avatar Lin Shuijin avatar  avatar Ashish Sinha avatar  avatar  avatar Rohan Mitra avatar nothing  avatar Bin Chen avatar Morgan Hough avatar Dimitris Karkalousos avatar

Watchers

Ashish Sinha avatar Kostas Georgiou avatar  avatar

ssdiffrecon's Issues

Re-Implementing your code

Hi, we have a doubt in reimplementing your code, please share any docker-file/specific versions of your package versions. We are facing multiple errors while re-running the code.
We are working on Ubuntu 22 and nvidia A4000, we tried docker as well.

error when try with undersampled data

hi, I had successfully reappeared your erxperience with fastmri train data, and get the inference with fastmri test data{-r09.tfrecords}, as far as I know, -r09 means 512512, so I try to inference with -r08, which is 256256, I got this error:
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.

(0) Invalid argument: required broadcastable shapes at loc(unknown)
[[{{node while/cond/add}}]]
[[Abs/_1463]]
(1) Invalid argument: required broadcastable shapes at loc(unknown)
[[{{node while/cond/add}}]]
0 successful operations.
0 derived errors ignored.

  1. maybe I need to pad the images before I test?
  2. Additionally, I find images in -r09.tfrecords all have two channels, like 2512512, I don't get why two channels are needed? and is there any relation between these two channels?
  3. I noticed that you mentioned Mapper gets two inputs: time index and extracted information about data (like image contrast, R), how did you get the latter? auto by attention or man-made by adding a condition label?
    Looking forward to your reply! I can't be more honored to reappear your method.

A question about training time

Hi, I'm reproducing your code and using the dataset you provided. But there is a problem during the traning. We've been training for two weeks, and the program hasn't stopped. So, I'd like to ask you how long it takes you to train. I was wondering if there was a parameter I didn't setup that was causing the training to end or get stuck in a loop.

We are working on Centos7. And a single NVIDIA RTX A40 gpu is used for training.

Thank you very much!

Reproducing model effects

Hello, I have successfully reproduced your code using the dataset you provided. But there is a problem, I made my own training set and test set for training and testing, and the model I got using my own test set is a mess, but using the test set provided by you the result is normal, and I am using the mask training and testing dataset provided by you, can you explain in detail about your dataset creation?
Thank you very much, it's an honour to reproduce your project!

question about data details

It's my great honor to reappear your method, while there came a confusion when i try the code with your fastmri data.
file{fastmri_mixed_us/-r09.tfrecords} which downloaded from your webdriver link contains two things: data and labels, I wonder what the labels mean? it seems like one-hot code, your paper mentioned the mapper network takes the index of time and extracted label of undersampled image(i.e. rate or contrast), while I don't find any clue about the latter input like rate or contrast in code. Looking forward to your help, thanks!

A question about test data

Hi, great paper and thanks for providing the code and related dataset.
I'm having difficulty reproducing your test result. May I ask how to get ground truth brain MRI test image? I didn't see reconstructed image in fastMRI multicoil brain test dataset.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.