diff --git a/tensorflow_addons/metrics/kendalls_tau.py b/tensorflow_addons/metrics/kendalls_tau.py index c34489e118..3f1a69f392 100644 --- a/tensorflow_addons/metrics/kendalls_tau.py +++ b/tensorflow_addons/metrics/kendalls_tau.py @@ -76,7 +76,22 @@ def __init__( self.preds_max = preds_max self.actual_cutpoints = actual_cutpoints self.preds_cutpoints = preds_cutpoints - self.reset_state() + self.actual_cuts = tf.linspace( + tf.cast(self.actual_min, tf.float32), + tf.cast(self.actual_max, tf.float32), + self.actual_cutpoints - 1, + ) + self.preds_cuts = tf.linspace( + tf.cast(self.preds_min, tf.float32), + tf.cast(self.preds_max, tf.float32), + self.preds_cutpoints - 1, + ) + self.m = self.add_weight( + "m", (self.actual_cutpoints, self.preds_cutpoints), dtype=tf.int64 + ) + self.nrow = self.add_weight("nrow", (self.actual_cutpoints), dtype=tf.int64) + self.ncol = self.add_weight("ncol", (self.preds_cutpoints), dtype=tf.int64) + self.n = self.add_weight("n", (), dtype=tf.int64) def update_state(self, y_true, y_pred, sample_weight=None): """Accumulates ranks. @@ -89,75 +104,69 @@ def update_state(self, y_true, y_pred, sample_weight=None): Returns: Update op. """ - if y_true.shape and y_true.shape[0]: - i = tf.searchsorted( - self.actual_cuts, - tf.cast(tf.reshape(y_true, -1), self.actual_cuts.dtype), + i = tf.searchsorted( + self.actual_cuts, + tf.cast(tf.reshape(y_true, [-1]), self.actual_cuts.dtype), + ) + j = tf.searchsorted( + self.preds_cuts, tf.cast(tf.reshape(y_pred, [-1]), self.preds_cuts.dtype) + ) + + m = tf.sparse.from_dense(self.m) + nrow = tf.sparse.from_dense(self.nrow) + ncol = tf.sparse.from_dense(self.ncol) + + k = 0 + while k < tf.shape(i)[0]: + m = tf.sparse.add( + m, + tf.SparseTensor( + [[i[k], j[k]]], + tf.cast([1], dtype=m.dtype), + self.m.shape, + ), ) - j = tf.searchsorted( - self.preds_cuts, tf.cast(tf.reshape(y_pred, -1), self.preds_cuts.dtype) + nrow = tf.sparse.add( + nrow, + tf.SparseTensor( + [[i[k]]], + tf.cast([1], dtype=nrow.dtype), + self.nrow.shape, + ), ) - - def body(k, n, m, nrow, ncol): - return ( - k + 1, - n + 1, - tf.sparse.add( - m, - tf.SparseTensor( - [[i[k], j[k]]], - tf.cast([1], dtype=self.m.dtype), - self.m.shape, - ), - ), - tf.sparse.add( - nrow, - tf.SparseTensor( - [[i[k]]], - tf.cast([1], dtype=self.nrow.dtype), - self.nrow.shape, - ), - ), - tf.sparse.add( - ncol, - tf.SparseTensor( - [[j[k]]], - tf.cast([1], dtype=self.ncol.dtype), - self.ncol.shape, - ), - ), - ) - - _, self.n, self.m, self.nrow, self.ncol = tf.while_loop( - lambda k, n, m, nrow, ncol: k < i.shape[0], - body=body, - loop_vars=(0, self.n, self.m, self.nrow, self.ncol), + ncol = tf.sparse.add( + ncol, + tf.SparseTensor( + [[j[k]]], + tf.cast([1], dtype=ncol.dtype), + self.ncol.shape, + ), ) + k += 1 + + self.n.assign_add(tf.cast(k, tf.int64)) + self.m.assign(tf.sparse.to_dense(m)) + self.nrow.assign(tf.sparse.to_dense(nrow)) + self.ncol.assign(tf.sparse.to_dense(ncol)) def result(self): - m_dense = tf.sparse.to_dense(tf.cast(self.m, tf.float32)) - n_cap = tf.cumsum( - tf.cumsum( - tf.slice(tf.pad(m_dense, [[1, 0], [1, 0]]), [0, 0], self.m.shape), - axis=0, - ), - axis=1, - ) + m = tf.cast(self.m, tf.float32) + n_cap = tf.cumsum(tf.cumsum(m, axis=0), axis=1) # Number of concordant pairs. - p = tf.math.reduce_sum(tf.multiply(n_cap, m_dense)) - sum_m_squard = tf.math.reduce_sum(tf.math.square(m_dense)) + p = tf.math.reduce_sum(tf.multiply(n_cap[:-1, :-1], m[1:, 1:])) + sum_m_squard = tf.math.reduce_sum(tf.math.square(m)) # Ties in x. t = ( - tf.math.reduce_sum(tf.math.square(tf.sparse.to_dense(self.nrow))) + tf.cast(tf.math.reduce_sum(tf.math.square(self.nrow)), tf.float32) - sum_m_squard ) / 2.0 # Ties in y. u = ( - tf.math.reduce_sum(tf.math.square(tf.sparse.to_dense(self.ncol))) + tf.cast(tf.math.reduce_sum(tf.math.square(self.ncol)), tf.float32) - sum_m_squard ) / 2.0 # Ties in both. - b = tf.math.reduce_sum(tf.multiply(m_dense, (m_dense - 1.0))) / 2.0 + b = tf.math.reduce_sum(tf.multiply(m, (m - 1.0))) / 2.0 # Number of discordant pairs. n = tf.cast(self.n, tf.float32) q = (n - 1.0) * n / 2.0 - p - t - u - b @@ -179,28 +188,11 @@ def get_config(self): def reset_state(self): """Resets all of the metric state variables.""" - self.actual_cuts = tf.linspace( - tf.cast(self.actual_min, tf.float32), - tf.cast(self.actual_max, tf.float32), - self.actual_cutpoints - 1, - ) - self.preds_cuts = tf.linspace( - tf.cast(self.preds_min, tf.float32), - tf.cast(self.preds_max, tf.float32), - self.preds_cutpoints - 1, - ) - self.m = tf.SparseTensor( - tf.zeros((0, 2), tf.int64), - [], - [self.actual_cutpoints, self.preds_cutpoints], - ) - self.nrow = tf.SparseTensor( - tf.zeros((0, 1), dtype=tf.int64), [], [self.actual_cutpoints] - ) - self.ncol = tf.SparseTensor( - tf.zeros((0, 1), dtype=tf.int64), [], [self.preds_cutpoints] - ) - self.n = 0 + + self.m.assign(tf.zeros((self.actual_cutpoints, self.preds_cutpoints), tf.int64)) + self.nrow.assign(tf.zeros((self.actual_cutpoints), tf.int64)) + self.ncol.assign(tf.zeros((self.preds_cutpoints), tf.int64)) + self.n.assign(0) def reset_states(self): # Backwards compatibility alias of `reset_state`. New classes should diff --git a/tensorflow_addons/metrics/tests/kendalls_tau_test.py b/tensorflow_addons/metrics/tests/kendalls_tau_test.py index 6d9502ea2b..4121c64b5e 100644 --- a/tensorflow_addons/metrics/tests/kendalls_tau_test.py +++ b/tensorflow_addons/metrics/tests/kendalls_tau_test.py @@ -90,7 +90,8 @@ def test_keras_binary_classification_model(): x = np.random.rand(1000, 10).astype(np.float32) y = np.random.rand(1000, 1).astype(np.float32) - model.fit(x, y, epochs=1, verbose=0, batch_size=32) + history = model.fit(x, y, epochs=1, verbose=0, batch_size=32) + assert not any(np.isnan(history.history["kendalls_tau"])) def test_kendalls_tau_serialization():