GithubHelp home page GithubHelp logo

arnoweng / chexnet Goto Github PK

View Code? Open in Web Editor NEW
555.0 24.0 215.0 78.44 MB

A pytorch reimplementation of CheXNet

Python 100.00%
thoracic-diseases classification localization pytorch pneumonia x-ray medical-images deep-learning

chexnet's Introduction

CheXNet for Classification and Localization of Thoracic Diseases

This is a Python3 (Pytorch) reimplementation of CheXNet. The model takes a chest X-ray image as input and outputs the probability of each thoracic disease along with a likelihood map of pathologies.

Dataset

The ChestX-ray14 dataset comprises 112,120 frontal-view chest X-ray images of 30,805 unique patients with 14 disease labels. To evaluate the model, we randomly split the dataset into training (70%), validation (10%) and test (20%) sets, following the work in paper. Partitioned image names and corresponding labels are placed under the directory labels.

Prerequisites

  • Python 3.4+
  • PyTorch and its dependencies

Usage

  1. Clone this repository.

  2. Download images of ChestX-ray14 from this released page and decompress them to the directory images.

  3. Specify one or multiple GPUs and run

    python model.py

Comparsion

We followed the training strategy described in the official paper, and a ten crop method is adopted both in validation and test. Compared with the original CheXNet, the per-class AUROC of our reproduced model is almost the same. We have also proposed a slightly-improved model which achieves a mean AUROC of 0.847 (v.s. 0.841 of the original CheXNet).

Pathology Wang et al. Yao et al. CheXNet Our Implemented CheXNet Our Improved Model
Atelectasis 0.716 0.772 0.8094 0.8294 0.8311
Cardiomegaly 0.807 0.904 0.9248 0.9165 0.9220
Effusion 0.784 0.859 0.8638 0.8870 0.8891
Infiltration 0.609 0.695 0.7345 0.7143 0.7146
Mass 0.706 0.792 0.8676 0.8597 0.8627
Nodule 0.671 0.717 0.7802 0.7873 0.7883
Pneumonia 0.633 0.713 0.7680 0.7745 0.7820
Pneumothorax 0.806 0.841 0.8887 0.8726 0.8844
Consolidation 0.708 0.788 0.7901 0.8142 0.8148
Edema 0.835 0.882 0.8878 0.8932 0.8992
Emphysema 0.815 0.829 0.9371 0.9254 0.9343
Fibrosis 0.769 0.767 0.8047 0.8304 0.8385
Pleural Thickening 0.708 0.765 0.8062 0.7831 0.7914
Hernia 0.767 0.914 0.9164 0.9104 0.9206

Contributions

This work was collaboratively conducted by Xinyu Weng, Nan Zhuang, Jingjing Tian and Yingcheng Liu.

Our Team

All of us are students/interns of Machine Intelligence Lab, Institute of Computer Science & Technology, Peking University, directed by Prof. Yadong Mu (http://www.muyadong.com).

chexnet's People

Contributors

arnoweng avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

chexnet's Issues

Model not working on Real Life Xrays

I have couple of frontal chest x rays which I know have disease. But when I predict it with the model, it shows very less probability for all the 14 diseases.

It doesn't seem to work on real life X-rays

Please help

Regd. Data Normalization

Hi;

Thank you for sharing the implementation. As mentioned in the paper, I see that you are normalizing the images with the mean and standard deviation of imagenet data. However, the Xray images are between the range 0-255, should'nt you be scaling the value of the image between 0-1 and then normalizing with the mean and standard.?

what torchvision version?

in main
transforms.Resize(256),

AttributeError: module 'torchvision.transforms' has no attribute 'Resize'

Replication having probelm.

Hello, Great work. Here I just want to discuss that the requirements are fulfilled but still getting errors in model Unexpected key(s) in state_dict.
Please update the readme that how a newbie can run them and retrain or test the model and get the AUC and eval parameters.
Thanks

Run model.py at a strange position

I have already download the images and put then to the right position.
After this step , I run model.py.
But the process stop at a strange position.

root@fd1a0b35acfd:/CheXNet# python model.py
=> loading checkpoint
=> loaded checkpoint

The process is stop .
Can anyone help me to solve this problem??
Thank you very much

FPGA Development

Is it able to run on FPGA environment?
Dose Pytorch support implement on FPGA?
Thanks

How to Cite?

Hi! I wanted to cite this repository. How can I do so?

I have an error with the first line of the code. the error is about 'tuple' object is not callable

for i, (inp, target) in enumerate(test_loader):
target = target.cuda()
gt = torch.cat((gt, target), 0)
bs, n_crops, c, h, w = inp.size()
input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda(), volatile=True)
output = model(input_var)
output_mean = output.view(bs, n_crops, -1).mean(1)
pred = torch.cat((pred, output_mean.data), 0)

