diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py new file mode 100644 index 00000000000..457f15a2b5a --- /dev/null +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -0,0 +1,337 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""This script executes some envs across the Gym library with the explicit scope of testing the throughput using the various TorchRL components. + +We test: +- gym async envs embedded in a TorchRL's GymEnv wrapper, +- ParallelEnv with regular GymEnv instances, +- Data collector +- Multiprocessed data collectors with parallel envs. + +The tests are executed with various number of cpus, and on different devices. + +""" +import time + +import myosuite # noqa: F401 +import tqdm +from torchrl._utils import timeit +from torchrl.collectors import ( + MultiaSyncDataCollector, + MultiSyncDataCollector, + RandomPolicy, + SyncDataCollector, +) +from torchrl.envs import EnvCreator, GymEnv, ParallelEnv +from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend + +if __name__ == "__main__": + for envname in [ + "HalfCheetah-v4", + "CartPole-v1", + "myoHandReachRandom-v0", + "ALE/Breakout-v5", + "CartPole-v1", + ]: + # the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes + for num_workers, num_collectors in zip((8, 16, 32, 64), (2, 4, 8, 8)): + with open( + f"atari_{envname}_{num_workers}.txt".replace("/", "-"), "w+" + ) as log: + if "myo" in envname: + gym_backend = "gym" + else: + gym_backend = "gymnasium" + + total_frames = num_workers * 10_000 + + # pure gym + def make(envname=envname, gym_backend=gym_backend): + with set_gym_backend(gym_backend): + return gym_bc().make(envname) + + with set_gym_backend(gym_backend): + env = gym_bc().vector.AsyncVectorEnv( + [make for _ in range(num_workers)] + ) + env.reset() + global_step = 0 + times = [] + start = time.time() + print("Timer started.") + for _ in tqdm.tqdm(range(total_frames // num_workers)): + env.step(env.action_space.sample()) + global_step += num_workers + env.close() + log.write( + f"pure gym: {num_workers * 10_000 / (time.time() - start): 4.4f} fps\n" + ) + log.flush() + + # regular parallel env + for device in ( + "cuda:0", + "cpu", + ): + + def make(envname=envname, gym_backend=gym_backend, device=device): + with set_gym_backend(gym_backend): + return GymEnv(envname, device=device) + + env_make = EnvCreator(make) + penv = ParallelEnv(num_workers, env_make) + # warmup + penv.rollout(2) + pbar = tqdm.tqdm(total=num_workers * 10_000) + t0 = time.time() + for _ in range(100): + data = penv.rollout(100, break_when_any_done=False) + pbar.update(100 * num_workers) + log.write( + f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + penv.close() + timeit.print() + del penv + + for device in ("cuda:0", "cpu"): + + def make(envname=envname, gym_backend=gym_backend, device=device): + with set_gym_backend(gym_backend): + return GymEnv(envname, device=device) + + env_make = EnvCreator(make) + # penv = SerialEnv(num_workers, env_make) + penv = ParallelEnv(num_workers, env_make) + collector = SyncDataCollector( + penv, + RandomPolicy(penv.action_spec), + frames_per_batch=1024, + total_frames=num_workers * 10_000, + ) + pbar = tqdm.tqdm(total=num_workers * 10_000) + total_frames = 0 + for i, data in enumerate(collector): + if i == num_collectors: + t0 = time.time() + if i >= num_collectors: + total_frames += data.numel() + pbar.update(data.numel()) + pbar.set_description( + f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps" + ) + log.write( + f"single collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + collector.shutdown() + del collector + + for device in ( + "cuda:0", + "cpu", + ): + # gym parallel env + def make_env( + envname=envname, + num_workers=num_workers, + gym_backend=gym_backend, + device=device, + ): + with set_gym_backend(gym_backend): + penv = GymEnv(envname, num_envs=num_workers, device=device) + return penv + + penv = make_env() + # warmup + penv.rollout(2) + pbar = tqdm.tqdm(total=num_workers * 10_000) + t0 = time.time() + for _ in range(100): + data = penv.rollout(100, break_when_any_done=False) + pbar.update(100 * num_workers) + log.write( + f"gym penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + penv.close() + del penv + + for device in ( + "cuda:0", + "cpu", + ): + # async collector + # + torchrl parallel env + def make_env( + envname=envname, gym_backend=gym_backend, device=device + ): + with set_gym_backend(gym_backend): + return GymEnv(envname, device=device) + + penv = ParallelEnv( + num_workers // num_collectors, EnvCreator(make_env) + ) + collector = MultiaSyncDataCollector( + [penv] * num_collectors, + policy=RandomPolicy(penv.action_spec), + frames_per_batch=1024, + total_frames=num_workers * 10_000, + device=device, + ) + pbar = tqdm.tqdm(total=num_workers * 10_000) + total_frames = 0 + for i, data in enumerate(collector): + if i == num_collectors: + t0 = time.time() + if i >= num_collectors: + total_frames += data.numel() + pbar.update(data.numel()) + pbar.set_description( + f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps" + ) + log.write( + f"async collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + collector.shutdown() + del collector + + for device in ( + "cuda:0", + "cpu", + ): + # async collector + # + gym async env + def make_env( + envname=envname, + num_workers=num_workers, + gym_backend=gym_backend, + device=device, + ): + with set_gym_backend(gym_backend): + penv = GymEnv(envname, num_envs=num_workers, device=device) + return penv + + penv = EnvCreator( + lambda num_workers=num_workers // num_collectors: make_env( + num_workers + ) + ) + collector = MultiaSyncDataCollector( + [penv] * num_collectors, + policy=RandomPolicy(penv().action_spec), + frames_per_batch=1024, + total_frames=num_workers * 10_000, + num_sub_threads=num_workers // num_collectors, + device=device, + ) + pbar = tqdm.tqdm(total=num_workers * 10_000) + total_frames = 0 + for i, data in enumerate(collector): + if i == num_collectors: + t0 = time.time() + if i >= num_collectors: + total_frames += data.numel() + pbar.update(data.numel()) + pbar.set_description( + f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps" + ) + log.write( + f"async collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + collector.shutdown() + del collector + + for device in ( + "cuda:0", + "cpu", + ): + # sync collector + # + torchrl parallel env + def make_env( + envname=envname, gym_backend=gym_backend, device=device + ): + with set_gym_backend(gym_backend): + return GymEnv(envname, device=device) + + penv = ParallelEnv( + num_workers // num_collectors, EnvCreator(make_env) + ) + collector = MultiSyncDataCollector( + [penv] * num_collectors, + policy=RandomPolicy(penv.action_spec), + frames_per_batch=1024, + total_frames=num_workers * 10_000, + device=device, + ) + pbar = tqdm.tqdm(total=num_workers * 10_000) + total_frames = 0 + for i, data in enumerate(collector): + if i == num_collectors: + t0 = time.time() + if i >= num_collectors: + total_frames += data.numel() + pbar.update(data.numel()) + pbar.set_description( + f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps" + ) + log.write( + f"sync collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + collector.shutdown() + del collector + + for device in ( + "cuda:0", + "cpu", + ): + # sync collector + # + gym async env + def make_env( + envname=envname, + num_workers=num_workers, + gym_backend=gym_backend, + device=device, + ): + with set_gym_backend(gym_backend): + penv = GymEnv(envname, num_envs=num_workers, device=device) + return penv + + penv = EnvCreator( + lambda num_workers=num_workers // num_collectors: make_env( + num_workers + ) + ) + collector = MultiSyncDataCollector( + [penv] * num_collectors, + policy=RandomPolicy(penv().action_spec), + frames_per_batch=1024, + total_frames=num_workers * 10_000, + num_sub_threads=num_workers // num_collectors, + device=device, + ) + pbar = tqdm.tqdm(total=num_workers * 10_000) + total_frames = 0 + for i, data in enumerate(collector): + if i == num_collectors: + t0 = time.time() + if i >= num_collectors: + total_frames += data.numel() + pbar.update(data.numel()) + pbar.set_description( + f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps" + ) + log.write( + f"sync collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + collector.shutdown() + del collector + exit()