GithubHelp home page GithubHelp logo

facebookresearch / covidprognosis Goto Github PK

View Code? Open in Web Editor NEW
157.0 15.0 41.0 36 KB

COVID deterioration prediction based on chest X-ray radiographs via MoCo-trained image representations

Home Page: https://arxiv.org/abs/2101.04909

License: MIT License

Python 100.00%
medical-imaging deep-learning radiography x-ray pytorch covid-19 medical-image-analysis

covidprognosis's People

Contributors

anuroopsriram avatar dependabot[bot] avatar mmuckley avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

covidprognosis's Issues

Run inference

Thank you for sharing your work! I would like to know how to run an inference on some data to obtain the output mentioned on your paper.
Thank you in advance

Transformer masking

I have a question about how the transformer layer treats sequences of images of different lengths. Consider the following example evaluation of a MIPModel:

image_model = DenseNet()
net = MIPModel(image_model, 1024, 64, 2, 2, 1024, 0.1, 0.1, 'sum')
B1, N1, C1, H1, W1 = 2, 4, 3, 224, 224
images = torch.rand(B1, N1, C1, H1, W1)
times = torch.tensor([
[-30, -20, -10, 0],
[-30, -20, 0, 0], # Extra 0 for padding
])
lens = torch.tensor([4, 3])
x = net(images, times, lens)

In this batch of size 2, the first sequence has 4 images, and the second sequence has 3 images. We add padding to the times corresponding to the shorter sequence, and we pad the shorter sequence of images with a "random" image so they are both of length 4.

This "padding" image should not be considered by the transformer. However, the function "_apply_transformer" in MIPModel does not mask this image for the transformer. In fact, "lens" is input into "_apply_transformer", but it is not used. "Lens" is used to drop the extra images in "image_feats_combined" during the "_pool" function, but the transformer has already used the "padding" images to create "image_feats_trans".

Is this correct, and if so, how should we pad image sequences of different lengths so that the "padding" images do not affect the transformer output?

I cannot figure out the inputs to MIPModel

Hope anyone could help figure out the inputs to MIPModel.

import mip_model as mip
image_model = mip.Densenet()
net = mip.MIPModel(image_model, 1024, 64, 2, 2, 1024, 0.1, 0.1, 'sum')
B1, N1, C1, H1, W1 = 2, 4, 3, 224, 224
images = torch.rand(B1, N1, C1, H1, W1)
x = net(images, torch.tensor([-30, -20, -10, 0]), torch.tensor([N1]))
print(x.shape)

Errors:

Traceback (most recent call last):
  File "mip_model.py", line 305, in <module>
    net(images, torch.tensor([-30, -20, -10, 0]), torch.tensor([N1]))
  File "/opt/software/install/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "mip_model.py", line 144, in forward
    return self.classifier(image_feats_pooled)
  File "/opt/software/install/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/software/install/miniconda3/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/opt/software/install/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 1690, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x4 and 1088x2)

Missing keys error when loading pretrain model to train_sip.py

Hi, I tried to fine tune train_sip.py by loading a pretrained model. I tried both download your provided models, or my own pretrained models, but it all gives me error at line 210: trainer.fit(model.model, datamodule=data_module)

The error says: "
Exception has occurred: RuntimeError
Error(s) in loading state_dict for SipModule:
Missing key(s) in state_dict: "pos_weights", "model.features.conv0.weight", "model.features.norm0.weight", "model.features.norm0.bias", "model.features.norm0.running_mean", "model.features.norm0.running_var", "model.features.denseblock1.denselayer1.norm1.weight", "model.features.denseblock1.denselayer1.norm1.bias", "model.features.denseblock1.denselayer1.norm1.running_mean", "model.features.denseblock1.denselayer1.norm1.running_var", "model.features.denseblock1.denselayer1.conv1.weight", "model.features.denseblock1.denselayer1.norm2.weight", "model.features.denseblock1.denselayer1.norm2.bias", "model.features.denseblock1.denselayer1.norm2.running_mean", "model.features.denseblock1.denselayer1.norm2.running_var", "model.features.denseblock1.denselayer1.conv2.weight", "model.features.denseblock1.denselayer2.norm1.weight", "model.features.denseblock1.denselayer2.norm1.bias", "model.features.denseblock1.denselayer2.norm1.running_mean", "model.features.denseblock1.denselayer2.norm1.running_var", "model.features.denseblock1.denselayer2.conv1.weight", "model.features.denseblock1.denselayer2.norm2.weight", "model.features.denseblock1.denselayer2.norm2.bias", "model.features.denseblock1.denselayer2.norm2.running_mean", "model.features.denseblock1.denselayer2.norm2.running_var", "model.features.denseblock1.denselayer2.conv2.weight", "model.features.denseblock1.denselayer3.norm1.weight", "model.features.denseblock1.denselayer3.norm1.bias",
........
"