AUROCs = compute_AUCs(gt, pred)
AUROC_avg = np.array(AUROCs).mean()

Can't pickle local object

Hi,
Thank you for your sharing.
When I run the code, There is an error occurred in this line "for i, (inp, target) in enumerate(test_loader)" ,
Somebody tell me that "python can't pickle functions. ", can you point me the right way to fix it?
I run the code on win7 python 3.6 pytorch 0.1.12.
Than you~

cannot be unzipped

The compressed package model.pth in the folder is shown to be corrupt and cannot be unzipped

你好

能不能发给我一部分测试数据,我这没法下这些图片

协议声明

你好,我觉得你的项目很棒,工作很好,不过能不能把你的项目加个协议,如果是apache的就最好了。还有另一个隐写项目也是一样的问题

Broken Pipeline error

I am getting this error of Broken pipe.

runfile('E:/constalytics/chexnet-master_2pytorosho/Main.py', wdir='E:/constalytics/chexnet-master_2pytorosho')
Reloaded modules: ChexnetTrainer, DensenetModels, DatasetGenerator
Training NN architecture =  DENSE-NET-121
D:\Anaconda\envs\pytr\lib\site-packages\torchvision\models\densenet.py:212: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
  nn.init.kaiming_normal(m.weight.data)
Traceback (most recent call last):

  File "<ipython-input-2-97c96155fd04>", line 1, in <module>
    runfile('E:/constalytics/chexnet-master_2pytorosho/Main.py', wdir='E:/constalytics/chexnet-master_2pytorosho')

  File "D:\Anaconda\envs\pytr\lib\site-packages\spyder\utils\site\sitecustomize.py", line 705, in runfile
    execfile(filename, namespace)

  File "D:\Anaconda\envs\pytr\lib\site-packages\spyder\utils\site\sitecustomize.py", line 102, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "E:/constalytics/chexnet-master_2pytorosho/Main.py", line 82, in <module>
    runTrain()

  File "E:/constalytics/chexnet-master_2pytorosho/Main.py", line 54, in runTrain
    ChexnetTrainer.train(pathDirData, pathFileTrain, pathFileVal, nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, imgtransResize, imgtransCrop, timestampLaunch, None)

  File "E:\constalytics\chexnet-master_2pytorosho\ChexnetTrainer.py", line 94, in train
    ChexnetTrainer.epochTrain (model, dataLoaderTrain, optimizer, scheduler, trMaxEpoch, nnClassCount, loss)

  File "E:\constalytics\chexnet-master_2pytorosho\ChexnetTrainer.py", line 116, in epochTrain
    for batchID, (input, target) in enumerate (dataLoader):

  File "D:\Anaconda\envs\pytr\lib\site-packages\torch\utils\data\dataloader.py", line 451, in __iter__
    return _DataLoaderIter(self)

  File "D:\Anaconda\envs\pytr\lib\site-packages\torch\utils\data\dataloader.py", line 239, in __init__
    w.start()

  File "D:\Anaconda\envs\pytr\lib\multiprocessing\process.py", line 105, in start
    self._popen = self._Popen(self)

  File "D:\Anaconda\envs\pytr\lib\multiprocessing\context.py", line 223, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)

  File "D:\Anaconda\envs\pytr\lib\multiprocessing\context.py", line 322, in _Popen
    return Popen(process_obj)

  File "D:\Anaconda\envs\pytr\lib\multiprocessing\popen_spawn_win32.py", line 65, in __init__
    reduction.dump(process_obj, to_child)

  File "D:\Anaconda\envs\pytr\lib\multiprocessing\reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)

BrokenPipeError: [Errno 32] Broken pipe

Any idea how to fix this?

Did you use data augment?

Thx for sharing the code! did you use data augment?
and this seems only testing code, would you mind share training code also

Thanks

Low AUROC

The description describes a mean AUROC of 0.847, however when I run the model I get a mean AUROC of 0.489, here's the output from running the model.py script:

ambouk3@u109023:/data/ambouk3/CheXNet$ python3 model.py
=> no checkpoint found
model.py:77: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
  input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda(), volatile=True)
The average AUROC is 0.489
The AUROC of Atelectasis is 0.5147218977086381
The AUROC of Cardiomegaly is 0.4816967965324666
The AUROC of Effusion is 0.5525329228378363
The AUROC of Infiltration is 0.4785231304193095
The AUROC of Mass is 0.5108256156533197
The AUROC of Nodule is 0.4956833510675879
The AUROC of Pneumonia is 0.4754862089500211
The AUROC of Pneumothorax is 0.5329778292671846
The AUROC of Consolidation is 0.49934268439528523
The AUROC of Edema is 0.429807208063108
The AUROC of Emphysema is 0.3639704709500116
The AUROC of Fibrosis is 0.5687886732196019
The AUROC of Pleural_Thickening is 0.443773228540649
The AUROC of Hernia is 0.5042597897539615

I am using the latest version of PyTorch and running it on a Quadro RTX 6000 with a batch size of 8, any ideas as to why this is happening?

Tencrop on test images & preprocessing imbalanced data

Hi arnoweng,

I have two following questions in this project.
1):
I'm wondering why to apply tencrop technique on testing images. I thought data augmentation techniques should only be applied on training set in order to add diversity of training images while test images should be kept unchanged so the testing results can be compared with others. If you use tencrop on testing images, you are technically using a different testing set, right?
2):
Do you think if it is necessary to pre-process the imbalanced data? I noticed there is a HUGE imbalance between the sample number of hernia and other diseases (~200 images vs 10000 e.g. infiltration), and thus different testing set would result in really different aucroc on at least Hernia. e.g. If my test set accidentally include only just 10 Hernia images, I guess the aucroc score of Hernia in this test set would be really high like 93%. In contrast, if there are around 150 Hernia in test set, the aucroc score would be low.
Thanks a lot!

What loss function do you use?

Hello
Thx for your sharing code.I'm trying to make a caffe reimplementation of it.I use SigmoidCrossEntropyLoss ,and I modified batch size,learning rate several times.But I can't get a good result.Actually,I just get a mean AUROC of 70%.
Could you give some information of loss function about your work?

About the program

Thank you for sharing.
I want to know if this program is according to Andrew Y Ng. 's paper: CheXNet: Radiologist-Level Pneumonia Detection on Chest X-Rays with Deep Learning
Also i want to know if you can share the training code, thank you!

Running ChexNet on normal chest x ray produces FALSE positive (pretty much every time)

False positives is everywhere.

I ran on few, such as this one: https://radiopaedia.org/cases/normal-chest-x-ray

Results:

[ { "Atelectasis":0.513898491859436 }, { "Cardiomegaly":0.4632066786289215 }, { "Effusion":0.7996526956558228 }, { "Infiltration":0.3642980754375458 }, { "Mass":0.4949002265930176 }, { "Nodule":0.2438926249742508 }, { "Pneumonia":0.35062509775161743 }, { "Pneumothorax":0.2875746786594391 }, { "Consolidation":0.5513240098953247 }, { "Edema":0.4362095892429352 }, { "Emphysema":0.5255514979362488 }, { "Fibrosis":0.4974825382232666 }, { "Pleural_Thickening":0.38915345072746277 }, { "Hernia":0.5669712424278259 } ]

When it fact, this is a clean Xray

I ran on few other actual CXR with no known issues.. yet its' producing all kinds of false positives.

Has anyone else noticed that the model is faulty?

Unable to run the model on CPU

Was trying to run the model in cpu. Below are the changes made for the same, but facing below issue in running the model, can you help to fix this?

Error generated:

