GithubHelp home page GithubHelp logo

rajpurkarlab / chexzero Goto Github PK

View Code? Open in Web Editor NEW
170.0 4.0 39.0 1.38 MB

This repository contains code to train a self-supervised learning model on chest X-ray images that lack explicit annotations and evaluate this model's performance on pathology-classification tasks.

License: MIT License

Python 75.16% Jupyter Notebook 24.84%

chexzero's Introduction

Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning

Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning, Nat. Biomed. Eng (2022). [Paper]
Ekin Tiu, Ellie Talius, Pujan Patel, Curtis P. Langlotz, Andrew Y. Ng, Pranav Rajpurkar
Tiu, E., Talius, E., Patel, P. et al. Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning. Nat. Biomed. Eng (2022). https://doi.org/10.1038/s41551-022-00936-9

Screen Shot 2022-09-15 at 10 57 16 AM

This repository contains code to train a self-supervised learning model on chest X-ray images that lack explicit annotations and evalute this model's performance on pathology-classification tasks.

Main Findings
  1. Automatically detecting pathologies in chest x-rays without explicit annotations: Our method learns directly from the combination of images and unstructured radiology reports, thereby avoiding time-consuming labeling efforts. Our deep learning method is capable of predicting multiple pathologies and differential diagnoses that it had not explicitly seen during training.
  2. Matching radiologist performance on different tasks on an external test set: Our method performed on par with human performance when evaluated on an external validation set (CheXpert) of chest x-ray images labeled for the presence of 14 different conditions by multiple radiologists.
  3. Outperforming approaches that train on explicitly labeled data on an external test set: Using no labels, we outperformed a fully supervised approach (100% of labels) on 3 out of the 8 selected pathologies on a dataset (PadChest) collected in a different country. We further demonstrated high performance (AUC > 0.9) on 14 findings and at least 0.700 on 53 findings out of 107 radiographic findings that the method had not seen during training.

Dependencies

To clone all files:

git clone https://github.com/rajpurkarlab/CheXzero.git

To install Python dependencies:

pip install -r requirements.txt

Data

Training Dataset

  1. Download images come from [MIMIC-CXR JPG] https://physionet.org/content/mimic-cxr-jpg/2.0.0/ and reports from MIMIC-CXR Database Note: in order to gain access to the data, you must be a credentialed user as defined on PhysioNet.
  2. Copy the dataset into the data/ directory.
  3. Run python run_preprocess.py
  4. This should preprocess the chest x-ray images into a Hierarchical Data Format (HDF) format used for training stored at data/cxr.h5 and extract the impressions section as text from the corresponding chest x-ray radiology report stored at data/mimic_impressions.csv .

Evaluation Dataset

CheXpert Dataset

The CheXpert dataset consists of chest radiographic examinations from Stanford Hospital, performed between October 2002 and July 2017 in both inpatient and outpatient centers. Population-level characteristics are unavailable for the CheXpert test dataset, as they are used for official evaluation on the CheXpert leaderboard.

The main data (CheXpert data) supporting the results of this study are available at https://aimi.stanford.edu/chexpert-chest-x-rays.

The CheXpert test dataset has recently been made public, and can be found by following the steps in the cheXpert-test-set-labels repository.

PadChest Dataset

The PadChest dataset contains chest X-rays that were interpreted by 18 radiologists at the Hospital Universitario de San Juan, Alicante, Spain, from January 2009 to December 2017. The dataset contains 109,931 image studies and 168,861 images. PadChest also contains 206,222 study reports.

The PadChest is publicly available at https://bimcv.cipf.es/bimcv-projects/padchest. Those who would like to use PadChest for experimentation should request access to PadChest at the link.

Model Checkpoints

Model checkpoints of CheXzero pre-trained on MIMIC-CXR are publicly available at the following link. Download files and save them in the ./checkpoints/chexzero_weights directory.

Running Training

Run the following command to perform CheXzero pretraining.

python run_train.py --cxr_filepath "./data/cxr.h5" --txt_filepath "data/mimic_impressions.csv"

Arguments

  • --cxr_filepath Directory to load chest x-ray image data from.
  • --txt_filepath Directory to load radiology report impressions text from.

Use -h flag to see all optional arguments.

Zero-Shot Inference

See the following notebook for an example of how to use CheXzero to perform zero-shot inference on a chest x-ray dataset. The example shows how to output predictions from the model ensemble and evaluate performance of the model if ground truth labels are available.

import zero_shot

# computes predictions for a set of images stored as a np array of probabilities for each pathology
predictions, y_pred_avg = zero_shot.ensemble_models(
    model_paths=model_paths, 
    cxr_filepath=cxr_filepath, 
    cxr_labels=cxr_labels, 
    cxr_pair_template=cxr_pair_template, 
    cache_dir=cache_dir,
)

