Skip to content

Commit

Permalink
COVIDNet CXR-3 (#204)
Browse files Browse the repository at this point in the history
* Added eval and inference scripts for COVIDNet CXR-3 with MEDUSA backbone

Co-authored-by: mayaliliya <[email protected]>
  • Loading branch information
mayaliliya and mayaliliya authored Oct 19, 2021
1 parent 4347213 commit 1090a2f
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 55 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

**Recording to webinar on [How we built COVID-Net in 7 days with Gensynth](https://darwinai.news/fny)**

**Update 10/19/2021:** We released a new COVID-Net CXR-3 [model](docs/models.md) for COVID-19 positive/negative detection which was trained and tested on the COVIDx8B dataset leveraging the new MEDUSA (Multi-scale Encoder-Decoder Self-Attention) architecture.\
**Update 04/21/2021:** We released a new COVIDNet CXR-S [model](docs/models.md) and [COVIDxSev](create_COVIDxSev.ipynb) dataset for airspace severity grading in COVID-19 positive patient CXR images. For more information on training, testing and inference please refer to severity [docs](docs/covidnet_severity.md).\
**Update 03/20/2021:** We released a new COVID-Net CXR-2 [model](docs/models.md) for COVID-19 positive/negative detection which was trained on the new COVIDx8B dataset with 16,352 CXR images from a multinational cohort of 15,346 patients from at least 51 countries. The test results are based on the new COVIDx8B test set of 200 COVID-19 positive and 200 negative CXR images.\
**Update 03/19/2021:** We released updated datasets and dataset curation scripts. The COVIDx V8A dataset and create_COVIDx.ipynb are for detection of no pneumonia/non-COVID-19 pneumonia/COVID-19 pneumonia, and COVIDx V8B dataset and create_COVIDx_binary.ipynb are for COVID-19 positive/negative detection. Both datasets contain over 16000 CXR images with over 2300 positive COVID-19 images.\
Expand Down Expand Up @@ -131,6 +132,35 @@ Additional requirements to generate dataset:
## Results
These are the final results for the COVIDNet models.

### COVIDNet-CXR-3 on COVIDx8B (200 COVID-19 test)
<div class="tg-wrap"><table class="tg">
<tr>
<th class="tg-7btt" colspan="3">Sensitivity (%)</th>
</tr>
<tr>
<td class="tg-7btt">Negative</td>
<td class="tg-7btt">Positive</td>
</tr>
<tr>
<td class="tg-c3ow">99.0</td>
<td class="tg-c3ow">97.5</td>
</tr>
</table></div>

<div class="tg-wrap"><table class="tg">
<tr>
<th class="tg-7btt" colspan="3">Positive Predictive Value (%)</th>
</tr>
<tr>
<td class="tg-7btt">Negative</td>
<td class="tg-7btt">Positive</td>
</tr>
<tr>
<td class="tg-c3ow">97.5</td>
<td class="tg-c3ow">99.0</td>
</tr>
</table></div>

### COVIDNet-CXR-2 on COVIDx8B (200 COVID-19 test)
<div class="tg-wrap"><table class="tg">
<tr>
Expand Down
63 changes: 48 additions & 15 deletions data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tensorflow as tf
from tensorflow import keras

from functools import partial
import numpy as np
import os
import cv2
Expand All @@ -17,13 +18,22 @@ def central_crop(img):
offset_w = int((img.shape[1] - size) / 2)
return img[offset_h:offset_h + size, offset_w:offset_w + size]

def process_image_file(filepath, top_percent, size):
def process_image_file(filepath, size, top_percent=0.08, crop=True):
img = cv2.imread(filepath)
img = crop_top(img, percent=top_percent)
img = central_crop(img)
if crop:
img = central_crop(img)
img = cv2.resize(img, (size, size))
return img

def process_image_file_medusa(filepath, size):
img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (size, size))
img = img.astype('float64')
img -= img.mean()
img /= img.std()
return np.expand_dims(img, -1)

def random_ratio_resize(img, prob=0.3, delta=0.1):
if np.random.rand() >= prob:
return img
Expand Down Expand Up @@ -84,7 +94,8 @@ def __init__(
csv_file,
is_training=True,
batch_size=8,
input_shape=(224, 224),
medusa_input_shape=(256, 256),
input_shape=(480, 480),
n_classes=2,
num_channels=3,
mapping={
Expand All @@ -96,14 +107,16 @@ def __init__(
covid_percent=0.5,
class_weights=[1., 1.],
top_percent=0.08,
is_severity_model=False
is_severity_model=False,
is_medusa_backbone=False,
):
'Initialization'
self.datadir = data_dir
self.dataset = _process_csv_file(csv_file)
self.is_training = is_training
self.batch_size = batch_size
self.N = len(self.dataset)
self.medusa_input_shape = medusa_input_shape
self.input_shape = input_shape
self.n_classes = n_classes
self.num_channels = num_channels
Expand All @@ -115,6 +128,13 @@ def __init__(
self.augmentation = augmentation
self.top_percent = top_percent
self.is_severity_model = is_severity_model
self.is_medusa_backbone = is_medusa_backbone

# If using MEDUSA backbone load images without crop
if self.is_medusa_backbone:
self.load_image = partial(process_image_file, top_percent=0, crop=False)
else:
self.load_image = process_image_file

datasets = {}
for key in self.mapping.keys():
Expand Down Expand Up @@ -147,7 +167,7 @@ def __init__(

def __next__(self):
# Get one batch of data
batch_x, batch_y, weights = self.__getitem__(self.n)
model_inputs = self.__getitem__(self.n)
# Batch index
self.n += 1

Expand All @@ -156,7 +176,7 @@ def __next__(self):
self.on_epoch_end()
self.n = 0

return batch_x, batch_y, weights
return model_inputs

def __len__(self):
return int(np.ceil(len(self.datasets[0]) / float(self.batch_size)))
Expand All @@ -168,12 +188,13 @@ def on_epoch_end(self):
np.random.shuffle(v)

def __getitem__(self, idx):
batch_x, batch_y = np.zeros(
(self.batch_size, *self.input_shape,
self.num_channels)), np.zeros(self.batch_size)
batch_x = np.zeros((self.batch_size, *self.input_shape, self.num_channels))
batch_y = np.zeros(self.batch_size)

batch_files = self.datasets[0][idx * self.batch_size:(idx + 1) *
self.batch_size]
if self.is_medusa_backbone:
batch_sem_x = np.zeros((self.batch_size, *self.medusa_input_shape, 1))

batch_files = self.datasets[0][idx * self.batch_size:(idx + 1) * self.batch_size]

# upsample covid cases
covid_size = max(int(len(batch_files) * self.covid_percent), 1)
Expand All @@ -198,21 +219,33 @@ def __getitem__(self, idx):
else:
folder = 'test'

x = process_image_file(os.path.join(self.datadir, folder, sample[1]),
self.top_percent,
self.input_shape[0])
image_file = os.path.join(self.datadir, folder, sample[1])
x = self.load_image(
image_file,
self.input_shape[0],
top_percent=self.top_percent,
)

if self.is_training and hasattr(self, 'augmentation'):
x = self.augmentation(x)

x = x.astype('float32') / 255.0

if self.is_medusa_backbone:
sem_x = process_image_file_medusa(image_file, self.medusa_input_shape[0])
batch_sem_x[i] = sem_x

y = self.mapping[sample[2]]

batch_x[i] = x
batch_y[i] = y

class_weights = self.class_weights
weights = np.take(class_weights, batch_y.astype('int64'))
batch_y = keras.utils.to_categorical(batch_y, num_classes=self.n_classes)

return batch_x, keras.utils.to_categorical(batch_y, num_classes=self.n_classes), weights
if self.is_medusa_backbone:
return batch_sem_x, batch_x, batch_y, weights, self.is_training
else:
return batch_x, batch_y, weights, self.is_training

3 changes: 2 additions & 1 deletion docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
## COVIDNet Chest X-Ray Classification
| Type | Input Resolution | COVID-19 Sensitivity | Accuracy | # Params (M) | MACs (G) | Model |
|:-----:|:----------------:|:--------------------:|:--------:|:------------:|:--------:|:-------------------:|
| ckpt | 480x480 | 95.5 | 96.3 | 8.8 | 5.55 |[COVIDNet-CXR-2](https://bit.ly/COVIDNet-CXR-2)|
| ckpt | 480x480 | 97.5 | 98.3 | 29.2 | 29.1 |[COVIDNet-CXR-3](https://bit.ly/COVIDNet-CXR-3)|
| ckpt | 480x480 | 95.5 | 96.3 | 8.8 | 5.55 |[COVIDNet-CXR-2](https://bit.ly/COVIDNet-CXR-2)|
| ckpt | 480x480 | 95.0 | 94.3 | 40.2 | 23.63 |[COVIDNet-CXR4-A](https://bit.ly/COVIDNet-CXR4-A)|
| ckpt | 480x480 | 93.0 | 93.7 | 11.7 | 7.50 |[COVIDNet-CXR4-B](https://bit.ly/COVIDNet-CXR4-B)|
| ckpt | 480x480 | 96.0 | 93.3 | 9.2 | 5.55 |[COVIDNet-CXR4-C](https://bit.ly/COVIDNet-CXR4-C)|
Expand Down
47 changes: 44 additions & 3 deletions docs/train_eval_inference.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
# Training, Evaluation and Inference
## COVID-19 positive/negative detection
COVIDNet-CXR-2 model takes as input an image of shape (N, 480, 480, 3) and outputs the softmax probabilities of COVID-19 positive and negative detection as (N, 2), where N is the number of batches.
COVIDNet-CXR-3 model takes as input two images, one of shape (N, 256, 256, 1) for the MEDUSA architecture and one of shape (N, 480, 480, 3) for the COVIDNet architecture and outputs the softmax probabilities of COVID-19 positive and negative detection as (N, 2), where N is the number of batches.
Older COVIDNet models take a single input of shape (N, 480, 480, 3) and do not leverage a MEDUSA architecture.

If using the TF checkpoints, here are some useful tensors:
For COVIDNet-CXR-3:
* input tensor: `input_2:0`
* input medusa tensor: `input_1:0`
* logit tensor: `final_output/MatMul:0`
* output tensor: `softmax/Softmax:0`
* label tensor: `Placeholder:0`
* class weights tensor: `Placeholder_1:0`
* loss tensor: `Mean:0`

For COVIDNet-CXR-2:
* input tensor: `input_1:0`
* logit tensor: `norm_dense_2/MatMul:0`
* output tensor: `norm_dense_2/Softmax:0`
Expand Down Expand Up @@ -32,14 +43,28 @@ python train_tf.py \

1. We provide you with the tensorflow evaluation script, [eval.py](../eval.py)
2. Locate the tensorflow checkpoint files
3. To evaluate a tf checkpoint:
3. To evaluate a tf checkpoint
For COVIDNet-CXR-3:
```
python eval.py \
--weightspath models/COVIDNet-CXR-3 \
--metaname model.meta \
--ckptname model \
--n_classes 2 \
--testfile labels/test_COVIDx8B.txt \
--out_tensorname softmax/Softmax:0 \
--is_medusa_backbone
```

For COVIDNet-CXR-2:
```
python eval.py \
--weightspath models/COVIDNet-CXR-2 \
--metaname model.meta \
--ckptname model \
--n_classes 2 \
--testfile labels/test_COVIDx8B.txt \
--in_tensorname input_1:0 \
--out_tensorname norm_dense_2/Softmax:0
```
4. For more options and information, `python eval.py --help`
Expand All @@ -49,14 +74,28 @@ python eval.py \

1. Download a model from the [pretrained models section](models.md)
2. Locate models and xray image to be inferenced
3. To inference,
3. To inference
For COVIDNet-CXR-3:
```
python inference.py \
--weightspath models/COVIDNet-CXR-3 \
--metaname model.meta \
--ckptname model \
--n_classes 2 \
--imagepath assets/ex-covid.jpeg \
--out_tensorname softmax/Softmax:0 \
--is_medusa_backbone
```

For COVIDNet-CXR-2:
```
python inference.py \
--weightspath models/COVIDNet-CXR-2 \
--metaname model.meta \
--ckptname model \
--n_classes 2 \
--imagepath assets/ex-covid.jpeg \
--in_tensorname input_1:0 \
--out_tensorname norm_dense_2/Softmax:0
```
4. For more options and information, `python inference.py --help`
Expand Down Expand Up @@ -102,6 +141,7 @@ python eval.py \
--ckptname model-18540 \
--n_classes 3 \
--testfile labels/test_COVIDx8A.txt \
--in_tensorname input_1:0 \
--out_tensorname norm_dense_1/Softmax:0
```
4. For more options and information, `python eval.py --help`
Expand All @@ -119,6 +159,7 @@ python inference.py \
--ckptname model-18540 \
--n_classes 3 \
--imagepath assets/ex-covid.jpeg \
--in_tensorname input_1:0 \
--out_tensorname norm_dense_1/Softmax:0
```
4. For more options and information, `python inference.py --help`
Expand Down
Loading

0 comments on commit 1090a2f

Please sign in to comment.