-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_memory.py
104 lines (91 loc) · 3.36 KB
/
utils_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import (
Tuple,
)
from SumTree import SumTree
import torch
import numpy as np
import random
from utils_types import (
BatchAction,
BatchDone,
BatchNext,
BatchReward,
BatchState,
BatchIndice,
Batchweight,
BatchPriority,
TensorStack5,
TorchDevice,
)
class Prioritized_ReplayMemory(object):
def __init__(
self,
channels: int,
capacity: int,
alpha : float,
device: TorchDevice,
full_sink: bool = True,
) -> None:
self.__capacity = capacity
self.__alpha = alpha
self.__device = device
self.__size = 0
self.__pos = 0
sink = lambda x: x.to(device) if full_sink else x
self.__m_states = sink(torch.zeros(
(capacity, channels, 84, 84), dtype=torch.uint8))
self.__m_actions = sink(torch.zeros((capacity, 1), dtype=torch.long))
self.__m_rewards = sink(torch.zeros((capacity, 1), dtype=torch.int8))
self.__m_dones = sink(torch.zeros((capacity, 1), dtype=torch.bool))
self.tree = SumTree(capacity)
def push(
self,
folded_state: TensorStack5,
action: int,
reward: int,
done: bool
) -> None:
self.__m_states[self.__pos] = folded_state
self.__m_actions[self.__pos, 0] = action
self.__m_rewards[self.__pos, 0] = reward
self.__m_dones[self.__pos, 0] = done
max_priority = self.tree.max() if self.__size else 1.0
self.tree.add(max_priority, self.__pos)
self.__pos = (self.__pos + 1) % self.__capacity
self.__size += 1
self.__size = min(self.__size, self.__capacity)
def sample(self, batch_size: int, beta: float) -> Tuple[
BatchState,
BatchAction,
BatchReward,
BatchNext,
BatchDone,
BatchIndice,
Batchweight
]:
segment = self.tree.total() / batch_size
priorities = []
indices = []
for i in range(batch_size):
a = segment * i
b = segment * (i + 1)
s = random.uniform(a, b)
priority, data_idx = self.tree.get(s)
priorities.append(priority)
indices.append(data_idx)
probabilities = priorities / self.tree.total()
min_prob = self.tree.min(self.__size) / self.tree.total()
weights = np.power(probabilities / min_prob, -beta)
indices = torch.tensor(indices)
weights = torch.from_numpy(weights).to(self.__device).float()
b_state = self.__m_states[indices, :4].to(self.__device).float()
b_next = self.__m_states[indices, 1:].to(self.__device).float()
b_action = self.__m_actions[indices].to(self.__device)
b_reward = self.__m_rewards[indices].to(self.__device).float()
b_done = self.__m_dones[indices].to(self.__device).float()
return b_state, b_action, b_reward, b_next, b_done, indices, weights
def update_priorities(self, batch_indice: BatchIndice, batch_priority: BatchPriority):
for data_idx, priority in zip(batch_indice, batch_priority):
self.tree.update(data_idx, priority)
def __len__(self) -> int:
return self.__size