Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alpha zero #30

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions examples/alpha_zero/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np

from examples.alpha_zero.mcts import MCTS
from tetris_gymnasium.envs import Tetris


class MCTSAgent:
"""AI agent based on MCTS"""

def __init__(self, policy_value_function, c_puct=5, n_playout=2000, is_selfplay=0):
self.mcts = MCTS(policy_value_function, c_puct, n_playout)
self._is_selfplay = is_selfplay

def reset_agent(self):
self.mcts.update_with_move(-1)

def get_action(self, env: Tetris, temp=1e-3, return_prob=0):
# the pi vector returned by MCTS as in the alphaGo Zero paper
move_probs = np.zeros(env.action_space.n)
acts, probs = self.mcts.get_move_probs(env, temp)
move_probs[list(acts)] = probs
if self._is_selfplay:
# add Dirichlet Noise for exploration (needed for
# self-play training)
move = np.random.choice(
acts,
p=0.75 * probs + 0.25 * np.random.dirichlet(0.3 * np.ones(len(probs))),
)
# update the root node and reuse the search tree
self.mcts.update_with_move(move)
else:
# with the default temp=1e-3, it is almost equivalent
# to choosing the move with the highest prob
move = np.random.choice(acts, p=probs)
# reset the root node
self.mcts.update_with_move(-1)
# location = board.move_to_location(move)
# print("AI move: %d,%d\n" % (location[0], location[1]))

if return_prob:
return move, move_probs
else:
return move

def __str__(self):
return "MCTS Agent"
173 changes: 173 additions & 0 deletions examples/alpha_zero/mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""
Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value
network to guide the tree search and evaluate the leaf nodes

@author: Junxiao Song
"""

import copy

import numpy as np

from tetris_gymnasium.envs import Tetris


def softmax(x):
probs = np.exp(x - np.max(x))
probs /= np.sum(probs)
return probs


class TreeNode:
"""A node in the MCTS tree.

