Skip to content

Commit

Permalink
[Benchmark] Benchmark Gym vs TorchRL (pytorch#1602)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 10, 2023
1 parent acefedf commit 82fa8f6
Showing 1 changed file with 337 additions and 0 deletions.
337 changes: 337 additions & 0 deletions benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,337 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""This script executes some envs across the Gym library with the explicit scope of testing the throughput using the various TorchRL components.
We test:
- gym async envs embedded in a TorchRL's GymEnv wrapper,
- ParallelEnv with regular GymEnv instances,
- Data collector
- Multiprocessed data collectors with parallel envs.
The tests are executed with various number of cpus, and on different devices.
"""
import time

import myosuite # noqa: F401
import tqdm
from torchrl._utils import timeit
from torchrl.collectors import (
MultiaSyncDataCollector,
MultiSyncDataCollector,
RandomPolicy,
SyncDataCollector,
)
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv
from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend

if __name__ == "__main__":
for envname in [
"HalfCheetah-v4",
"CartPole-v1",
"myoHandReachRandom-v0",
"ALE/Breakout-v5",
"CartPole-v1",
]:
# the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes
for num_workers, num_collectors in zip((8, 16, 32, 64), (2, 4, 8, 8)):
with open(
f"atari_{envname}_{num_workers}.txt".replace("/", "-"), "w+"
) as log:
if "myo" in envname:
gym_backend = "gym"
else:
gym_backend = "gymnasium"

total_frames = num_workers * 10_000

# pure gym
def make(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
return gym_bc().make(envname)

with set_gym_backend(gym_backend):
env = gym_bc().vector.AsyncVectorEnv(
[make for _ in range(num_workers)]
)
env.reset()
global_step = 0
times = []
start = time.time()
print("Timer started.")
for _ in tqdm.tqdm(range(total_frames // num_workers)):
env.step(env.action_space.sample())
global_step += num_workers
env.close()
log.write(
f"pure gym: {num_workers * 10_000 / (time.time() - start): 4.4f} fps\n"
)
log.flush()

# regular parallel env
for device in (
"cuda:0",
"cpu",
):

def make(envname=envname, gym_backend=gym_backend, device=device):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)

env_make = EnvCreator(make)
penv = ParallelEnv(num_workers, env_make)
# warmup
penv.rollout(2)
pbar = tqdm.tqdm(total=num_workers * 10_000)
t0 = time.time()
for _ in range(100):
data = penv.rollout(100, break_when_any_done=False)
pbar.update(100 * num_workers)
log.write(
f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
)
log.flush()
penv.close()
timeit.print()
del penv

for device in ("cuda:0", "cpu"):

def make(envname=envname, gym_backend=gym_backend, device=device):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)

env_make = EnvCreator(make)
# penv = SerialEnv(num_workers, env_make)
penv = ParallelEnv(num_workers, env_make)
collector = SyncDataCollector(
penv,
RandomPolicy(penv.action_spec),
frames_per_batch=1024,
total_frames=num_workers * 10_000,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
for i, data in enumerate(collector):
if i == num_collectors:
t0 = time.time()
if i >= num_collectors:
total_frames += data.numel()
pbar.update(data.numel())
pbar.set_description(
f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
)
log.write(
f"single collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
)
log.flush()
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
# gym parallel env
def make_env(
envname=envname,
num_workers=num_workers,
gym_backend=gym_backend,
device=device,
):
with set_gym_backend(gym_backend):
penv = GymEnv(envname, num_envs=num_workers, device=device)
return penv

penv = make_env()
# warmup
penv.rollout(2)
pbar = tqdm.tqdm(total=num_workers * 10_000)
t0 = time.time()
for _ in range(100):
data = penv.rollout(100, break_when_any_done=False)
pbar.update(100 * num_workers)
log.write(
f"gym penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
)
log.flush()
penv.close()
del penv

for device in (
"cuda:0",
"cpu",
):
# async collector
# + torchrl parallel env
def make_env(
envname=envname, gym_backend=gym_backend, device=device
):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)

penv = ParallelEnv(
num_workers // num_collectors, EnvCreator(make_env)
)
collector = MultiaSyncDataCollector(
[penv] * num_collectors,
policy=RandomPolicy(penv.action_spec),
frames_per_batch=1024,
total_frames=num_workers * 10_000,
device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
for i, data in enumerate(collector):
if i == num_collectors:
t0 = time.time()
if i >= num_collectors:
total_frames += data.numel()
pbar.update(data.numel())
pbar.set_description(
f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
)
log.write(
f"async collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
)
log.flush()
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
# async collector
# + gym async env
def make_env(
envname=envname,
num_workers=num_workers,
gym_backend=gym_backend,
device=device,
):
with set_gym_backend(gym_backend):
penv = GymEnv(envname, num_envs=num_workers, device=device)
return penv

penv = EnvCreator(
lambda num_workers=num_workers // num_collectors: make_env(
num_workers
)
)
collector = MultiaSyncDataCollector(
[penv] * num_collectors,
policy=RandomPolicy(penv().action_spec),
frames_per_batch=1024,
total_frames=num_workers * 10_000,
num_sub_threads=num_workers // num_collectors,
device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
for i, data in enumerate(collector):
if i == num_collectors:
t0 = time.time()
if i >= num_collectors:
total_frames += data.numel()
pbar.update(data.numel())
pbar.set_description(
f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps"
)
log.write(
f"async collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
)
log.flush()
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
# sync collector
# + torchrl parallel env
def make_env(
envname=envname, gym_backend=gym_backend, device=device
):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)

penv = ParallelEnv(
num_workers // num_collectors, EnvCreator(make_env)
)
collector = MultiSyncDataCollector(
[penv] * num_collectors,
policy=RandomPolicy(penv.action_spec),
frames_per_batch=1024,
total_frames=num_workers * 10_000,
device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
for i, data in enumerate(collector):
if i == num_collectors:
t0 = time.time()
if i >= num_collectors:
total_frames += data.numel()
pbar.update(data.numel())
pbar.set_description(
f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
)
log.write(
f"sync collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
)
log.flush()
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
# sync collector
# + gym async env
def make_env(
envname=envname,
num_workers=num_workers,
gym_backend=gym_backend,
device=device,
):
with set_gym_backend(gym_backend):
penv = GymEnv(envname, num_envs=num_workers, device=device)
return penv

penv = EnvCreator(
lambda num_workers=num_workers // num_collectors: make_env(
num_workers
)
)
collector = MultiSyncDataCollector(
[penv] * num_collectors,
policy=RandomPolicy(penv().action_spec),
frames_per_batch=1024,
total_frames=num_workers * 10_000,
num_sub_threads=num_workers // num_collectors,
device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
for i, data in enumerate(collector):
if i == num_collectors:
t0 = time.time()
if i >= num_collectors:
total_frames += data.numel()
pbar.update(data.numel())
pbar.set_description(
f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps"
)
log.write(
f"sync collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
)
log.flush()
collector.shutdown()
del collector
exit()

0 comments on commit 82fa8f6

Please sign in to comment.