diff --git a/tensorflow_privacy/privacy/dp_query/dp_query.py b/tensorflow_privacy/privacy/dp_query/dp_query.py index 008ff9bbc..8b4944f1a 100644 --- a/tensorflow_privacy/privacy/dp_query/dp_query.py +++ b/tensorflow_privacy/privacy/dp_query/dp_query.py @@ -268,10 +268,11 @@ def derive_metrics(self, global_state): def _zeros_like(arg): """A `zeros_like` function that also works for `tf.TensorSpec`s.""" - try: - arg = tf.convert_to_tensor(value=arg) - except TypeError: - pass + if not isinstance(arg, tf.TensorSpec): + try: + arg = tf.convert_to_tensor(value=arg) + except TypeError: + pass return tf.zeros(arg.shape, arg.dtype) diff --git a/tensorflow_privacy/privacy/dp_query/tree_range_query.py b/tensorflow_privacy/privacy/dp_query/tree_range_query.py index 7cdc0b0dd..0548ec2a6 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_range_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_range_query.py @@ -18,7 +18,7 @@ import distutils import math -from typing import Optional +from typing import Any, Optional import attr import dp_accounting @@ -136,6 +136,12 @@ def initial_global_state(self): arity=self._arity, inner_query_state=self._inner_query.initial_global_state()) + def initial_sample_state(self, template: Optional[Any] = None): + """Implements `tensorflow_privacy.DPQuery.initial_sample_state`.""" + unprocessed_sample_state = super().initial_sample_state(template) + sample_params = self.derive_sample_params(self.initial_global_state()) + return self.preprocess_record(sample_params, unprocessed_sample_state) + def derive_sample_params(self, global_state): """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" return (global_state.arity,