From 20476b49005a4668bfa834a0ca8743c606fc0cae Mon Sep 17 00:00:00 2001 From: TF-Agents Team Date: Tue, 12 Dec 2023 12:35:48 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 590298665 Change-Id: Ie389bec30e47c40cf0da7d86f91a29c56c290725 --- tf_agents/policies/py_tf_policy.py | 3 ++- tf_agents/replay_buffers/reverb_utils.py | 32 ++++++++++++++++-------- tf_agents/utils/session_utils.py | 2 +- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/tf_agents/policies/py_tf_policy.py b/tf_agents/policies/py_tf_policy.py index 983deec1a..9e3dd8493 100644 --- a/tf_agents/policies/py_tf_policy.py +++ b/tf_agents/policies/py_tf_policy.py @@ -14,6 +14,7 @@ # limitations under the License. """Converts TensorFlow Policies into Python Policies.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -76,7 +77,7 @@ def __init__( ) self._tf_policy = policy - self.session = None + self.session: tf.compat.v1.Session = None self._policy_state_spec = tensor_spec.to_nest_array_spec( self._tf_policy.policy_state_spec diff --git a/tf_agents/replay_buffers/reverb_utils.py b/tf_agents/replay_buffers/reverb_utils.py index 2b018f516..c1c0d70af 100644 --- a/tf_agents/replay_buffers/reverb_utils.py +++ b/tf_agents/replay_buffers/reverb_utils.py @@ -117,6 +117,11 @@ def __init__( self._overflow_episode = False self._writer_has_data = False + def _get_writer(self): + if self._writer is None: + raise ValueError("Could not obtain writer from py_client.") + return self._writer + def update_priority(self, priority: Union[float, int]) -> None: """Update the table priority. @@ -181,7 +186,7 @@ def __call__(self, trajectory: trajectory_lib.Trajectory) -> None: return if not self._overflow_episode: - self._writer.append(trajectory) + self._get_writer().append(trajectory) self._writer_has_data = True self._cached_steps += 1 @@ -197,9 +202,11 @@ def _write_cached_steps(self) -> None: # Only writes to Reverb when the writer has cached trajectories. if self._writer_has_data: # No need to truncate since the truncation is done in the class. - trajectory = tf.nest.map_structure(lambda h: h[:], self._writer.history) + trajectory = tf.nest.map_structure( + lambda h: h[:], self._get_writer().history + ) for table_name in self._table_names: - self._writer.create_item( + self._get_writer().create_item( table=table_name, trajectory=trajectory, priority=self._priority ) self._writer_has_data = False @@ -214,7 +221,7 @@ def flush(self): By calling this method before the `learner.run()`, we ensure that there is enough data to be consumed. """ - self._writer.flush() + self._get_writer().flush() def reset(self, write_cached_steps: bool = True) -> None: """Resets the state of the observer. @@ -347,6 +354,11 @@ def __init__( self._cached_steps = 0 self._last_trajectory = None + def _get_writer(self): + if self._writer is None: + raise ValueError("Could not obtain writer from py_client.") + return self._writer + @property def py_client(self) -> types.ReverbClient: return self._py_client @@ -367,7 +379,7 @@ def __call__(self, trajectory: trajectory_lib.Trajectory) -> None: there is *no* batch dimension. """ self._last_trajectory = trajectory - self._writer.append(trajectory) + self._get_writer().append(trajectory) self._cached_steps += 1 # If the fixed sequence length is reached, write the sequence. @@ -395,10 +407,10 @@ def _write_cached_steps(self) -> None: if self._sequence_lengths_reached(): trajectory = tf.nest.map_structure( - lambda d: d[-self._sequence_length :], self._writer.history + lambda d: d[-self._sequence_length :], self._get_writer().history ) for table_name in self._table_names: - self._writer.create_item( + self._get_writer().create_item( table=table_name, trajectory=trajectory, priority=self._priority ) @@ -412,7 +424,7 @@ def flush(self): By calling this method before the `learner.run()`, we ensure that there is enough data to be consumed. """ - self._writer.flush() + self._get_writer().flush() def _get_padding_step( self, example_trajectory: trajectory_lib.Trajectory @@ -452,7 +464,7 @@ def reset(self, write_cached_steps: bool = True) -> None: pad_range = range(self._sequence_length - self._cached_steps) for _ in pad_range: - self._writer.append(zero_step) + self._get_writer().append(zero_step) self._cached_steps += 1 self._write_cached_steps() # Write the cached trajectories without padding, if the cache contains @@ -522,7 +534,7 @@ def __call__(self, trajectory: trajectory_lib.Trajectory) -> None: # We record the last step called to generate the padding step when padding # is enabled. self._last_trajectory = trajectory - self._writer.append(trajectory) + self._get_writer().append(trajectory) self._cached_steps += 1 self._write_cached_steps() diff --git a/tf_agents/utils/session_utils.py b/tf_agents/utils/session_utils.py index e01d25e71..fb69cd5b2 100644 --- a/tf_agents/utils/session_utils.py +++ b/tf_agents/utils/session_utils.py @@ -96,7 +96,7 @@ def session(self, session): """ @property - def session(self): + def session(self) -> tf.compat.v1.Session: """Returns the TensorFlow session-like object used by this object. Returns: