GithubHelp home page GithubHelp logo

maitrechen / medgan-reslite Goto Github PK

View Code? Open in Web Editor NEW
19.0 3.0 1.0 40.94 MB

一个用于肺炎图像分类的轻量级ResNet18-SAM模型实现,采用SH-DCGAN生成少类样本数据,解决了数据不平衡的问题,同时结合剪枝策略实现轻量化!It's still being updated! Stay tuned!❤

License: Apache License 2.0

Python 100.00%
augmentation medical-image-processing medmnist resnet lightweight-neural-network classification cnn dcgan spatial-attention deep-learning

medgan-reslite's Introduction

📣Introduction

This is a pneumonia classification project that addresses the issue of class imbalance by utilizing generative adversarial networks (GAN) to generate images of minority class samples. In addition, the spatial attention mechanism is introduced into ResNet18 to enhance the generalization performance of classifier!

🔥 Workflow

🚩Updates & Roadmap

🌟New Updates

  • ✅ Mar 21, 2023. Creat "MedGAN-ResLite" project repository and Find MedMNIST.
  • ✅ Mar 22, 2023. Generate pneumonia samples with DCGAN.
  • ✅ May 30, 2023. Replace original Loss function with Hinge Adversial Loss.
  • ✅ Apri 1, 2023. DCGAN + Spectral Normalization.
  • ✅ Apri 4, 2023. Add DCGAN metrics:Inception Score + FID + KID; Fuse and Split dataset;
  • ✅ Apri 5, 2023. Override the dataset inheritance class.
  • ✅ Apri 6, 2023. Write train, eval and infer scripts for classifier. And get a new-model by modifing input & output shape of pre-trained model. Add metrics:acc + auc + f1 + confusion matrix.
  • ✅ April 7, 2023. Add scripts: export_onnx.py and inference_onnx.py.
  • ✅ April 8, 2023. Tuning the hyperparameters of DCGAN.
  • ✅ April 10, 2023. Explore CBAM attention mechanism to add location and quantity.
  • ✅ April 14, 2023. Abalation Study: GAN, DCGAN, DCGAN+Hinge, DCGAN + SN, DCGAN + SH.
  • ✅ April 21, 2023. Attention mechanism visualization using CAM.
  • ✅ April 25, 2023. Make a Presentation.

💤Progress & Upcoming work

✅❗
Finished, and Successfully! Finished, but Failed! Unfinished!

Part 1: Dataset and Preprocessing

  • ❓ Experiment with more challenging datasets, such as ChestXRay2017, Kaggle, etc.
  • ❓ Consider introducing the idea of "learning" when scaling the image, such as adopting transposed convolution instead of interpolation when scaling up the image size

Part 2: Generation part

  • ✅❗ May 23, 2023. Try Muti-Scale Fusion.
  • ✅❗ May 25, 2023. Introduce class information into DCGAN to generate samples.【cDCGAN】
  • ❓ Replace original Loss function with Wasserstein distance.

Part 3: Classification part

  • ❓ Apply ensemble learning methods, such as voting evaluation.

Part 4: Lightweight-NN part

  • ❓ Pruning:one-hot + iterative ,including L1✅、L2✅、FPGM✅、BNScale.
  • ❓ Build the pruned model automatically.
  • ❓ Knowledge distillation:design lightweight network A,and use pruned-model to guide A.

Part 5: Depolyment part

  • ❓ Deploy model on CPU and NSC2 using OpenVINO. 【Python ✅and C++ version】.
  • ❓ Deploy on the web side using Django or flask.

Other

  • ❓ Explore the influence of attention mechanism on deep network and shallow network.

✨Usage

Install

Clone repo and install requirements.txt.

git clone [email protected]:MaitreChen/MedGAN-ResLite.git
cd MedGAN-ResLite
pip install -r requirements.txt

Preparations

Dataset

You can download dataset from this link. It includes the pneumoniamnist original real dataset and the fake dataset synthesized using GAN (see data README.md for details)

The dataset structure directory is as follows:

