Skip to content

Commit

Permalink
refactor: update evaluate script
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin18 committed Dec 6, 2023
1 parent 402c92c commit 582650c
Showing 1 changed file with 87 additions and 16 deletions.
103 changes: 87 additions & 16 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
NTupleNetworkTDPolicySmall,
)

plt.style.use("ggplot")


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -46,6 +48,16 @@ def parse_args() -> argparse.Namespace:
default=42,
help="random generator seed",
)
parser.add_argument(
"-t",
"--title",
help="figure title",
)
parser.add_argument(
"-o",
"--output-path",
help="path to output png file",
)
args = parser.parse_args()
return args

Expand All @@ -57,6 +69,13 @@ def make_env(env_id: str) -> gym.Env:


def make_policy(algo: str, trained_agent: str) -> NTupleNetworkBasePolicy:
"""
Makes the policy to evaluate.
:param algo: Name of the algorithm.
:param trained_agent: Path to a trained agent.
:return: Policy.
"""
algo_policy_map = {
"ql": NTupleNetworkQLearningPolicy,
"tdl": NTupleNetworkTDPolicy,
Expand All @@ -66,45 +85,66 @@ def make_policy(algo: str, trained_agent: str) -> NTupleNetworkBasePolicy:
return policy.load(trained_agent)


def evaluate() -> None:
args = parse_args()

np.random.seed(args.seed)
env = make_env(env_id=args.env)
if args.algo is not None and args.trained_agent is not None:
policy = make_policy(algo=args.algo, trained_agent=args.trained_agent)
else:
policy = None
def run_episodes(
env: gym.Env,
policy: NTupleNetworkBasePolicy | None,
n_episodes: int,
) -> tuple[list[int], list[int], list[int], list[int]]:
"""
Runs episodes and record statistics.
:param env: Game environment.
:param policy: Policy or None for random policy.
:param n_episodes: Number of episodes.
:return: Lengths, rewards, max tiles and total score.
"""
lengths = []
rewards = []
max_tiles = []
total_score = []

# Run episodes
for _ in trange(args.n_episodes, desc="Episode"):
for _ in trange(n_episodes, desc="Episode", unit="episode"):
_observation, info = env.reset()
terminated = truncated = False

while not terminated and not truncated:
if policy is None:
action = env.action_space.sample()
else:
state = info["board"]
action = policy.predict(state=state)

_observation, _reward, terminated, truncated, info = env.step(action)

lengths.extend(info["episode"]["l"])
rewards.extend(info["episode"]["r"])
max_tiles.append(info["max"])
total_score.append(info["total_score"])
env.reset()

env.close()

# Plot results
plt.style.use("ggplot")
env.reset()

return lengths, rewards, max_tiles, total_score


def plot_statistics(
lengths: list[int],
rewards: list[int],
max_tiles: list[int],
total_score: list[int],
title: str | None = None,
) -> plt.Figure:
"""
Plots episode statistics.
:param lengths: Lengths.
:param rewards: Rewards.
:param max_tiles: Maximum tiles reached.
:param total_score: Total game score.
:param title: Figure title. Default to None.
:return: Figure with statistics.
"""
fig, axs = plt.subplots(2, 2)

axs[0, 0].hist(lengths)
axs[0, 0].set_xlabel("Length")
axs[0, 0].set_ylabel("Count")
Expand All @@ -128,9 +168,40 @@ def evaluate() -> None:
axs[1, 1].set_ylabel("Count")
axs[1, 1].set_title("Score")

fig.suptitle(title)
fig.tight_layout()

return fig


def evaluate() -> None:
args = parse_args()

np.random.seed(args.seed)
env = make_env(env_id=args.env)
if args.algo is not None and args.trained_agent is not None:
policy = make_policy(algo=args.algo, trained_agent=args.trained_agent)
else:
policy = None

lengths, rewards, max_tiles, total_score = run_episodes(
env=env,
policy=policy,
n_episodes=args.n_episodes,
)
env.close()
fig = plot_statistics(
lengths=lengths,
rewards=rewards,
max_tiles=max_tiles,
total_score=total_score,
title=args.title,
)
fig.show()

if args.output_path is not None:
fig.savefig(args.output_path)


if __name__ == "__main__":
evaluate()

0 comments on commit 582650c

Please sign in to comment.