Skip to content

Commit

Permalink
feat: update random policy script
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin18 committed Dec 2, 2023
1 parent f871ea6 commit 319ab61
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions scripts/random_policy.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import argparse

import gymnasium as gym
from tqdm import trange


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Random policy")
parser = argparse.ArgumentParser(
description="Random policy",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--env",
default="gymnasium_2048:gymnasium_2048/TwentyFortyEight-v0",
help="environment id",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="seed",
help="random generator seed",
)
parser.add_argument(
"--n-timesteps",
Expand All @@ -21,16 +30,14 @@ def parse_args() -> argparse.Namespace:
return args


def main() -> None:
def random_policy() -> None:
args = parse_args()
env = gym.make(
"gymnasium_2048:gymnasium_2048/TwentyFortyEight-v0",
render_mode="human",
)

env = gym.make(args.env, render_mode="human")

env.reset(seed=args.seed)

for _ in range(args.n_timesteps):
for _ in trange(args.n_timesteps, desc="Random policy"):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)

Expand All @@ -41,4 +48,4 @@ def main() -> None:


if __name__ == "__main__":
main()
random_policy()

0 comments on commit 319ab61

Please sign in to comment.