Skip to content

Commit

Permalink
fix(nyz): fix priority buffer delete bug (#844)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Dec 2, 2024
1 parent de9ada0 commit 548406f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
9 changes: 8 additions & 1 deletion ding/data/buffer/deque_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,17 @@ def count(self) -> int:
def get(self, idx: int) -> BufferedData:
"""
Overview:
The method that returns the BufferedData object given a specific index.
The method that returns the BufferedData object by subscript idx (int).
"""
return self.storage[idx]

def get_by_index(self, index: str) -> BufferedData:
"""
Overview:
The method that returns the BufferedData object given a specific index (str).
"""
return self.storage[self.indices.get(index)]

@apply_middleware("clear")
def clear(self) -> None:
"""
Expand Down
12 changes: 6 additions & 6 deletions ding/data/buffer/middleware/priority.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ def update(self, chain: Callable, index: str, data: Any, meta: Any, *args, **kwa
self.max_priority = max(self.max_priority, new_priority)

def delete(self, chain: Callable, index: str, *args, **kwargs) -> None:
for item in self.buffer.storage:
meta = item.meta
priority_idx = meta['priority_idx']
self.sum_tree[priority_idx] = self.sum_tree.neutral_element
self.min_tree[priority_idx] = self.min_tree.neutral_element
self.buffer_idx.pop(priority_idx)
item = self.buffer.get_by_index(index)
meta = item.meta
priority_idx = meta['priority_idx']
self.sum_tree[priority_idx] = self.sum_tree.neutral_element
self.min_tree[priority_idx] = self.min_tree.neutral_element
self.buffer_idx.pop(priority_idx)
return chain(index, *args, **kwargs)

def clear(self, chain: Callable) -> None:
Expand Down
1 change: 1 addition & 0 deletions ding/data/buffer/tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def test_priority():
assert data[0].meta['priority'] == 3.0
buffer.delete(data[0].index)
assert buffer.count() == N + N - 1
assert len(buffer._middleware[0].buffer_idx) == N + N - 1
buffer.clear()
assert buffer.count() == 0

Expand Down

0 comments on commit 548406f

Please sign in to comment.