Deep learning models for COVID-19 chest x-ray classification: Preventing shortcut learning using feature disentanglement
Disclaimer: This repository is provided for research and development use only. The models described in this repo are not intended for use in clinical decision-making or for any other clinical use and the performance of these models for clinical use has not been established.
This README describes how to reproduce models/some experiments from the preprint "Deep learning models for COVID-19 chest x-ray classification: Preventing shortcut learning using feature disentanglement". Note: the "CC-CCII" dataset is proprietary so we cannot include it, however this repository reproduces all results that only depend on the COVIDx dataset.
A rough sketch of how the repository is setup:
preprocess.py
is run with a path to a "metadata.csv" file that contains pointers to unprocessed chest x-ray images (in DICOM, jpeg, png format) and their disease labels. This script will create masked and unmasked copies of each input image that have been resized/cropped to a size of 224x224 in a desired location with an accompanying "metadata_preprocessed.csv" file.create_embeddings.py
is run with a path to a preprocessed dataset (a "metadata_preprocessed.csv" file). This script will create embeddings using some existing pre-trained model (currently, the three models we list in the paper), that can be used to quickly train a classifier. We call the output of this step an embedded dataset.train.py
uses an embedded dataset and corresponding domain/task labels to train a classifier. This can be done with or without "feature disentanglement".evaluate.py
uses a classifier and a preprocessed dataset to generate embeddings for new data. Note that these embeddings are not the same as those generated by one of the existing pre-trained models (e.g. torchxrayvision), but are a result of using a pre-trained model and classifier (see the paper for more details about this distinction).
The following commands should create a conda environment with the necessary requirements for reproducing our results. Note: our conda environment assumes that you have a CUDA Version of 11, you may need to adjust this to match your system
conda env create -f environment.yml
conda activate xray-feature-disentanglement
To get the COVID-Net pre-trained model weights:
- Go here and click on the "COVIDNet-CXR Large" large link. This should take you to a Google Drive hosted by the authors of COVID-Net.
- Download the following files into
data/pretrained_models/COVIDNet-CXR_Large
:checkpoint
model-8485.data-00000-of-00001
model-8485.index
model.meta
We use two CXR datasets in the accompanying paper: the open COVIDx dataset, and a private dataset from the CC-CCII. We cannot include the CC-CCII dataset, so this repository serves to reproduce the experiments that depend on only the COVIDx dataset.
To download the COVIDx dataset and create a "metadata.csv" file to use throughout the pipeline:
- Follow the directions here to download the 5 sub-datasets into a
$BASE_DIR
directory - Additionally download the two binary files describing RSNA patients from the COVIDx repo into the same
$BASE_DIR
directory: - Run
notebooks/Preprocessing - COVIDx - create dataset.ipynb
, editing the second cell to point to$BASE_DIR
- Run
notebooks/Preprocessing - COVIDx - combine splits and create metadata file.ipynb
, again, editing the second cell to point to$BASE_DIR
In utils.py
you should set BASE_DIR
to point to the full path of where you have cloned this repository.
We assume that datasets are defined as a list of filenames with corresponding metadata (e.g "label" and "patient_id") defined in, what we call, "metadata.csv" files. The preprocess.py
will consume a "metadata.csv" file, standardize the dimensions of each image, apply lung masking to each image, etc. and save the results to a directory.
Run the following using the data/metadata_covidx.csv
file created by notebooks/Preprocessing - COVIDx - combine splits and create metadata file.ipynb
:
mkdir -p datasets/covidx/
MKL_THREADING_LAYER=GNU python preprocess.py --input_fn data/metadata_covidx.csv --output_dir datasets/covidx/ --disable_flip_preprocessing --overwrite
After the datasets have been preprocessed we run create_embeddings.py
to generate embeddings from different feature extractor models, which we will then use to train a classifier.
# generate embeddings for masked images
python create_embeddings.py --input covidx --name covidx --model xrv --mask masked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model densenet --mask masked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model covidnet --mask masked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model histogram --mask masked --output_dir datasets/embeddings/
# generate embeddings for unmasked images
python create_embeddings.py --input covidx --name covidx --model xrv --mask unmasked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model densenet --mask unmasked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model covidnet --mask unmasked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model histogram --mask unmasked --output_dir datasets/embeddings/
We provide the way to reproduce 3 sets of results from the paper:
- The first two columns of Table 2
- The first two columns of Table 3
- (roughly) The UMAP visualization
The first set of results is generated by python generate_table2_results.py
.
The second set of results is generated by python run_main_experiments.py
followed by python generate_table3_results.py
. NOTE: running the experiments will take a long time and require GPU resources.
The third set of results is generated by bash run_umap_experiments.sh
and the notebook notebooks/Results - Generate UMAPs.ipynb
.
We use the torchxrayvision project from here at commit b274a7a32c462faff6df8cde711498d34f1acc36 on the master
branch.
We use the COVID-Net project from here at commit d6f3552f44f1af99981dbc960ee46ea3bceecd61 on the master
branch. Specifically, we use the pre-trained model "COVIDNet-CXR Large", and the dataset creation notebook (copied to notebooks/Preprocessing- Create COVID-Net dataset.ipynb
).
We use the lungVAE project from here at commit 52b44df82a351706db2f575758ea3b8452389998 on the master
branch. We also make the following small changes (see lungVAE/
):
diff --git a/predict.py b/predict.py
index 163a775..3876066 100644
--- a/predict.py
+++ b/predict.py
@@ -124,7 +124,7 @@ t = time.strftime("%Y%m%d_%H_%M")
if args.saveLoc is '':
save_dir = args.data+'pred_'+t+'/'
else:
- save_dir = args.saveLoc+'pred_'+t+'/'
+ save_dir = args.saveLoc + '/'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
@@ -134,7 +134,7 @@ print("Model "+args.model.split('/')[-1]+" Number of parameters:%d"%(nParam))
if args.dicom:
filetype = 'DCM'
else:
- filetype= 'png'
+ filetype= 'jpg'
files = list(set(glob(args.data+'*.'+filetype)) \
- set(glob(args.data+'*_mask*.'+filetype)) \
@@ -144,11 +144,14 @@ files = sorted(files)
for fIdx in range(len(files)):
f = files[fIdx]
fName = f.split('/')[-1]
- img, roi, h, w, hLoc, wLoc, imH, imW = loadDCM(f,
- no_preprocess=args.no_preprocess,
- dicom=args.dicom)
+ img, roi, h, w, hLoc, wLoc, imH, imW = loadDCM(
+ f,
+ no_preprocess=args.no_preprocess,
+ dicom=args.dicom
+ )
img = img.to(device)
- _,mask = net(img)
+ _, mask = net(img)
+ mask = mask.cpu()
mask = torch.sigmoid(mask*roi)
f = save_dir+fName.replace('.'+filetype,'_mask.png')
This project is licensed under the MIT License.