Skip to content

Commit

Permalink
new Dice and IoU #6 & #7
Browse files Browse the repository at this point in the history
  • Loading branch information
ebgoldstein committed Sep 30, 2022
1 parent 691abe7 commit f7236ab
Showing 1 changed file with 49 additions and 14 deletions.
63 changes: 49 additions & 14 deletions doodleverse_utils/model_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,18 @@ def batchnorm_act(x):
###############################################################

# -----------------------------------
def mean_iou(y_true, y_pred):

#define the basic IOU formula.
def new_iou(y_true, y_pred):
smooth = 10e-6
y_true_f = tf.reshape(tf.dtypes.cast(y_true, tf.float32), [-1])
y_pred_f = tf.reshape(tf.dtypes.cast(y_pred, tf.float32), [-1])
intersection = tf.reduce_sum(y_true_f * y_pred_f)
union = tf.reduce_sum(y_true_f + y_pred_f) - intersection
return (intersection+smooth)/(union+ smooth)

#define the IoU metric for nclasses
def MC_mean_iou(nclasses):
"""
mean_iou(y_true, y_pred)
This function computes the mean IoU between `y_true` and `y_pred`: this version is tensorflow (not numpy) and is used by tensorflow training and evaluation functions
Expand All @@ -971,16 +982,19 @@ def mean_iou(y_true, y_pred):
OUTPUTS:
* IoU score [tensor]
"""
yt0 = y_true[:, :, :, 0]
yp0 = tf.keras.backend.cast(y_pred[:, :, :, 0] > 0.5, "float32")
inter = tf.math.count_nonzero(tf.logical_and(tf.equal(yt0, 1), tf.equal(yp0, 1)))
union = tf.math.count_nonzero(tf.add(yt0, yp0))
iou = tf.where(tf.equal(union, 0), 1.0, tf.cast(inter / union, "float32"))
return iou
def mean_iou(y_true, y_pred):
iousum = 0
y_pred = tf.one_hot(tf.argmax(y_pred, -1), 4)
for index in range(nclasses):
iousum += new_iou(y_true[:,:,:,index], y_pred[:,:,:,index])
return iousum/nclasses

return mean_iou


# -----------------------------------
def dice_coef(y_true, y_pred):
#define basic Dice formula
def basic_dice_coef(y_true, y_pred):
"""
dice_coef(y_true, y_pred)
Expand All @@ -1004,17 +1018,30 @@ def dice_coef(y_true, y_pred):
OUTPUTS:
* Dice score [tensor]
"""
smooth = 1.0
smooth = 10e-6
y_true_f = tf.reshape(tf.dtypes.cast(y_true, tf.float32), [-1])
y_pred_f = tf.reshape(tf.dtypes.cast(y_pred, tf.float32), [-1])
intersection = tf.reduce_sum(y_true_f * y_pred_f)
return (2.0 * intersection + smooth) / (
tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth
)
dice = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
return dice

#define Dice formula for multiple classes
def dice_multi(nclasses):

def dice_coef(y_true, y_pred):
dice = 0
#can't have an argmax in a loss
#y_pred = tf.one_hot(tf.argmax(y_pred, -1), 4)
for index in range(nclasses):
dice += basic_dice_coef(y_true[:,:,:,index], y_pred[:,:,:,index])
return dice/nclasses

return dice_coef


# ---------------------------------------------------
def dice_coef_loss(y_true, y_pred):
#define Dice loss for multiple classes
def dice_coef_loss(nclasses):
"""
dice_coef_loss(y_true, y_pred)
Expand All @@ -1038,7 +1065,15 @@ def dice_coef_loss(y_true, y_pred):
OUTPUTS:
* Dice loss [tensor]
"""
return 1.0 - dice_coef(y_true, y_pred)
def MC_dice_coef_loss(y_true, y_pred):
dice = 0
#can't have an argmax in a loss
#y_pred = tf.one_hot(tf.argmax(y_pred, -1), 4)
for index in range(nclasses):
dice += basic_dice_coef(y_true[:,:,:,index], y_pred[:,:,:,index])
return 1 - (dice/nclasses)

return MC_dice_coef_loss


# -----------------------------------
Expand Down

0 comments on commit f7236ab

Please sign in to comment.