diff --git a/ding/data/buffer/deque_buffer.py b/ding/data/buffer/deque_buffer.py index 26c7cebc8e..6c4af9cac5 100644 --- a/ding/data/buffer/deque_buffer.py +++ b/ding/data/buffer/deque_buffer.py @@ -239,10 +239,17 @@ def count(self) -> int: def get(self, idx: int) -> BufferedData: """ Overview: - The method that returns the BufferedData object given a specific index. + The method that returns the BufferedData object by subscript idx (int). """ return self.storage[idx] + def get_by_index(self, index: str) -> BufferedData: + """ + Overview: + The method that returns the BufferedData object given a specific index (str). + """ + return self.storage[self.indices.get(index)] + @apply_middleware("clear") def clear(self) -> None: """ diff --git a/ding/data/buffer/middleware/priority.py b/ding/data/buffer/middleware/priority.py index 017b302a5f..b890bac78f 100644 --- a/ding/data/buffer/middleware/priority.py +++ b/ding/data/buffer/middleware/priority.py @@ -108,12 +108,12 @@ def update(self, chain: Callable, index: str, data: Any, meta: Any, *args, **kwa self.max_priority = max(self.max_priority, new_priority) def delete(self, chain: Callable, index: str, *args, **kwargs) -> None: - for item in self.buffer.storage: - meta = item.meta - priority_idx = meta['priority_idx'] - self.sum_tree[priority_idx] = self.sum_tree.neutral_element - self.min_tree[priority_idx] = self.min_tree.neutral_element - self.buffer_idx.pop(priority_idx) + item = self.buffer.get_by_index(index) + meta = item.meta + priority_idx = meta['priority_idx'] + self.sum_tree[priority_idx] = self.sum_tree.neutral_element + self.min_tree[priority_idx] = self.min_tree.neutral_element + self.buffer_idx.pop(priority_idx) return chain(index, *args, **kwargs) def clear(self, chain: Callable) -> None: diff --git a/ding/data/buffer/tests/test_middleware.py b/ding/data/buffer/tests/test_middleware.py index cc19866ee3..94c51ba740 100644 --- a/ding/data/buffer/tests/test_middleware.py +++ b/ding/data/buffer/tests/test_middleware.py @@ -106,6 +106,7 @@ def test_priority(): assert data[0].meta['priority'] == 3.0 buffer.delete(data[0].index) assert buffer.count() == N + N - 1 + assert len(buffer._middleware[0].buffer_idx) == N + N - 1 buffer.clear() assert buffer.count() == 0