Skip to content

Commit

Permalink
Add tf test.
Browse files Browse the repository at this point in the history
  • Loading branch information
vaxenburg committed Apr 3, 2024
1 parent 1e203bf commit be4c6e2
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions tests/test_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Test: create and run tensorflow policy network in environment loop."""

import sonnet as snt
from acme import wrappers
from acme.tf import networks as network_utils
from acme.tf import utils as tf2_utils

from flybody.fly_envs import walk_on_ball
from flybody.agents.network_factory import make_network_factory_dmpo


def test_can_create_and_run_tf_policy():

env = walk_on_ball()
env = wrappers.SinglePrecisionWrapper(env)
env = wrappers.CanonicalSpecWrapper(env, clip=True)

network_factory = make_network_factory_dmpo()
networks = network_factory(env.action_spec())
assert set(networks.keys()) == set(('observation', 'policy', 'critic'))

policy_network = snt.Sequential([
networks['observation'],
networks['policy'],
network_utils.StochasticSamplingHead()
])

timestep = env.reset()
for _ in range(100):
batched_observation = tf2_utils.add_batch_dim(timestep.observation)
action = policy_network(batched_observation)
action = tf2_utils.to_numpy_squeeze(action)
timestep = env.step(action)

0 comments on commit be4c6e2

Please sign in to comment.