Traceback (most recent call last):
File "model.py", line 129, in
main()
File "model.py", line 40, in main
model.load_state_dict(checkpoint['state_dict'])
File "E:\anaconda\lib\site-packages\torch\nn\modules\module.py", line 719, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.densenet121.features.denseblock1.denselayer1.norm1.weight", "module.densenet121.features.denseblock1.denselayer1.norm1.bias", "module.densenet121.features.denseblock1.denselayer1.norm1.running_mean", "module.densenet121.features.denseblock1.denselayer1.norm1.running_var",

Model changes:

`def main():

cudnn.benchmark = False

# initialize and load the model
model = DenseNet121(N_CLASSES).to(torch.device("cpu"))
model = torch.nn.DataParallel(model)

if os.path.isfile(CKPT_PATH):
    print("=> loading checkpoint")
    checkpoint = torch.load(CKPT_PATH, map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint")
else:
    print("=> no checkpoint found")

normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])

test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                image_list_file=TEST_IMAGE_LIST,
                                transform=transforms.Compose([
                                    transforms.Resize(256),
                                    transforms.TenCrop(224),
                                    transforms.Lambda
                                    (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                                    transforms.Lambda
                                    (lambda crops: torch.stack([normalize(crop) for crop in crops]))
                                ]))
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                         shuffle=False, num_workers=8, pin_memory=True)

# initialize the ground truth and output tensor
gt = torch.FloatTensor()
#gt = gt.cuda()
pred = torch.FloatTensor()
#pred = pred.cuda()

# switch to evaluate mode
model.eval()`

How to create the heatmap

Your work is good.
I want to know how to create the heatmap. Now the "output" variable in the code is a 10*14 mat, and it seems not the heatmap.

Code for training the model?

Hi, thanks a lot for providing CheXNet implementation in PyTorch!

Could you please also provide the script to train the model on nih-14 dataset (exactly the same way it was trained in CheXNet paper)?

Also have you used the train-test split provided in the nih-14 dataset itself or have you used your own strategy for splitting the data? Looking forward to hearing from you.

how to train

Hi, I'm new to pytorch.
I have no idea how to train by this code, because there is only for eval.
Hope you can release the training code. It may be pretty helpful, thx :)

model.load_state_dict(checkpoint['state_dict']) error with pytorch 0.4.0

I was running the code without any problem on pytorch 0.3.0.
I upgraded yesterday to pytorch 0.4.0 and can't load the checkpoint file. I am on Ubuntu and python 3.6 in conda env.
I get this error:

RuntimeError Traceback (most recent call last)
in ()
181 if name == 'main':
--> 182 main()

in main()
39 print("=> loading checkpoint")
40 checkpoint = torch.load(CKPT_PATH)
---> 41 model.load_state_dict(checkpoint['state_dict'])
42 print("=> loaded checkpoint")
43 else:

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
719 if len(error_msgs) > 0:
720 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 721 self.class.name, "\n\t".join(error_msgs)))
722
723 def parameters(self):

RuntimeError: Error(s) in loading state_dict for DenseNet121:
Missing key(s) in state_dict: "densenet121.features.conv0.weight", "densenet121.features.norm0.weight", "densenet121.features.norm0.bias", "densenet121.features.norm0.running_mean", "densenet121.features.norm0.running_var", "densenet121.features.denseblock1.denselayer1.norm1.weight", "densenet121.features.denseblock1.denselayer1.norm1.bias", "densenet121.features.denseblock1.denselayer1.norm1.running_mean",
(entire network ...)
"module.densenet121.features.denseblock4.denselayer16.conv.2.weight", "module.densenet121.features.norm5.weight", "module.densenet121.features.norm5.bias", "module.densenet121.features.norm5.running_mean", "module.densenet121.features.norm5.running_var", "module.densenet121.classifier.0.weight", "module.densenet121.classifier.0.bias".

It is likely related to this information about pytorch 0.4.0:
https://pytorch.org/2018/04/22/0_4_0-migration-guide.html
New edge-case constraints on names of submodules, parameters, and buffers in nn.Module
name that is an empty string or contains "." is no longer permitted in module.add_module(name, value), module.add_parameter(name, value) or module.add_buffer(name, value) because such names may cause lost data in the state_dict. If you are loading a checkpoint for modules containing such names, please update the module definition and patch the state_dict before loading it.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.