Arguments

  • model_paths: List[str]: List of paths to all checkpoints to be used in the ensemble. To run on a single model, input a list containing a single path.
  • cxr_filepath: str: Path to images .h5 file
  • cxr_labels: List[str]: List of pathologies to query in each image
  • cxr_pair_templates: Tuple[str, str]: constrasting templates used to query model (see Figure 1 in article for visual explanation).
  • cache_dir: str: Directory to cache predictions of each checkpoint, use to avoid recomputing predictions.

In order to use CheXzero for zero-shot inference, ensure the following requirements are met:

  • All input images must be stored in a single .h5 (Hierarchical Data Format). See the img_to_h5 function in preprocess_padchest.py for an example of how to convert a list of paths to .png files into a valid .h5 file.
  • The ground truth labels must be in a .csv dataframe where rows represent each image sample, and each column represents the binary labels for a particular pathology on each sample.
  • Ensure all model checkpoints are stored in checkpoints/chexzero_weights/, or the model_dir that is specified in the notebook.

Evaluation

Given a numpy array of predictions (obtained from zero-shot inference), and a numpy array of ground truth labels, one can evaluate the performance of the model using the following code:

import zero_shot
import eval

# loads in ground truth labels into memory
test_pred = y_pred_avg
test_true = zero_shot.make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)

# evaluate model, no bootstrap
cxr_results: pd.DataFrame = eval.evaluate(test_pred, test_true, cxr_labels) # eval on full test datset

# boostrap evaluations for 95% confidence intervals
bootstrap_results: Tuple[pd.DataFrame, pd.DataFrame] = eval.bootstrap(test_pred, test_true, cxr_labels) # (df of results for each bootstrap, df of CI)

# print results with confidence intervals
print(bootstrap_results[1])

The results are represented as a pd.DataFrame which can be saved as a .csv.

CheXpert Test Dataset

In order to replicate the results in the paper, zero-shot inference and evaluation can be performed on the now publicly available CheXpert test dataset.

  1. Download labels at cheXpert-test-set-labels and image files from Stanford AIMI and save in the ./data directory in CheXzero/. The test dataset images should have the following directory structure:
data/
├─ CheXpert/
│  ├─ test/
│  │  ├─ patient64741/
│  │  │  ├─ study1/
│  │  │  │  ├─ view1_frontal.jpg
│  │  ├─ .../
  1. Run run_preprocess.py script with the following arguments:
python run_preprocess.py --dataset_type "chexpert-test" --cxr_out_path "./data/chexpert_test.h5" --chest_x_ray_path "./data/CheXpert/test/"

This should save a .h5 version of the test dataset images which can be used for evaluation.

  1. Open sample zero-shot notebook and run all cells. If the directory structure is set up correctly, then all cells should run without errors.

Issues

Please open new issue threads specifying the issue with the codebase or report issues directly to [email protected].

Citation

Tiu, E., Talius, E., Patel, P. et al. Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning. Nat. Biomed. Eng (2022). https://doi.org/10.1038/s41551-022-00936-9

License

The source code for the site is licensed under the MIT license, which you can find in the LICENSE file. Also see NOTICE.md for attributions to third-party sources.

chexzero's People

Contributors

ekkin2 avatar rajpurkar avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

chexzero's Issues

Comprehension Question - Negating Input-Images

Thank you for you awesome code and work.
I have a question:
In your train_batch function in the run_train.py you are negating the images with -images in the forward pass:
logits_per_image, logits_per_text = model(-images, texts)

Why do you do so?

Also in the train function you wanted to get the "highest_val_auc" which you never used. How would you define the AUC in such a scenario?

`def train(model, loader, device, criterion, optimizer, config):
model_save_dir = os.path.join(config.save_dir, config.model_name)
if not os.path.exists(model_save_dir):
# Create a new folder if not exists
os.makedirs(model_save_dir)

# Run training
total_batches = len(loader) * config.epochs
example_ct = 0  # number of examples seen
batch_ct = 0
report_freq = config.log_interval
highest_val_auc = 0 # save highest mean auc

for epoch in range(config.epochs):
    running_loss = 0.0 # running loss over batch
    for data in tqdm(loader):
        # get the images
        images = data['img']

        texts = data['txt']
        texts = preprocess_text(texts, model) 
        
        # perform step for a single batch
        loss = train_batch(images, texts, model, device, criterion, optimizer)
        example_ct +=  len(images)
        batch_ct += 1
        running_loss += loss.item()

        # Report metrics every `report_freq` batch
        if (batch_ct % report_freq) == 0:
            train_log(running_loss / report_freq, example_ct, epoch)
            running_loss = 0.0
        
        if (batch_ct % config.save_interval) == 0: 
            model_path = os.path.join(model_save_dir, "checkpoint_{batch_ct}.pt".format(
                batch_ct=str(batch_ct), 
            ))
            print("Saved checkpoint to: ", model_path)
            save(model, model_path)`

I hope you can help me with my questions!

Thanks again for making your code publicly avaiable and greetings from Germany

Train Loss seems stuck after a point