Each node keeps track of its own value Q, prior probability P, and
its visit-count-adjusted prior score u.
"""

def __init__(self, parent, prior_p):
self._parent = parent
self._children = {} # a map from action to TreeNode
self._n_visits = 0
self._Q = 0
self._u = 0
self._P = prior_p

def expand(self, action_priors):
"""Expand tree by creating new children.
action_priors: a list of tuples of actions and their prior probability
according to the policy function.
"""
for action, prob in action_priors:
if action not in self._children:
self._children[action] = TreeNode(self, prob)

def select(self, c_puct):
"""Select action among children that gives maximum action value Q
plus bonus u(P).
Return: A tuple of (action, next_node)
"""
return max(
self._children.items(), key=lambda act_node: act_node[1].get_value(c_puct)
)

def update(self, leaf_value):
"""Update node values from leaf evaluation.
leaf_value: the value of subtree evaluation from the current player's
perspective.
"""
# Count visit.
self._n_visits += 1
# Update Q, a running average of values for all visits.
self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits

def update_recursive(self, leaf_value):
"""Like a call to update(), but applied recursively for all ancestors."""
# If it is not root, this node's parent should be updated first.
if self._parent:
self._parent.update_recursive(-leaf_value)
self.update(leaf_value)

def get_value(self, c_puct):
"""Calculate and return the value for this node.
It is a combination of leaf evaluations Q, and this node's prior
adjusted for its visit count, u.
c_puct: a number in (0, inf) controlling the relative impact of
value Q, and prior probability P, on this node's score.
"""
self._u = (
c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits)
)
return self._Q + self._u

def is_leaf(self):
"""Check if leaf node (i.e. no nodes below this have been expanded)."""
return self._children == {}

def is_root(self):
return self._parent is None


class MCTS:
"""An implementation of Monte Carlo Tree Search."""

def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
"""
policy_value_fn: a function that takes in a board state and outputs
a list of (action, probability) tuples and also a score in [-1, 1]
(i.e. the expected value of the end game score from the current
player's perspective) for the current player.
c_puct: a number in (0, inf) that controls how quickly exploration
converges to the maximum-value policy. A higher value means
relying on the prior more.
"""
self._root = TreeNode(None, 1.0)
self._policy = policy_value_fn
self._c_puct = c_puct
self._n_playout = n_playout

def _playout(self, env: Tetris):
"""Run a single playout from the root to the leaf, getting a value at
the leaf and propagating it back through its parents.
State is modified in-place, so a copy must be provided.
"""
node = self._root
reward = 0
terminated = False
while 1:
if node.is_leaf():
break
# Greedily select next move.
action, node = node.select(self._c_puct)
obs, reward, terminated, truncated, info = env.step(action)

# Evaluate the leaf using a network which outputs a list of
# (action, probability) tuples p and also a score v in [-1, 1]
# for the current player.
action_probs, leaf_value = self._policy(env)

reward = reward / ((4**2) * 10) # normalize reward
leaf_value = (leaf_value + reward) / 2 # average reward and value to normalize

# Check for end of game.
if not terminated:
node.expand(action_probs)

# Update value and visit count of nodes in this traversal.
node.update_recursive(-leaf_value)

def get_move_probs(self, env, temp=1e-3):
"""Run all playouts sequentially and return the available actions and
their corresponding probabilities.
state: the current game state
temp: temperature parameter in (0, 1] controls the level of exploration
"""

state_copy = env.unwrapped.clone_state()
for n in range(self._n_playout):
env.unwrapped.restore_state(state_copy)
self._playout(env)

env.unwrapped.restore_state(state_copy)

# calc the move probabilities based on visit counts at the root node
act_visits = [
(act, node._n_visits) for act, node in self._root._children.items()
]
acts, visits = zip(*act_visits)
act_probs = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10))

return acts, act_probs

def update_with_move(self, last_move):
"""Step forward in the tree, keeping everything we already know
about the subtree.
"""
if last_move in self._root._children:
self._root = self._root._children[last_move]
self._root._parent = None
else:
self._root = TreeNode(None, 1.0)

def __str__(self):
return "MCTS"
170 changes: 170 additions & 0 deletions examples/alpha_zero/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""
An implementation of the policyValueNet in PyTorch
Tested in PyTorch 0.2.0 and 0.3.0

@author: Junxiao Song
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from tetris_gymnasium.envs import Tetris


def set_learning_rate(optimizer, lr):
"""Sets the learning rate to the given value"""
for param_group in optimizer.param_groups:
param_group["lr"] = lr


class Net(nn.Module):
"""policy-value network module"""

def __init__(self, board_width, board_height, action_size):
super().__init__()

self.board_width = board_width
self.board_height = board_height
# common layers
n_channels = 1
self.conv1 = nn.Conv2d(n_channels, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
# action policy layers
self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1)
self.act_fc1 = nn.Linear(4 * board_width * board_height, action_size)
# state value layers
self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1)
self.val_fc1 = nn.Linear(2 * board_width * board_height, 64)
self.val_fc2 = nn.Linear(64, 1)

def forward(self, state_input):
# common layers
x = F.relu(self.conv1(state_input))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
# action policy layers
x_act = F.relu(self.act_conv1(x))
x_act = x_act.view(-1, 4 * self.board_width * self.board_height)
x_act = F.log_softmax(self.act_fc1(x_act))
# state value layers
x_val = F.relu(self.val_conv1(x))
x_val = x_val.view(-1, 2 * self.board_width * self.board_height)
x_val = F.relu(self.val_fc1(x_val))
x_val = F.tanh(self.val_fc2(x_val))
return x_act, x_val


class PolicyValueNet:
"""policy-value network"""

def __init__(
self, board_width, board_height, action_size, model_file=None, use_gpu=False
):
self.use_gpu = use_gpu
self.board_width = board_width
self.board_height = board_height
self.l2_const = 1e-4 # coef of l2 penalty
# the policy value net module
if self.use_gpu:
self.policy_value_net = Net(board_width, board_height, action_size).cuda()
else:
self.policy_value_net = Net(board_width, board_height, action_size)
self.optimizer = optim.Adam(
self.policy_value_net.parameters(), weight_decay=self.l2_const
)

if model_file:
net_params = torch.load(model_file)
self.policy_value_net.load_state_dict(net_params)

def policy_value(self, state_batch):
"""
input: a batch of states
output: a batch of action probabilities and state values
"""
if self.use_gpu:
state_batch = Variable(torch.FloatTensor(state_batch).cuda())
state_batch = state_batch.unsqueeze(1)
log_act_probs, value = self.policy_value_net(state_batch)
act_probs = np.exp(log_act_probs.data.cpu().numpy())
return act_probs, value.data.cpu().numpy()
else:
state_batch = Variable(torch.FloatTensor(state_batch))
state_batch = state_batch.unsqueeze(1)
log_act_probs, value = self.policy_value_net(state_batch)
act_probs = np.exp(log_act_probs.data.numpy())
return act_probs, value.data.numpy()

def policy_value_fn(self, env: Tetris):
"""
input: board
output: a list of (action, probability) tuples for each available
action and the score of the board state
"""
legal_positions = list(range(env.action_space.n))
current_state = np.ascontiguousarray(np.expand_dims(env.get_obs(), axis=0))
# current_state = np.ascontiguousarray(env._get_obs().reshape(
# -1, 4, self.board_width, self.board_height))
if self.use_gpu:
log_act_probs, value = self.policy_value_net(
Variable(torch.from_numpy(current_state)).cuda().float()
)
act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
value = value.data.cpu().numpy()[0][0]
else:
log_act_probs, value = self.policy_value_net(
Variable(torch.from_numpy(current_state)).float()
)
act_probs = np.exp(log_act_probs.data.numpy().flatten())
value = value.data.numpy()[0][0]
act_probs = zip(legal_positions, act_probs[legal_positions])
return act_probs, value

def train_step(self, state_batch, mcts_probs, z_batch, lr):
"""perform a training step"""
# wrap in Variable
if self.use_gpu:
state_batch = Variable(torch.FloatTensor(state_batch).cuda())
state_batch = state_batch.unsqueeze(1)
mcts_probs = Variable(torch.FloatTensor(mcts_probs).cuda())
z_batch = Variable(torch.FloatTensor(z_batch).cuda())
else:
state_batch = Variable(torch.FloatTensor(state_batch))
state_batch = state_batch.unsqueeze(1)
mcts_probs = Variable(torch.FloatTensor(mcts_probs))
z_batch = Variable(torch.FloatTensor(z_batch))

# zero the parameter gradients
self.optimizer.zero_grad()
# set learning rate
set_learning_rate(self.optimizer, lr)

# forward
log_act_probs, value = self.policy_value_net(state_batch)
# define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
# Note: the L2 penalty is incorporated in optimizer
value_loss = F.mse_loss(value.view(-1), z_batch)
policy_loss = -torch.mean(torch.sum(mcts_probs * log_act_probs, 1))
loss = value_loss + policy_loss
# backward and optimize
loss.backward()
self.optimizer.step()
# calc policy entropy, for monitoring only
entropy = -torch.mean(torch.sum(torch.exp(log_act_probs) * log_act_probs, 1))
# return loss.data[0], entropy.data[0]
# for pytorch version >= 0.5 please use the following line instead.
return loss.item(), entropy.item()

def get_policy_param(self):
net_params = self.policy_value_net.state_dict()
return net_params

def save_model(self, model_file):
"""save model params to file"""
net_params = self.get_policy_param() # get model params
torch.save(net_params, model_file)
Loading