diff --git a/CHANGELOG.md b/CHANGELOG.md index ceaa298b18..05c28562fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,11 @@ These changes are available on the `master` branch, but have not yet been releas - Added `Guild.fetch_role` method. ([#2528](https://github.com/Pycord-Development/pycord/pull/2528)) +### Fixed + +- Fixed `EntitlementIterator` behavior with `limit > 100`. + ([#2555](https://github.com/Pycord-Development/pycord/pull/2555)) + ## [2.6.0] - 2024-07-09 ### Added diff --git a/discord/iterators.py b/discord/iterators.py index 7404e790a4..13f67266ea 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -64,6 +64,7 @@ from .types.audit_log import AuditLog as AuditLogPayload from .types.guild import Guild as GuildPayload from .types.message import Message as MessagePayload + from .types.monetization import Entitlement as EntitlementPayload from .types.threads import Thread as ThreadPayload from .types.user import PartialUser as PartialUserPayload from .user import User @@ -988,11 +989,21 @@ def __init__( self.guild_id = guild_id self.exclude_ended = exclude_ended + self._filter = None + + if self.before and self.after: + self._retrieve_entitlements = self._retrieve_entitlements_before_strategy + self._filter = lambda e: int(e["id"]) > self.after.id + elif self.after: + self._retrieve_entitlements = self._retrieve_entitlements_after_strategy + else: + self._retrieve_entitlements = self._retrieve_entitlements_before_strategy + self.state = state self.get_entitlements = state.http.list_entitlements self.entitlements = asyncio.Queue() - async def next(self) -> BanEntry: + async def next(self) -> Entitlement: if self.entitlements.empty(): await self.fill_entitlements() @@ -1014,30 +1025,57 @@ async def fill_entitlements(self): if not self._get_retrieve(): return + data = await self._retrieve_entitlements(self.retrieve) + + if self._filter: + data = list(filter(self._filter, data)) + + if len(data) < 100: + self.limit = 0 # terminate loop + + for element in data: + await self.entitlements.put(Entitlement(data=element, state=self.state)) + + async def _retrieve_entitlements(self, retrieve) -> list[Entitlement]: + """Retrieve entitlements and update next parameters.""" + raise NotImplementedError + + async def _retrieve_entitlements_before_strategy( + self, retrieve: int + ) -> list[EntitlementPayload]: + """Retrieve entitlements using before parameter.""" before = self.before.id if self.before else None - after = self.after.id if self.after else None data = await self.get_entitlements( self.state.application_id, before=before, - after=after, - limit=self.retrieve, + limit=retrieve, user_id=self.user_id, guild_id=self.guild_id, sku_ids=self.sku_ids, exclude_ended=self.exclude_ended, ) + if data: + if self.limit is not None: + self.limit -= retrieve + self.before = Object(id=int(data[-1]["id"])) + return data - if not data: - # no data, terminate - return - - if self.limit: - self.limit -= self.retrieve - - if len(data) < 100: - self.limit = 0 # terminate loop - - self.after = Object(id=int(data[-1]["id"])) - - for element in reversed(data): - await self.entitlements.put(Entitlement(data=element, state=self.state)) + async def _retrieve_entitlements_after_strategy( + self, retrieve: int + ) -> list[EntitlementPayload]: + """Retrieve entitlements using after parameter.""" + after = self.after.id if self.after else None + data = await self.get_entitlements( + self.state.application_id, + after=after, + limit=retrieve, + user_id=self.user_id, + guild_id=self.guild_id, + sku_ids=self.sku_ids, + exclude_ended=self.exclude_ended, + ) + if data: + if self.limit is not None: + self.limit -= retrieve + self.after = Object(id=int(data[-1]["id"])) + return data