But I can see that the model after declaration contains these keys.

Any hint of why this is happening? Could it be an pytorch-lightning issue or what?

Thanks in advance for help

Inference result from last Linear Classifier of MIP Model

Got a question about the MIPModel class. Specifically

return self.classifier(image_feats_pooled)

On the paper, it was mentioned that the output for MIP was

MIP Predictions
ICU24: 0.784
Int24: 0.782
Mor24: 0.965

but currently output of MIPModel(image, times, lens) would be

[0.2244, -0.4123]

even if trained with new images.

  1. Is it implied that we would be the one to do the update so that we could get something like [0.53, 0.47] for 2 classes or [0.30, 0.23, 0.47 ] for 3 classes example would be by adding nn.Softmax() after nn.Linear()? If not, does that mean we'll just treat the output of nn.Linear as the same as the output you've shared on the paper, where we could expect negatives as classes that won't happen, then positives as classes that may happen?
  2. For the MIP prediction results stated on the paper, is it possible if you could share if you encountered results where, ICU24 and INT24 was predicted but Mor24 didn't and how the result looks like?

Thanks again for creating this.

Transform for trainset and valset during pretraining

Hi, after reading transforms used for pretraining MoCo, its seems that trainset and valset use the same transforms, which is

    transform_list = [
        transforms.RandomResizedCrop(args.im_size, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        RandomGaussianBlur(),
        AddGaussianNoise(snr_range=(4, 8)),
        HistogramNormalize(),
        TensorToRGB(),
    ]

however, the transform implemented in MoCoV2 seems to avoid non-deterministic transforms for valset such as Random*. So I think there are some differences between the original MoCoV2 and this paper implementation. Do you have any idea about this one?

SIP Finetune on CheXpert - Validation sanity check error - target has to be an integer tensor

Hello,
I am trying to finetune with train_sip.py and the public dataset CheXpert.
At the moment, I simply launched the training by specifying one of the pretrained model available online and the dataset name and directory. Is there any other configuration step that should be made?

At the moment, i am getting the error "target has to be an integer tensor" during _basic_input_validation (see log below).
When going back to the function validation_epoch_end ("sip_finetune.py", line 213), i noticed that "targets" is defined as a Tensor - dtype = float32 , while "logits" values are in the range [-1,1]. I was expecting them to be between 0 and 1... is that right?
What could I do to fix this?

Thank you in advance!
Silvia

LOG:
_Validation sanity check: 50%|█████ | 1/2 [00:00<00:00, 1.40it/s]path: Atelectasis, len: 64
Traceback (most recent call last):

File "/Facebook_NYU/CovidPrognosis-master/cp_examples/sip_finetune/train_sip.py", line 205, in
cli_main(args)

[...]
File "/Facebook_NYU/CovidPrognosis-master/cp_examples/sip_finetune/sip_finetune.py", line 231, in validation_epoch_end
self.val_acc[i](logits, targets)

File "/home/anaconda3/envs/facebook/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)

File "/home/anaconda3/envs/facebook/lib/python3.9/site-packages/torchmetrics/metric.py", line 152, in forward
self.update(*args, **kwargs)

File "/home/anaconda3/envs/facebook/lib/python3.9/site-packages/torchmetrics/metric.py", line 199, in wrapped_func
return update(*args, **kwargs)

File "/home/anaconda3/envs/facebook/lib/python3.9/site-packages/pytorch_lightning/metrics/classification/accuracy.py", line 138, in update
correct, total = _accuracy_update(

File "/home/anaconda3/envs/facebook/lib/python3.9/site-packages/pytorch_lightning/metrics/functional/accuracy.py", line 25, in _accuracy_update
preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)

File "/home/anaconda3/envs/facebook/lib/python3.9/site-packages/pytorch_lightning/metrics/classification/helpers.py", line 433, in _input_format_classification
case = _check_classification_inputs(

File "/home/anaconda3/envs/facebook/lib/python3.9/site-packages/pytorch_lightning/metrics/classification/helpers.py", line 296, in _check_classification_inputs
_basic_input_validation(preds, target, threshold, is_multiclass)

File "/home/anaconda3/envs/facebook/lib/python3.9/site-packages/pytorch_lightning/metrics/classification/helpers.py", line 62, in _basic_input_validation
raise ValueError("The target has to be an integer tensor.")

ValueError: The target has to be an integer tensor._

Can you provide load and predict guideline on Colab?

Hi, thanks for your excellent works!

I've impressed on this task, so I wanna apply your model for my hospital datasets.

But there was no guideline for pre-trained model load and predict. (it was too hard for me..)

Can you provide some codes on Colab?

Thank you!

SUM Pooling causing multiplication of matrix with different shape

Hi @anuroopsriram
Thanks for the help on issue #7 , got to understand how the times and lens was properly structured. Although encountered something.

On the premise of

mipmodel = MIPModel(
    image_model,
    feature_dim,
    64,
    2,
    2,
    feature_dim,
    0.1,
    0.1,
    "last_timestep"
)

your example

B1, N1, C1, H1, W1 = 2, 4, 3, 224, 224
images = torch.rand(B1, N1, C1, H1, W1)
times = torch.tensor([
    [-30, -20, -10, 0],
    [-30, -20, 0, 0],      # Extra 0 for padding
])
lens = torch.tensor([4, 4])
mipmodel.eval()
x = mipmodel(images, times, lens)
print(images.shape, x.shape)

works perfectly well, but if I returned the pooling from last_timestep to sum. I'm still encountering multiplication of two different matrix shapes

mipmodel = MIPModel(
    image_model,
    feature_dim,
    64,
    2,
    2,
    feature_dim,
    0.1,
    0.1,
#     "last_timestep"
    "sum"
)

Error I encountered

RuntimeError                              Traceback (most recent call last)
<ipython-input-140-8254e1a45d57> in <module>
      6 lens = torch.tensor([4, 4])
      7 mipmodel.eval()
----> 8 x = mipmodel(images, times, lens)
      9 print(images.shape, x.shape)

/opt/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-118-4961eeb32228> in forward(self, images, times, lens)
     70         image_feats_combined = torch.cat([image_feats, image_feats_trans], dim=2)
     71         image_feats_pooled = self._pool(image_feats_combined, lens)
---> 72         return self.classifier(image_feats_pooled)

/opt/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
     91 
     92     def forward(self, input: Tensor) -> Tensor:
---> 93         return F.linear(input, self.weight, self.bias)
     94 
     95     def extra_repr(self) -> str:

/opt/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1688     if input.dim() == 2 and bias is not None:
   1689         # fused op is marginally faster
-> 1690         ret = torch.addmm(bias, input, weight.t())
   1691     else:
   1692         output = input.matmul(weight.t())

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x4 and 1088x2)

Would there be additional things to do on the data for sum pooling to work?. Thought to raise this problem as a new issue since it wasn't discussed in #7

release of pretrained model for MIP

Hi, I read your paper and i think it is very interesting. However, i only happened to find the pre-trained model for SIP model, but no the MIP. Do you have any plan in future to release them as well?

Normalization during training

Hi, as I understand inputs to deep learning need to be scaled to a specific range such as [0.0, 1.0] so that the model can perform well. However, when reading the transform methods used by the proposed paper, it seems that transforms.ToTensor() scales input from [0, 255] to [0.0, 1.0], but HistogramNormalize scales it backs to [0., 255.]. Can you explain why do we scale it back to the original range, which is different from the typical workflow when training neural networks requiring inputs in [0, 1] range?

transform_list = [
        transforms.RandomResizedCrop(args.im_size, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        RandomGaussianBlur(),
        AddGaussianNoise(snr_range=(4, 8)),
        HistogramNormalize(),
        TensorToRGB(),
]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.