-
Notifications
You must be signed in to change notification settings - Fork 17
/
main.py
64 lines (58 loc) · 1.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/python
# -*- coding: utf-8 -*-
from Agent import AgentDiscretePPO
from core import ReplayBuffer
from draw import Painter
from env4Snake import Snake
import random
import pygame
import numpy as np
import torch
import matplotlib.pyplot as plt
def testAgent(test_env,agent,episode):
ep_reward = 0
o = test_env.reset()
for _ in range(650):
if episode % 100 == 0:
test_env.render()
for event in pygame.event.get(): # 不加这句render要卡,不清楚原因
pass
a_int, a_prob = agent.select_action(o)
o2, reward, done, _ = test_env.step(a_int)
ep_reward += reward
if done: break
o = o2
return ep_reward
if __name__ == "__main__":
env = Snake()
test_env = Snake()
act_dim = 4
obs_dim = 6
agent = AgentDiscretePPO()
agent.init(512,obs_dim,act_dim,if_use_gae=True)
agent.state = env.reset()
buffer = ReplayBuffer(2**12,obs_dim,act_dim,True)
MAX_EPISODE = 800
batch_size = 256
rewardList = []
maxReward = -np.inf
for episode in range(MAX_EPISODE):
with torch.no_grad():
trajectory_list = agent.explore_env(env,2**12,1,0.99)
buffer.extend_buffer_from_list(trajectory_list)
agent.update_net(buffer,batch_size,1,2**-8)
ep_reward = testAgent(test_env, agent, episode)
print('Episode:', episode, 'Reward:%f' % ep_reward)
rewardList.append(ep_reward)
if episode > MAX_EPISODE/3 and ep_reward > maxReward:
maxReward = ep_reward
print('保存模型!')
torch.save(agent.act.state_dict(),'act_weight.pkl')
pygame.quit()
painter = Painter(load_csv=True, load_dir='reward.csv')
painter.addData(rewardList, 'PPO')
painter.saveData('reward.csv')
painter.setTitle('snake game reward')
painter.setXlabel('episode')
painter.setYlabel('reward')
painter.drawFigure()