diff --git a/examples/alpha_zero/agent.py b/examples/alpha_zero/agent.py new file mode 100644 index 0000000..c6324df --- /dev/null +++ b/examples/alpha_zero/agent.py @@ -0,0 +1,46 @@ +import numpy as np + +from examples.alpha_zero.mcts import MCTS +from tetris_gymnasium.envs import Tetris + + +class MCTSAgent: + """AI agent based on MCTS""" + + def __init__(self, policy_value_function, c_puct=5, n_playout=2000, is_selfplay=0): + self.mcts = MCTS(policy_value_function, c_puct, n_playout) + self._is_selfplay = is_selfplay + + def reset_agent(self): + self.mcts.update_with_move(-1) + + def get_action(self, env: Tetris, temp=1e-3, return_prob=0): + # the pi vector returned by MCTS as in the alphaGo Zero paper + move_probs = np.zeros(env.action_space.n) + acts, probs = self.mcts.get_move_probs(env, temp) + move_probs[list(acts)] = probs + if self._is_selfplay: + # add Dirichlet Noise for exploration (needed for + # self-play training) + move = np.random.choice( + acts, + p=0.75 * probs + 0.25 * np.random.dirichlet(0.3 * np.ones(len(probs))), + ) + # update the root node and reuse the search tree + self.mcts.update_with_move(move) + else: + # with the default temp=1e-3, it is almost equivalent + # to choosing the move with the highest prob + move = np.random.choice(acts, p=probs) + # reset the root node + self.mcts.update_with_move(-1) + # location = board.move_to_location(move) + # print("AI move: %d,%d\n" % (location[0], location[1])) + + if return_prob: + return move, move_probs + else: + return move + + def __str__(self): + return "MCTS Agent" diff --git a/examples/alpha_zero/mcts.py b/examples/alpha_zero/mcts.py new file mode 100644 index 0000000..4633dc5 --- /dev/null +++ b/examples/alpha_zero/mcts.py @@ -0,0 +1,173 @@ +""" +Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value +network to guide the tree search and evaluate the leaf nodes + +@author: Junxiao Song +""" + +import copy + +import numpy as np + +from tetris_gymnasium.envs import Tetris + + +def softmax(x): + probs = np.exp(x - np.max(x)) + probs /= np.sum(probs) + return probs + + +class TreeNode: + """A node in the MCTS tree. + + Each node keeps track of its own value Q, prior probability P, and + its visit-count-adjusted prior score u. + """ + + def __init__(self, parent, prior_p): + self._parent = parent + self._children = {} # a map from action to TreeNode + self._n_visits = 0 + self._Q = 0 + self._u = 0 + self._P = prior_p + + def expand(self, action_priors): + """Expand tree by creating new children. + action_priors: a list of tuples of actions and their prior probability + according to the policy function. + """ + for action, prob in action_priors: + if action not in self._children: + self._children[action] = TreeNode(self, prob) + + def select(self, c_puct): + """Select action among children that gives maximum action value Q + plus bonus u(P). + Return: A tuple of (action, next_node) + """ + return max( + self._children.items(), key=lambda act_node: act_node[1].get_value(c_puct) + ) + + def update(self, leaf_value): + """Update node values from leaf evaluation. + leaf_value: the value of subtree evaluation from the current player's + perspective. + """ + # Count visit. + self._n_visits += 1 + # Update Q, a running average of values for all visits. + self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits + + def update_recursive(self, leaf_value): + """Like a call to update(), but applied recursively for all ancestors.""" + # If it is not root, this node's parent should be updated first. + if self._parent: + self._parent.update_recursive(-leaf_value) + self.update(leaf_value) + + def get_value(self, c_puct): + """Calculate and return the value for this node. + It is a combination of leaf evaluations Q, and this node's prior + adjusted for its visit count, u. + c_puct: a number in (0, inf) controlling the relative impact of + value Q, and prior probability P, on this node's score. + """ + self._u = ( + c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits) + ) + return self._Q + self._u + + def is_leaf(self): + """Check if leaf node (i.e. no nodes below this have been expanded).""" + return self._children == {} + + def is_root(self): + return self._parent is None + + +class MCTS: + """An implementation of Monte Carlo Tree Search.""" + + def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): + """ + policy_value_fn: a function that takes in a board state and outputs + a list of (action, probability) tuples and also a score in [-1, 1] + (i.e. the expected value of the end game score from the current + player's perspective) for the current player. + c_puct: a number in (0, inf) that controls how quickly exploration + converges to the maximum-value policy. A higher value means + relying on the prior more. + """ + self._root = TreeNode(None, 1.0) + self._policy = policy_value_fn + self._c_puct = c_puct + self._n_playout = n_playout + + def _playout(self, env: Tetris): + """Run a single playout from the root to the leaf, getting a value at + the leaf and propagating it back through its parents. + State is modified in-place, so a copy must be provided. + """ + node = self._root + reward = 0 + terminated = False + while 1: + if node.is_leaf(): + break + # Greedily select next move. + action, node = node.select(self._c_puct) + obs, reward, terminated, truncated, info = env.step(action) + + # Evaluate the leaf using a network which outputs a list of + # (action, probability) tuples p and also a score v in [-1, 1] + # for the current player. + action_probs, leaf_value = self._policy(env) + + reward = reward / ((4**2) * 10) # normalize reward + leaf_value = (leaf_value + reward) / 2 # average reward and value to normalize + + # Check for end of game. + if not terminated: + node.expand(action_probs) + + # Update value and visit count of nodes in this traversal. + node.update_recursive(-leaf_value) + + def get_move_probs(self, env, temp=1e-3): + """Run all playouts sequentially and return the available actions and + their corresponding probabilities. + state: the current game state + temp: temperature parameter in (0, 1] controls the level of exploration + """ + + state_copy = env.unwrapped.clone_state() + for n in range(self._n_playout): + env.unwrapped.restore_state(state_copy) + self._playout(env) + + env.unwrapped.restore_state(state_copy) + + # calc the move probabilities based on visit counts at the root node + act_visits = [ + (act, node._n_visits) for act, node in self._root._children.items() + ] + acts, visits = zip(*act_visits) + act_probs = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10)) + + return acts, act_probs + + def update_with_move(self, last_move): + """Step forward in the tree, keeping everything we already know + about the subtree. + """ + if last_move in self._root._children: + self._root = self._root._children[last_move] + self._root._parent = None + else: + self._root = TreeNode(None, 1.0) + + def __str__(self): + return "MCTS" diff --git a/examples/alpha_zero/model.py b/examples/alpha_zero/model.py new file mode 100644 index 0000000..86a3ebd --- /dev/null +++ b/examples/alpha_zero/model.py @@ -0,0 +1,170 @@ +""" +An implementation of the policyValueNet in PyTorch +Tested in PyTorch 0.2.0 and 0.3.0 + +@author: Junxiao Song +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.autograd import Variable + +from tetris_gymnasium.envs import Tetris + + +def set_learning_rate(optimizer, lr): + """Sets the learning rate to the given value""" + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +class Net(nn.Module): + """policy-value network module""" + + def __init__(self, board_width, board_height, action_size): + super().__init__() + + self.board_width = board_width + self.board_height = board_height + # common layers + n_channels = 1 + self.conv1 = nn.Conv2d(n_channels, 32, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) + # action policy layers + self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1) + self.act_fc1 = nn.Linear(4 * board_width * board_height, action_size) + # state value layers + self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1) + self.val_fc1 = nn.Linear(2 * board_width * board_height, 64) + self.val_fc2 = nn.Linear(64, 1) + + def forward(self, state_input): + # common layers + x = F.relu(self.conv1(state_input)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + # action policy layers + x_act = F.relu(self.act_conv1(x)) + x_act = x_act.view(-1, 4 * self.board_width * self.board_height) + x_act = F.log_softmax(self.act_fc1(x_act)) + # state value layers + x_val = F.relu(self.val_conv1(x)) + x_val = x_val.view(-1, 2 * self.board_width * self.board_height) + x_val = F.relu(self.val_fc1(x_val)) + x_val = F.tanh(self.val_fc2(x_val)) + return x_act, x_val + + +class PolicyValueNet: + """policy-value network""" + + def __init__( + self, board_width, board_height, action_size, model_file=None, use_gpu=False + ): + self.use_gpu = use_gpu + self.board_width = board_width + self.board_height = board_height + self.l2_const = 1e-4 # coef of l2 penalty + # the policy value net module + if self.use_gpu: + self.policy_value_net = Net(board_width, board_height, action_size).cuda() + else: + self.policy_value_net = Net(board_width, board_height, action_size) + self.optimizer = optim.Adam( + self.policy_value_net.parameters(), weight_decay=self.l2_const + ) + + if model_file: + net_params = torch.load(model_file) + self.policy_value_net.load_state_dict(net_params) + + def policy_value(self, state_batch): + """ + input: a batch of states + output: a batch of action probabilities and state values + """ + if self.use_gpu: + state_batch = Variable(torch.FloatTensor(state_batch).cuda()) + state_batch = state_batch.unsqueeze(1) + log_act_probs, value = self.policy_value_net(state_batch) + act_probs = np.exp(log_act_probs.data.cpu().numpy()) + return act_probs, value.data.cpu().numpy() + else: + state_batch = Variable(torch.FloatTensor(state_batch)) + state_batch = state_batch.unsqueeze(1) + log_act_probs, value = self.policy_value_net(state_batch) + act_probs = np.exp(log_act_probs.data.numpy()) + return act_probs, value.data.numpy() + + def policy_value_fn(self, env: Tetris): + """ + input: board + output: a list of (action, probability) tuples for each available + action and the score of the board state + """ + legal_positions = list(range(env.action_space.n)) + current_state = np.ascontiguousarray(np.expand_dims(env.get_obs(), axis=0)) + # current_state = np.ascontiguousarray(env._get_obs().reshape( + # -1, 4, self.board_width, self.board_height)) + if self.use_gpu: + log_act_probs, value = self.policy_value_net( + Variable(torch.from_numpy(current_state)).cuda().float() + ) + act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten()) + value = value.data.cpu().numpy()[0][0] + else: + log_act_probs, value = self.policy_value_net( + Variable(torch.from_numpy(current_state)).float() + ) + act_probs = np.exp(log_act_probs.data.numpy().flatten()) + value = value.data.numpy()[0][0] + act_probs = zip(legal_positions, act_probs[legal_positions]) + return act_probs, value + + def train_step(self, state_batch, mcts_probs, z_batch, lr): + """perform a training step""" + # wrap in Variable + if self.use_gpu: + state_batch = Variable(torch.FloatTensor(state_batch).cuda()) + state_batch = state_batch.unsqueeze(1) + mcts_probs = Variable(torch.FloatTensor(mcts_probs).cuda()) + z_batch = Variable(torch.FloatTensor(z_batch).cuda()) + else: + state_batch = Variable(torch.FloatTensor(state_batch)) + state_batch = state_batch.unsqueeze(1) + mcts_probs = Variable(torch.FloatTensor(mcts_probs)) + z_batch = Variable(torch.FloatTensor(z_batch)) + + # zero the parameter gradients + self.optimizer.zero_grad() + # set learning rate + set_learning_rate(self.optimizer, lr) + + # forward + log_act_probs, value = self.policy_value_net(state_batch) + # define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2 + # Note: the L2 penalty is incorporated in optimizer + value_loss = F.mse_loss(value.view(-1), z_batch) + policy_loss = -torch.mean(torch.sum(mcts_probs * log_act_probs, 1)) + loss = value_loss + policy_loss + # backward and optimize + loss.backward() + self.optimizer.step() + # calc policy entropy, for monitoring only + entropy = -torch.mean(torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)) + # return loss.data[0], entropy.data[0] + # for pytorch version >= 0.5 please use the following line instead. + return loss.item(), entropy.item() + + def get_policy_param(self): + net_params = self.policy_value_net.state_dict() + return net_params + + def save_model(self, model_file): + """save model params to file""" + net_params = self.get_policy_param() # get model params + torch.save(net_params, model_file) diff --git a/examples/alpha_zero/train.py b/examples/alpha_zero/train.py new file mode 100644 index 0000000..e62ea7a --- /dev/null +++ b/examples/alpha_zero/train.py @@ -0,0 +1,199 @@ +""" +An implementation of the training pipeline of AlphaZero for Gomoku + +@author: Junxiao Song +""" + +import random +import time +from collections import defaultdict, deque + +import gymnasium +import numpy as np + +from examples.alpha_zero.agent import MCTSAgent +from examples.alpha_zero.model import PolicyValueNet +from tetris_gymnasium.envs import Tetris +from tetris_gymnasium.wrappers.observation import SimpleObservationWrapper + + +class TrainPipeline: + def __init__(self, init_model=None): + # params of the board and the game + self.board_width = 10 + self.board_height = 20 + self.n_in_row = 4 + self.env = SimpleObservationWrapper( + gymnasium.make( + "tetris_gymnasium/Tetris", + width=self.board_width, + height=self.board_height, + render_mode="rgb_array", + gravity=True, + ) + ) + # training params + self.learn_rate = 2e-3 + self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL + self.temp = 1.0 # the temperature param + self.n_playout = 400 # num of simulations for each move + self.c_puct = 5 + self.buffer_size = 10000 + self.batch_size = 512 # mini-batch size for training + self.data_buffer = deque(maxlen=self.buffer_size) + self.play_batch_size = 1 + self.epochs = 5 # num of train_steps for each update + self.kl_targ = 0.02 + self.check_freq = 50 + self.game_batch_num = 1500 + self.best_win_ratio = 0.0 + # num of simulations used for the pure mcts, which is used as + # the opponent to evaluate the trained policy + self.pure_mcts_playout_num = 1000 + action_size = self.env.action_space.n + if init_model: + # start training from an initial policy-value net + self.policy_value_net = PolicyValueNet( + self.board_width, self.board_height, action_size, model_file=init_model + ) + else: + # start training from a new policy-value net + self.policy_value_net = PolicyValueNet( + self.board_width, self.board_height, action_size + ) + self.agent = MCTSAgent( + self.policy_value_net.policy_value_fn, + c_puct=self.c_puct, + n_playout=self.n_playout, + is_selfplay=1, + ) + + def start_self_play(self, is_shown=False): + """start a self-play game using a MCTS player, reuse the search tree, + and store the self-play data: (state, mcts_probs, z) for training + """ + obs, _ = self.env.reset() + states, mcts_probs, rewards = [], [], [] + while True: + move, move_probs = self.agent.get_action( + self.env, temp=self.temp, return_prob=1 + ) + # store the data + states.append(obs) + mcts_probs.append(move_probs) + # perform a move + obs, reward, terminated, truncated, info = self.env.step(move) + rewards.append(reward) + if is_shown: + self.env.render() + if terminated: + self.agent.reset_agent() + if is_shown: + print("Game over") + return zip(states, mcts_probs, rewards) + + def collect_selfplay_data(self, n_games=1): + """collect self-play data for training""" + for i in range(n_games): + play_data = self.start_self_play() + play_data = list(play_data)[:] + self.episode_len = len(play_data) + # augment the data + self.data_buffer.extend(play_data) + + def policy_update(self): + """update the policy-value net""" + mini_batch = random.sample(self.data_buffer, self.batch_size) + state_batch = [data[0] for data in mini_batch] + mcts_probs_batch = [data[1] for data in mini_batch] + z_batch = [data[2] for data in mini_batch] + old_probs, old_v = self.policy_value_net.policy_value(state_batch) + for i in range(self.epochs): + loss, entropy = self.policy_value_net.train_step( + state_batch, + mcts_probs_batch, + z_batch, + self.learn_rate * self.lr_multiplier, + ) + new_probs, new_v = self.policy_value_net.policy_value(state_batch) + kl = np.mean( + np.sum( + old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), + axis=1, + ) + ) + if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly + break + # adaptively adjust the learning rate + if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: + self.lr_multiplier /= 1.5 + elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: + self.lr_multiplier *= 1.5 + + explained_var_old = 1 - np.var(np.array(z_batch) - old_v.flatten()) / np.var( + np.array(z_batch) + ) + explained_var_new = 1 - np.var(np.array(z_batch) - new_v.flatten()) / np.var( + np.array(z_batch) + ) + print( + ( + "kl:{:.5f}," + "lr_multiplier:{:.3f}," + "loss:{}," + "entropy:{}," + "explained_var_old:{:.3f}," + "explained_var_new:{:.3f}" + ).format( + kl, + self.lr_multiplier, + loss, + entropy, + explained_var_old, + explained_var_new, + ) + ) + return loss, entropy + + def run(self): + """run the training pipeline""" + try: + for i in range(self.game_batch_num): + start_time = time.time() + self.collect_selfplay_data(self.play_batch_size) + end_time = time.time() + + print(f"batch i:{i + 1}, episode_len:{self.episode_len}") + + execution_time = end_time - start_time + print(f"collect_selfplay_data took {execution_time:.4f} seconds") + print("SPS", self.episode_len / execution_time) + + print(f"batch i:{i + 1}, episode_len:{self.episode_len}") + if len(self.data_buffer) > self.batch_size: + loss, entropy = self.policy_update() + # check the performance of the current model, + # and save the model params + if (i + 1) % self.check_freq == 0: + print(f"current self-play batch: {i+1}") + win_ratio = 0.5 + # win_ratio = self.policy_evaluate() + self.policy_value_net.save_model("./current_policy.model") + if win_ratio > self.best_win_ratio: + print("New best policy!!!!!!!!") + self.best_win_ratio = win_ratio + # update the best_policy + self.policy_value_net.save_model("./best_policy.model") + if ( + self.best_win_ratio == 1.0 + and self.pure_mcts_playout_num < 5000 + ): + self.pure_mcts_playout_num += 1000 + self.best_win_ratio = 0.0 + except KeyboardInterrupt: + print("\n\rquit") + + +if __name__ == "__main__": + training_pipeline = TrainPipeline() + training_pipeline.run() diff --git a/examples/perf_state_test.py b/examples/perf_state_test.py new file mode 100644 index 0000000..f7820a8 --- /dev/null +++ b/examples/perf_state_test.py @@ -0,0 +1,69 @@ +import copy +import time + +import gymnasium as gym +import numpy as np + +from tetris_gymnasium.envs import Tetris +from tetris_gymnasium.wrappers.observation import RgbObservation + + +def create_env(): + env = gym.make("tetris_gymnasium/Tetris", render_mode="rgb_array") + return RgbObservation(env) + + +def run_test(n_state_inits, n_copy_operations): + # Create a single environment to use for all operations + env = create_env() + _ = env.reset() + + # Create n_state_inits copies of the state + start_time = time.time() + state_arr = [env.unwrapped.clone_state() for _ in range(n_state_inits)] + creation_time = time.time() - start_time + + # Perform n_copy_operations + start_time = time.time() + for i in np.random.choice(n_state_inits, n_copy_operations, replace=False): + state_arr[i] = env.unwrapped.clone_state() + copy_operations_time = time.time() - start_time + + # Simulate a single copy operation + start_time = time.time() + idx = np.random.randint(n_state_inits) + state = copy.deepcopy(env.unwrapped.clone_state()) + env.unwrapped.restore_state(state) + state_arr[idx] = env.unwrapped.clone_state() + single_copy_time = time.time() - start_time + + # Clean up + env.close() + + return creation_time, copy_operations_time, single_copy_time + + +def main(): + n_state_inits_values = [500000] + n_copy_operations_values = [500000] + + for n_state_inits in n_state_inits_values: + for n_copy_operations in n_copy_operations_values: + if n_copy_operations <= n_state_inits: + creation_time, copy_operations_time, single_copy_time = run_test( + n_state_inits, n_copy_operations + ) + print( + f"Test with {n_state_inits} state initializations and {n_copy_operations} copy operations" + ) + print(f" State creation time: {creation_time:.6f} seconds") + print(f" Copy operations time: {copy_operations_time:.6f} seconds") + print(f" Single copy time: {single_copy_time:.6f} seconds") + print( + f" Total time: {creation_time + copy_operations_time + single_copy_time:.6f} seconds" + ) + print() + + +if __name__ == "__main__": + main() diff --git a/examples/perf_val_clones.py b/examples/perf_val_clones.py new file mode 100644 index 0000000..7096fb9 --- /dev/null +++ b/examples/perf_val_clones.py @@ -0,0 +1,119 @@ +import gymnasium as gym +import numpy as np +import pytest + +from tetris_gymnasium.components.tetromino_queue import TetrominoQueue +from tetris_gymnasium.components.tetromino_randomizer import Randomizer +from tetris_gymnasium.envs import Tetris +from tetris_gymnasium.wrappers.observation import RgbObservation + + +def create_env(): + env = gym.make("tetris_gymnasium/Tetris", render_mode="rgb_array") + return RgbObservation(env) + + +def compare_states(state1, state2): + """Compare two Tetris game states.""" + assert state1.keys() == state2.keys(), "States have different keys" + + for key in state1.keys(): + if key == "board": + assert np.array_equal(state1[key], state2[key]), f"Board mismatch" + elif key == "active_tetromino": + compare_tetrominos(state1[key], state2[key]) + elif key in ["x", "y", "has_swapped", "game_over", "score"]: + assert state1[key] == state2[key], f"Value mismatch for key: {key}" + elif key == "queue": + compare_tetromino_queues(state1[key], state2[key]) + elif key == "holder": + compare_holders(state1[key], state2[key]) + elif key == "randomizer": + compare_randomizers(state1[key], state2[key]) + else: + raise ValueError(f"Unknown key in state: {key}") + + +def compare_tetrominos(tetromino1, tetromino2): + assert tetromino1.id == tetromino2.id, "Tetromino ID mismatch" + assert np.array_equal( + tetromino1.color_rgb, tetromino2.color_rgb + ), "Tetromino color mismatch" + assert np.array_equal( + tetromino1.matrix, tetromino2.matrix + ), "Tetromino matrix mismatch" + + +def compare_tetromino_queues(queue1, queue2): + assert isinstance(queue1, TetrominoQueue) and isinstance( + queue2, TetrominoQueue + ), "Queue type mismatch" + assert queue1.size == queue2.size, "Queue size mismatch" + assert len(queue1.queue) == len(queue2.queue), "Queue length mismatch" + for t1, t2 in zip(queue1.queue, queue2.queue): + assert t1 == t2, "Queue content mismatch" + compare_randomizers(queue1.randomizer, queue2.randomizer) + + +def compare_holders(holder1, holder2): + assert holder1.size == holder2.size, "Holder size mismatch" + assert len(holder1.queue) == len(holder2.queue), "Holder queue length mismatch" + for t1, t2 in zip(holder1.queue, holder2.queue): + if t1 is None and t2 is None: + continue + compare_tetrominos(t1, t2) + + +def compare_randomizers(randomizer1, randomizer2): + assert isinstance(randomizer1, Randomizer) and isinstance( + randomizer2, Randomizer + ), "Randomizer type mismatch" + assert randomizer1.__class__ == randomizer2.__class__, "Randomizer class mismatch" + assert randomizer1.size == randomizer2.size, "Randomizer size mismatch" + if hasattr(randomizer1, "bag"): + assert np.array_equal( + randomizer1.bag, randomizer2.bag + ), "Randomizer bag mismatch" + assert randomizer1.index == randomizer2.index, "Randomizer index mismatch" + + +@pytest.fixture(scope="module") +def env(): + environment = create_env() + yield environment + environment.close() + + +@pytest.mark.parametrize("test_number", range(1000)) +def test_clone_restore_consistency(env, test_number): + # Reset the environment if it's the first test or if the previous test ended + if test_number == 0 or env.unwrapped.game_over: + env.reset() + + # Clone the current state + original_state = env.unwrapped.clone_state() + + # Take a random action in the environment + action = env.action_space.sample() + original_obs, original_reward, original_done, original_info, _ = env.step(action) + + # Store the current state after the action + post_action_state_a = env.unwrapped.clone_state() + + # Restore the original state + env.unwrapped.restore_state(original_state) + + # Take the same action again + cloned_obs, cloned_reward, cloned_done, cloned_info, _ = env.step(action) + + # Clone the current state + post_action_state_b = env.unwrapped.clone_state() + + # Compare results + assert np.array_equal(original_obs, cloned_obs), "Observations don't match" + assert original_reward == cloned_reward, "Rewards don't match" + assert original_done == cloned_done, "Done flags don't match" + assert original_info == cloned_info, "Info dictionaries don't match" + + # Compare the full state after action with the stored post-action state + compare_states(post_action_state_a, post_action_state_b) diff --git a/tetris_gymnasium/components/tetromino.py b/tetris_gymnasium/components/tetromino.py index bd62e7c..e4003e7 100644 --- a/tetris_gymnasium/components/tetromino.py +++ b/tetris_gymnasium/components/tetromino.py @@ -1,4 +1,5 @@ """Data structures for Tetris.""" +from copy import deepcopy from dataclasses import dataclass import numpy as np @@ -17,6 +18,11 @@ class Pixel: id: int color_rgb: list + def __deepcopy__(self, memo): + # Since id is immutable and color_rgb is a list that needs to be copied, + # we can use the default deepcopy behavior + return deepcopy(super(), memo) + @dataclass class Tetromino(Pixel): @@ -42,3 +48,11 @@ class Tetromino(Pixel): """ matrix: np.ndarray + + def __deepcopy__(self, memo): + # Create a new instance with copied attributes + return Tetromino( + id=self.id, + color_rgb=deepcopy(self.color_rgb, memo), + matrix=self.matrix.copy(), + ) diff --git a/tetris_gymnasium/components/tetromino_holder.py b/tetris_gymnasium/components/tetromino_holder.py index 905f2f4..87bfc42 100644 --- a/tetris_gymnasium/components/tetromino_holder.py +++ b/tetris_gymnasium/components/tetromino_holder.py @@ -1,5 +1,6 @@ """Module for the Holder class, which stores one or more tetrominoes for later use in a game of Tetris.""" from collections import deque +from copy import deepcopy from typing import Optional from tetris_gymnasium.components.tetromino import Tetromino @@ -55,3 +56,12 @@ def reset(self): def get_tetrominoes(self): """Get all the tetrominoes currently in the holder.""" return list(self.queue) + + def __deepcopy__(self, memo): + # Create a new instance + new_holder = TetrominoHolder(self.size) + + # Deep copy the queue + new_holder.queue = deepcopy(self.queue, memo) + + return new_holder diff --git a/tetris_gymnasium/components/tetromino_queue.py b/tetris_gymnasium/components/tetromino_queue.py index 9b8bbbe..f96cd53 100644 --- a/tetris_gymnasium/components/tetromino_queue.py +++ b/tetris_gymnasium/components/tetromino_queue.py @@ -1,5 +1,6 @@ """Module for a queue of tetrominoes for use in a game of Tetris.""" from collections import deque +from copy import deepcopy from tetris_gymnasium.components.tetromino_randomizer import Randomizer @@ -44,3 +45,12 @@ def get_next_tetromino(self): def get_queue(self): """Get all tetrominoes currently in the queue.""" return list(self.queue) + + def __deepcopy__(self, memo): + # Create a new instance + new_queue = TetrominoQueue(deepcopy(self.randomizer, memo), self.size) + + # Deep copy the queue + new_queue.queue = deepcopy(self.queue, memo) + + return new_queue diff --git a/tetris_gymnasium/envs/tetris.py b/tetris_gymnasium/envs/tetris.py index c611772..fbe5250 100644 --- a/tetris_gymnasium/envs/tetris.py +++ b/tetris_gymnasium/envs/tetris.py @@ -1,5 +1,6 @@ """Tetris environment for Gymnasium.""" import copy +from collections import deque from dataclasses import fields from typing import Any, List @@ -122,12 +123,19 @@ def __init__( # Reason for this kind of initialization: https://stackoverflow.com/q/41686829 if randomizer is None: self.randomizer = BagRandomizer(len(self.tetrominoes)) + else: + self.randomizer = randomizer(len(self.tetrominoes)) if queue is None: self.queue = TetrominoQueue(self.randomizer) + else: + self.queue = queue(self.randomizer) if holder is None: self.holder = TetrominoHolder() + else: + self.holder = holder self.has_swapped = False self.gravity_enabled = gravity + self.score = 0 # Position self.x: int = 0 @@ -247,12 +255,15 @@ def step(self, action: ActType) -> "tuple[dict, float, bool, bool, dict]": # If there's no more room to move, lock in the tetromino reward, self.game_over, lines_cleared = self.commit_active_tetromino() + # update score + self.score += reward + return ( self._get_obs(), reward, self.game_over, truncated, - {"lines_cleared": lines_cleared}, + {"lines_cleared": lines_cleared, "score": self.score}, ) def reset( @@ -288,6 +299,9 @@ def reset( # Render self.window_name = None + # Score + self.score = 0 + return self._get_obs(), self._get_info() def get_rgb(self, observation): @@ -449,7 +463,7 @@ def commit_active_tetromino(self): self.drop_active_tetromino() self.place_active_tetromino() self.board, lines_cleared = self.clear_filled_rows(self.board) - reward = self.score(lines_cleared) + reward = self.calc_score(lines_cleared) # 2. Spawn the next tetromino and check if the game continues self.game_over = not self.spawn_tetromino() @@ -559,6 +573,18 @@ def _get_obs(self) -> "dict[str, Any]": active_tetromino_mask = np.zeros_like(board_obs) active_tetromino_mask[active_tetromino_slices] = 1 + # todo: make this cleaner + return { + "board": board_obs.astype(np.uint8), + "active_tetromino_mask": active_tetromino_mask.astype(np.uint8), + "holder": np.zeros( + (self.padding, self.padding * self.holder.size), dtype=np.uint8 + ), + "queue": np.zeros( + (self.padding, self.padding * self.queue.size), dtype=np.uint8 + ), + } + # Holder max_size = self.padding holder_tetrominoes = self.holder.get_tetrominoes() @@ -602,7 +628,7 @@ def _get_info(self) -> dict: """Return the current game state as info.""" return {"lines_cleared": 0} - def score(self, rows_cleared) -> int: + def calc_score(self, rows_cleared) -> int: """Calculate the score based on the number of lines cleared. Args: @@ -661,3 +687,68 @@ def offset_tetromino_id( tetrominoes[i].matrix = tetrominoes[i].matrix * (i + offset) return tetrominoes + + def restore_state(self, state): + """Restore the state of the environment.""" + self.board = state["board"] + self.active_tetromino = state["active_tetromino"] + self.x = state["x"] + self.y = state["y"] + self.queue = state["queue"] + self.holder = state["holder"] + self.randomizer = state["randomizer"] + self.has_swapped = state["has_swapped"] + self.game_over = state["game_over"] + self.score = state["score"] + + def clone_state(self): + """Clone the current state of the environment.""" + randomizer = self.clone_randomizer(self.randomizer) + return { + "board": self.clone_board(self.board), + "active_tetromino": self.clone_tetromino(self.active_tetromino), + "x": self.x, + "y": self.y, + "queue": self.clone_queue(self.queue, randomizer), + "holder": self.clone_holder(self.holder), + "randomizer": randomizer, + "has_swapped": self.has_swapped, + "game_over": self.game_over, + "score": self.score, + } + + @staticmethod + def clone_board(board): + return board.copy() + + @staticmethod + def clone_tetromino(tetromino): + return Tetromino( + id=tetromino.id, + color_rgb=tetromino.color_rgb.copy(), + matrix=tetromino.matrix.copy(), + ) + + @staticmethod + def clone_queue(queue, randomizer): + new_queue = TetrominoQueue(randomizer, queue.size) + new_queue.queue = deque(queue.queue) + return new_queue + + @staticmethod + def clone_holder(holder): + new_holder = TetrominoHolder(holder.size) + new_holder.queue = deque(holder.queue) + return new_holder + + @staticmethod + def clone_randomizer(randomizer): + new_randomizer = randomizer.__class__(randomizer.size) + new_randomizer.rng = np.random.Generator(np.random.PCG64()) + new_randomizer.rng.bit_generator.state = randomizer.rng.bit_generator.state + # new_randomizer.rng = copy.deepcopy(randomizer.rng) + + if hasattr(randomizer, "bag"): + new_randomizer.bag = randomizer.bag.copy() + new_randomizer.index = randomizer.index + return new_randomizer diff --git a/tetris_gymnasium/wrappers/observation.py b/tetris_gymnasium/wrappers/observation.py index 675a406..e34243a 100644 --- a/tetris_gymnasium/wrappers/observation.py +++ b/tetris_gymnasium/wrappers/observation.py @@ -8,6 +8,55 @@ from tetris_gymnasium.envs import Tetris +class SimpleObservationWrapper(gym.ObservationWrapper): + """Observation wrapper that displays all observations (board, holder, queue) as one single RGB Image. + + The observation contains the board on the left, the queue on the top right and the holder on the bottom right. + The size of the matrix depends on how many tetrominoes can be stored in the queue / holder. + """ + + def __init__(self, env: Tetris): + """Initialize the RgbObservation wrapper. + + Args: + env (Tetris): The environment + """ + super().__init__(env) + self.observation_space = Box( + low=0, + high=2, + shape=(env.unwrapped.height, env.unwrapped.width), + dtype=np.uint8, + ) + + def observation(self, observation): + """Observation wrapper that displays all observations (board, holder, queue) as one single RGB Image. + + The observation contains the board on the left, the queue on the top right and the holder on the bottom right. + """ + # Board + board_obs = observation["board"] + active_tetromnio_mask = observation["active_tetromino_mask"] + + # make board binary (0-1) + board_obs = np.where(board_obs > 0, 1, 0).astype(np.uint8) + # add active tetromino with value 2 + board_obs[active_tetromnio_mask == 1] = 2 + + board_obs = board_obs[ + 0 : -self.env.unwrapped.padding, + self.env.unwrapped.padding : -self.env.unwrapped.padding, + ] + + # print(board_obs.shape, board_obs.min(), board_obs.max(), board_obs.dtype) + # print(self.observation_space) + return board_obs + + def get_obs(self): + obs = self.env.unwrapped._get_obs() + return self.observation(obs) + + class RgbObservation(gym.ObservationWrapper): """Observation wrapper that displays all observations (board, holder, queue) as one single RGB Image.