-
Notifications
You must be signed in to change notification settings - Fork 0
/
Agent.h
69 lines (61 loc) · 1.57 KB
/
Agent.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
// Implmentation of Policy Evaluation and Policy Improvement
#pragma once
#include "Environment.h"
#include <vector>
using namespace std;
class Agent{
friend int main();
public:
Agent() {
int lambdas[4] = { 3,3,4,2 };
Environment env;
vector<int> temp(env._max_customer + 1, 0);
for (int i = 0; i < env._max_customer + 1; ++i)
policy.push_back(temp);
}
void policy_eval(double theta = 0.5) {
while(true) {
double delta = 0;
for(int i = 0; i < env._max_customer + 1 ; ++i)
for (int j = 0; j < env._max_customer + 1; ++j)
{
double old_val = env.get_v(i, j);
int state[2] = { i,j };
env.set_v(i, j, env.expected_reward(state, policy[i][j]));
delta = max(delta, abs(old_val - env.get_v(i, j)));
}
cout << delta << endl;
if (theta > delta)
break;
}
}
bool policy_iter() {
bool stable = true;
for (int i = 0; i < env._max_customer + 1; ++i)
for (int j = 0; j < env._max_customer + 1; ++j)
{
int old_action = policy[i][j];
int a2b = min(i, 5);
int b2a = -min(j, 5);
double max_value = -1000.0f;
int max_action = -20;
for (int action = b2a; action <= a2b; ++action)
{
double sigma;
int state[2] = { i,j };
sigma = env.expected_reward(state, action);
if (max_value < sigma) {
max_action = action;
max_value = sigma;
}
}
policy[i][j] = max_action;
if (policy[i][j] == old_action)
stable = false;
}
return stable;
}
private:
Environment env;
vector<vector<int>> policy;
};