-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalue_iteration_agent.py
72 lines (65 loc) · 2.41 KB
/
value_iteration_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import random
from agent import Agent
from blackjack_markov_decision_process import BlackjackMarkovDecisionProcess
class ValueIterationAgent(Agent):
def __init__(self, iterations):
self.mdp = BlackjackMarkovDecisionProcess()
self.values = {}
self.discount = 0.9
states = self.mdp.getStates()
for _ in range(0, iterations):
nextValues = {}
for state in states:
if self.mdp.isTerminal(state):
nextValues[state] = self.values.get(state, 0)
else:
nextValues[state] = max(
map(lambda action: self.getQValue(state, action),
self.mdp.getPossibleActions(state)))
self.values = nextValues
def getNextAction(self, gameState, hand):
state = (False, len(hand.getCards()) == 2, hand.isDoubleDown(),
hand.getHasAce(), hand.getHardCount(),
gameState.getDealerUpCard().getSoftCount())
return self.getActionForState(state)
def getActionForState(self, state):
actions = self.mdp.getPossibleActions(state)
if not actions:
return None
else:
maxValue, maxActions = None, []
for action in actions:
value = self.getQValue(state, action)
if value == maxValue:
maxActions.append(action)
elif value > maxValue:
maxValue, maxActions = value, [action]
return random.choice(maxActions)
def getQValue(self, state, action):
"""
The q-value of the state-action pair.
"""
statesAndProbs = self.mdp.getTransitionStatesAndProbs(state, action)
qValue = 0
for transitionState, prob in statesAndProbs:
reward = self.mdp.getReward(state, action, transitionState)
qValue += prob * (reward + self.discount * self.values.get(transitionState, 0))
return qValue
def printPolicies(self):
for hasAce in [False, True]:
if not hasAce:
labelString = "sc"
else:
labelString = "aces"
for dealerSoftCount in range(2, 12):
labelString = "{0}\t{1}".format(labelString, dealerSoftCount)
print labelString
for hardCount in range(2, 22):
formatString = "{0}".format(hardCount)
for dealerSoftCount in range(2, 12):
state = (False, True, False, hasAce, hardCount, dealerSoftCount)
action = self.getActionForState(state)
formatString += "\t{0}".format(action[0])
print formatString
def __str__(self):
return "Value iteration agent"