MedGAN-ResLite/
|__ data/
    |__ real/
        |__ train/
            |__ normal/
                |__ img_1.png
                |__ ...
            |__ pneumonia/
                |__ img_1.png
                |__ ...
        |__ val/
            |__ normal/
            |__ pneumonia/
        |__ test/
            |__ ...
    |__ fake/
        |__ ...

Pretrained Checkpoints

You can download pretrained checkpoints from this link and put it in your pretrained/ folder. It contains resnet18-sam and sh-dcgan model.

Inference

Classification part

🚀Quick start, and the results will be saved in the figures/classifier_torch folder.

python infer_classifier.py --ckpt-path pretrained/resnet18-sam.pth --image-path imgs/pneumonia_img1.png

🌜Here are the options in more detail:

Option Description
--ckpt-path Checkpoints path to load the pre-trained weights for inference.
--image-path Path of the input image for inference.
--device Alternative infer device, cpu or cuda, default is cpu.

📛Note

If you want to visualize the attention mechanism, run the following command and the results will be saved in the figures/heatmap folder.

python utils/cam.py --image-path imgs/pneumonia_img1.png

More information about CAM can be found here!💖

Generation part

🚀Quick start, and the results will be saved in the figures/generator_torch folder.

python infer_generator.py --ckpt-path pretrained/sn-dcgan.pth --batch-size 1 --mode -1

📛Note

If you want to generate fake images for training or sprite images, run following commands:

  • Generate a Sprite map. 【save results in outs/sprite

    python infer_generator.py --ckpt-path pretrained/sn-dcgan.pth --batch-size 64 --mode 0
  • Generate a batch of images. 【save results in outs/singles

    python infer_generator.py --ckpt-path pretrained/sn-dcgan.pth --batch-size 50 --mode 1

    💨When you generate a batch of images, batch-size is whatever you like❤

Evaluate

Classification part

python eval.py --ckpt-path pretrained/resnet18-sam.pth

Generation part

To evaluate a model, make sure you have torch-fidelity installed in requirements.txt❗

Then, you should prepare two datasets

  • training datasets in data/merge folder. 【real images】
  • generation datasets in outs folder. 【fake images】

Everything is ready, you can execute the following command:

fidelity --gpu 0 --fid --input1 data/merge --input2 data/outs/singles

More information about fidelity can be found here!💖

Train

Classification part

python train_classifier.py

💝 More details about training your own dataset

Please refer to data/config.yaml and README.md.

In addition, you need to set the normalized parameters mean and std! Please refer to utils/image_utils.py.

Generation part

python train_dcgan.py

Export

If you want to export the ONNX model for ONNXRuntime or OpenVINO, please refer to README.md!

Deploy

To use ONNXRuntime, refer to README.md and onnx/inference_onnx.py!

To use OpenVINO, refer to README.md!

🌞Results

Performance comparison of different GAN

Method Inception Score FID KID
GAN 2.20 260.15 0.42
DCGAN 2.20 259.72 0.39
SH-DCGAN 2.20 206.14 0.31

Original

SH-DCGAN

Ablation study

Method Inception Score FID KID
DCGAN 2.20 259.72 0.39
DCGAN + Hinge 2.20 252.42 0.38
DCGAN + SN 2.20 232.59 0.35
SH-DCGAN 2.20 206.14 0.31

Performance comparison before and after improvement

Comparison of different CNN models

Model Accuracy/% Precision/% Recall/% F1 score/%
AlexNet 90.16 90.16 90.16 90.16
VGG16 91.22 92.23 91.22 91.17
VGG19 91.76 92.70 91.76 91.71
ResNet34 92.55 93.26 92.55 92.52
ResNet50 91.15 92.44 92.15 92.14
MobileNetV2 92.29 92.60 92.29 92.27
ResNet18 92.02 92.02 92.02 92.02
ResNet18-SAM 93.48 93.82 93.48 93.47

Interpretability

📞Contact

For any questions or suggestions about this project, welcome everyone to raise issues!

Also, please feel free to contact [email protected].

Thank you, wish you have a pleasant experience~~💓🧡💛💚💙💜

medgan-reslite's People

Contributors

maitrechen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

1btu

medgan-reslite's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.