Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configurable Monte Carlo dropout uncertainty estimation [updated] #30

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions medsegpy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ class Config(object):
EARLY_STOPPING_PATIENCE = 0
EARLY_STOPPING_CRITERION = "val_loss"

# Dropout rate
DROPOUT_RATE = 0.0

# Use Monte Carlo dropout to generate predictive uncertainty values
MC_DROPOUT = False
# Number of Monte Carlo dropout iterations
MC_DROPOUT_T = 100

# Batch sizes
TRAIN_BATCH_SIZE = 12
VALID_BATCH_SIZE = 35
Expand Down Expand Up @@ -589,6 +597,10 @@ def summary(self, additional_vars=None):
"EARLY_STOPPING_PATIENCE" if self.USE_EARLY_STOPPING else "",
"EARLY_STOPPING_CRITERION" if self.USE_EARLY_STOPPING else "",
"",
"DROPOUT_RATE",
"MC_DROPOUT",
"MC_DROPOUT_T" if self.MC_DROPOUT else ""
"",
"KERNEL_INITIALIZER",
"SEED" if self.SEED else "",
"" "INIT_WEIGHTS",
Expand Down
13 changes: 12 additions & 1 deletion medsegpy/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def inference(self, model: Model, **kwargs):

workers = kwargs.pop("workers", self._cfg.NUM_WORKERS)
use_multiprocessing = kwargs.pop("use_multiprocessing", workers > 1)

kwargs["mc_dropout"] = self._cfg.MC_DROPOUT
kwargs["mc_dropout_T"] = self._cfg.MC_DROPOUT_T

for scan_id in scan_ids:
self._dataset_dicts = scan_to_dict_mapping[scan_id]

Expand All @@ -353,6 +357,13 @@ def inference(self, model: Model, **kwargs):
)
time_elapsed = time.perf_counter() - start

preds_mc_dropout = None
if isinstance(preds, dict):
if preds['preds_mc_dropout'] is not None:
preds_mc_dropout = np.squeeze(preds['preds_mc_dropout']).transpose((1, 2, 3, 0))

preds = preds['preds']

x, y, preds = self._restructure_data((x, y, preds))

input = {"x": x, "scan_id": scan_id}
Expand All @@ -363,7 +374,7 @@ def inference(self, model: Model, **kwargs):
}
input.update(scan_params)

output = {"y_pred": preds, "y_true": y, "time_elapsed": time_elapsed}
output = {"y_pred": preds, "y_mc_dropout":preds_mc_dropout, "y_true": y, "time_elapsed": time_elapsed}

yield input, output

Expand Down
4 changes: 4 additions & 0 deletions medsegpy/evaluation/sem_seg_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,15 @@ def process(self, inputs, outputs):
if includes_bg:
y_true = output["y_true"][..., 1:]
y_pred = output["y_pred"][..., 1:]
y_mc_dropout = None if output["y_mc_dropout"] is None else output["y_mc_dropout"][..., 1:]
labels = labels[..., 1:]
# if y_true.ndim == 3:
# y_true = y_true[..., np.newaxis]
# y_pred = y_pred[..., np.newaxis]
# labels = labels[..., np.newaxis]
output["y_true"] = y_true
output["y_pred"] = y_pred
output["y_mc_dropout"] = y_mc_dropout

time_elapsed = output["time_elapsed"]
if self.stream_evaluation:
Expand Down Expand Up @@ -178,6 +180,8 @@ def eval_single_scan(self, input, output, labels, time_elapsed):
with h5py.File(save_name, "w") as h5f:
h5f.create_dataset("probs", data=output["y_pred"])
h5f.create_dataset("labels", data=labels)
if output["y_mc_dropout"] is not None:
h5f.create_dataset("mc_dropout", data=output["y_mc_dropout"])