Hi, Thanks for the code you've provided.

I was trying to run the train on default specs of the repository. However I observed that the training loss gets suck at a value after a few minibatch cycles for both pretrained and non-pretrained cases. I tried this on the complete MIMIC dataset as well as a smaller batch I created for testing, but the loss was getting stuck at a value in all cases rather than moving towards overfitting.

Is this expected behaviour? Has anyone else also come across this?

TypeError: must be real number, not str

Getting the error in title when training.

Full command: python3 run_train.py --cxr_filepath {path to cxr.h5} --txt_filepath {path to mimic_impressions.csv}

Full traceback:

Traceback (most recent call last):
  File "/home/ec2-user/MedZero/CheXzero/run_train.py", line 143, in <module>
    model = model_pipeline(args)
  File "/home/ec2-user/MedZero/CheXzero/run_train.py", line 44, in model_pipeline
    train(model, data_loader, device, criterion, optimizer, config)
  File "/home/ec2-user/MedZero/CheXzero/run_train.py", line 85, in train
    for data in tqdm(loader):
  File "/opt/conda/envs/medzero/lib/python3.9/site-packages/tqdm/std.py", line 1195, in __iter__
    for obj in iterable:
  File "/opt/conda/envs/medzero/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/opt/conda/envs/medzero/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
    return self._process_data(data)
  File "/opt/conda/envs/medzero/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
    data.reraise()
  File "/opt/conda/envs/medzero/lib/python3.9/site-packages/torch/_utils.py", line 434, in reraise
    raise exception
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/envs/medzero/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/envs/medzero/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/opt/conda/envs/medzero/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 74, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/opt/conda/envs/medzero/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 74, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/opt/conda/envs/medzero/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 68, in default_collate
    return torch.tensor(batch, dtype=torch.float64)
TypeError: must be real number, not str

about h5 file key

Hello, thank you for sharing the code.
There seems to be one problem loading the h5 file into the CXRdataset class.
Wouldn't the key value be cxr rather than cxr_unprocessed in L39 (also, L42)?

CheXzero/train.py

Lines 24 to 43 in c303e5c

class CXRDataset(data.Dataset):
"""Represents an abstract HDF5 dataset.
Input params:
file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
recursive: If True, searches for h5 files in subdirectories.
load_data: If True, loads all the data immediately into RAM. Use this if
the dataset is fits into memory. Otherwise, leave this at false and
the data will load lazily.
data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
transform: PyTorch transform to apply to every data instance (default=None).
"""
def __init__(self, img_path, txt_path, column='report', size=None, transform=None):
super().__init__()
if size != None:
self.img_dset = h5py.File(img_path, 'r')['cxr_unprocessed'][:size]
self.txt_dset = pd.read_csv(txt_path)[column][:size]
else:
self.img_dset = h5py.File(img_path, 'r')['cxr_unprocessed']
self.txt_dset = pd.read_csv(txt_path)[column]

ground_truth format

ground_truth_3.csv
I was able to train a model with a small dataset and wanted to test the model with the original JPGs and giving a grount_truth.csv file encoded as 0s for all labels except "No Finding". I am getting back the prediction averages, but bootstrap_results[1] returns a table full of NaNs. What could be the issue?
Attached is a screenshot of the issue.
Screenshot_20231207_150757
Hope you can help us.

How to interpret results? (weird predictions)

Tested on some pneumonia image, getting these predictions:

[array([[0.49470618, 0.49632818, 0.501808  , 0.49816415, 0.49443188,
              0.4973687 , 0.5035202 , 0.50338894, 0.49944744, 0.4980956 ,
              0.5006629 , 0.5109319 , 0.5024933 , 0.48189908],
             [0.48987147, 0.4886013 , 0.505182  , 0.50687265, 0.49118397,
              0.50182277, 0.5023532 , 0.50404346, 0.49911276, 0.49427748,
              0.49890172, 0.51351357, 0.5014591 , 0.48301208]], dtype=float32)]

What does it mean? I think predictions should give 1.0 in total, here it seems to suspect every pathology with 50% probability, how so?
The same happens for both best_64_5e-05_original_22000_0.864.pt and best_128_5e-05_original_22000_0.855.pt weights.

Training not converging with default settings

Hi,
we're from the University of Wuerzburg and tried to replicate your project for German report data.
For now, we simply tried to get your code to run and train on MIMIC with the default settings provided as well as the settings provided in your paper. Of course, we made sure to have the same package versions as in the project.

However, we quickly get NaN loss after some iterations. So first, we tried to create a subsample of the dataset. For a very small dataset (~300 images), the training does converge. However, even for 1000 images the loss does not get smaller. We also tried several different learning rates and hyperparameters, but nothing helped so far.

I was hoping that you might be familiar with our problems and give us advice here.

Thanks in advance!

zeroshot result

When I ran the zero shot.ipynb file, I'm not sure why the average AUC of "No Finding" from the result is only 0.0700.

ss

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.