def evaluate(self):
"""Evaluates popular medical segmentation metrics specified in config.
Expand Down
5 changes: 3 additions & 2 deletions medsegpy/modeling/meta_arch/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def build_model(self, input_tensor=None) -> Model:
seed = cfg.SEED
depth = cfg.DEPTH
kernel_size = self.kernel_size
dropout_rate = cfg.DROPOUT_RATE
self.use_attention = cfg.USE_ATTENTION
self.use_deep_supervision = cfg.USE_DEEP_SUPERVISION

Expand Down Expand Up @@ -178,7 +179,7 @@ def build_model(self, input_tensor=None) -> Model:
num_conv=2,
activation="relu",
kernel_initializer=kernel_initializer,
dropout=0.0,
dropout=dropout_rate,
)

# Maxpool until penultimate depth.
Expand Down Expand Up @@ -220,7 +221,7 @@ def build_model(self, input_tensor=None) -> Model:
num_conv=2,
activation="relu",
kernel_initializer=kernel_initializer,
dropout=0.0,
dropout=dropout_rate,
)

if self.use_deep_supervision:
Expand Down
32 changes: 27 additions & 5 deletions medsegpy/modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras.models import Model as _Model
from keras.utils.data_utils import GeneratorEnqueuer, OrderedEnqueuer
from keras.utils.generic_utils import Progbar
import random

from medsegpy.utils import env

Expand Down Expand Up @@ -42,10 +43,12 @@ def inference_generator(
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0,
mc_dropout=False,
mc_dropout_T=100,
verbose=0
):
return self.inference_generator_static(
self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose
self, generator, steps, max_queue_size, workers, use_multiprocessing, mc_dropout, mc_dropout_T, verbose
)

@classmethod
Expand All @@ -57,7 +60,9 @@ def inference_generator_static(
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0,
mc_dropout=False,
mc_dropout_T=100,
verbose=0
):
"""Generates predictions for the input samples from a data generator
and returns inputs, ground truth, and predictions.
Expand Down Expand Up @@ -115,6 +120,8 @@ def inference_generator_static(
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
mc_dropout=mc_dropout,
mc_dropout_T=mc_dropout_T,
verbose=verbose,
)
else:
Expand Down Expand Up @@ -252,9 +259,13 @@ def _inference_generator_tf2(
max_queue_size=10,
workers=1,
use_multiprocessing=False,
mc_dropout=False,
mc_dropout_T=100
):
"""Inference generator for TensorFlow 2."""
random.seed(0)
outputs = []
outputs_mc_dropout = []
xs = []
ys = []
with model.distribute_strategy.scope():
Expand Down Expand Up @@ -295,14 +306,21 @@ def _inference_generator_tf2(
batch_x, batch_y, batch_x_raw = _extract_inference_inputs(next(iterator))
# tmp_batch_outputs = predict_function(iterator)
tmp_batch_outputs = model.predict(batch_x)


tmp_batch_outputs_mc_dropout = None
if mc_dropout:
tmp_batch_outputs_mc_dropout = np.stack([model(batch_x, training=True) for _ in range(mc_dropout_T)])

if data_handler.should_sync:
context.async_wait() # noqa: F821
batch_outputs = tmp_batch_outputs # No error, now safe to assign.
batch_outputs_mc_dropout = tmp_batch_outputs_mc_dropout

if batch_x_raw is not None:
batch_x = batch_x_raw
for batch, running in zip(
[batch_x, batch_y, batch_outputs], [xs, ys, outputs]
[batch_x, batch_y, batch_outputs, batch_outputs_mc_dropout], [xs, ys, outputs, outputs_mc_dropout]
):
nest.map_structure_up_to(
batch, lambda x, batch_x: x.append(batch_x), running, batch
Expand All @@ -318,7 +336,11 @@ def _inference_generator_tf2(
all_xs = nest.map_structure_up_to(batch_x, np.concatenate, xs)
all_ys = nest.map_structure_up_to(batch_y, np.concatenate, ys)
all_outputs = nest.map_structure_up_to(batch_outputs, np.concatenate, outputs)
return all_xs, all_ys, all_outputs
all_outputs_mc_dropout = nest.map_structure_up_to(batch_outputs_mc_dropout, np.concatenate, outputs_mc_dropout) if mc_dropout else None

outputs = {'preds': all_outputs, 'preds_mc_dropout': all_outputs_mc_dropout}

return all_xs, all_ys, outputs

# all_xs = nest.map_structure_up_to(batch_x, concat, xs)
# all_ys = nest.map_structure_up_to(batch_y, concat, ys)
Expand Down
81 changes: 81 additions & 0 deletions tests/modeling/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Test model output reproducability.

These tests check that Monte Carlo dropout during inference produces
reproducible results.
"""

import unittest
import numpy as np
import os
import h5py
import shutil
from fvcore.common.file_io import PathManager

from medsegpy.config import UNetConfig
from medsegpy.modeling.meta_arch import build_model
from medsegpy.modeling.model import Model
from medsegpy.data import DefaultDataLoader

class TestMCDropout(unittest.TestCase):
IMG_SIZE = (512, 512, 1)
NUM_CLASSES = 4
FILE_PATH = "mock_data://temp_data/scan.h5"

@classmethod
def setUpClass(cls):
np.random.seed(0)
img = np.random.rand(*cls.IMG_SIZE).astype(np.float32)
seg = (np.random.rand(*cls.IMG_SIZE, cls.NUM_CLASSES) >= 0.5).astype(np.uint8)

file_path = PathManager.get_local_path(cls.FILE_PATH)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with h5py.File(file_path, "w") as f:
f.create_dataset("data", data=img)
f.create_dataset("volume", data=img)
f.create_dataset("seg", data=seg)

@classmethod
def tearDownClass(cls):
file_path = PathManager.get_local_path(cls.FILE_PATH)
shutil.rmtree(os.path.dirname(file_path))

def get_dataset_dicts(self):
file_path = PathManager.get_local_path(self.FILE_PATH)
return [
{
"file_name": file_path,
"sem_seg_file": file_path,
"scan_id": os.path.splitext(os.path.basename(file_path))[0],
"image_size": self.IMG_SIZE,
}
]

def test_inference_with_mc_dropout(self):
cfg = UNetConfig()
cfg.MC_DROPOUT = True
cfg.MC_DROPOUT_T = 10
cfg.IMG_SIZE = self.IMG_SIZE
model = build_model(cfg)

with h5py.File(PathManager.get_local_path(self.FILE_PATH), "r") as f:
volume = f["volume"][:]
mask = f["seg"][:]
dataset_dicts = self.get_dataset_dicts()
data_loader = DefaultDataLoader(cfg, dataset_dicts, is_test=True, shuffle=False)

# Feed same data to inference generator twice
kwargs = dict()
kwargs["mc_dropout"] = data_loader._cfg.MC_DROPOUT
kwargs["mc_dropout_T"] = data_loader._cfg.MC_DROPOUT_T
x1, y1, preds1 = Model.inference_generator_static(model, data_loader, **kwargs)
x2, y2, preds2 = Model.inference_generator_static(model, data_loader, **kwargs)

# Outputs should be the same
assert np.array_equal(x1, x2)
assert np.array_equal(y1, y2)
assert np.array_equal(preds1["preds"], preds2["preds"])
assert np.array_equal(preds1["preds_mc_dropout"], preds2["preds_mc_dropout"])


if "__name__" == "__main__":
unittest.main()