From 5016d9e4bd231d6d7868bf500f61376f3fc50fca Mon Sep 17 00:00:00 2001 From: Nick Barrett Date: Fri, 15 Jul 2022 14:10:19 +0200 Subject: [PATCH] cache experiments --- .github/workflows/tests.yml | 7 - synapse/replication/slave/storage/devices.py | 18 +- .../replication/slave/storage/push_rule.py | 10 +- synapse/replication/slave/storage/pushers.py | 6 +- synapse/replication/tcp/client.py | 4 +- .../server_notices/server_notices_manager.py | 2 +- synapse/storage/_base.py | 34 +- .../storage/databases/main/account_data.py | 24 +- synapse/storage/databases/main/cache.py | 50 +- synapse/storage/databases/main/deviceinbox.py | 6 +- synapse/storage/databases/main/events.py | 2 +- .../storage/databases/main/events_worker.py | 8 +- synapse/storage/databases/main/keys.py | 3 +- synapse/storage/databases/main/presence.py | 8 +- synapse/storage/databases/main/pusher.py | 2 +- synapse/storage/databases/main/receipts.py | 18 +- synapse/storage/databases/main/roommember.py | 8 +- synapse/storage/databases/main/tags.py | 10 +- synapse/util/caches/deferred_cache.py | 351 --- synapse/util/caches/descriptors.py | 209 +- synapse/util/caches/lrucache.py | 63 +- tests/handlers/test_room_summary.py | 4 +- tests/handlers/test_sync.py | 4 +- tests/rest/admin/test_server_notice.py | 20 +- tests/storage/test_cleanup_extrems.py | 6 +- tests/storage/test_purge.py | 2 +- tests/test_metrics.py | 34 - tests/util/caches/test_deferred_cache.py | 278 --- tests/util/caches/test_descriptors.py | 1898 ++++++++--------- 29 files changed, 1242 insertions(+), 1847 deletions(-) delete mode 100644 synapse/util/caches/deferred_cache.py delete mode 100644 tests/util/caches/test_deferred_cache.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c8b033e8a473..5dd033b27120 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -63,7 +63,6 @@ jobs: trial: if: ${{ !cancelled() && !failure() }} # Allow previous steps to be skipped, but not fail - needs: linting-done runs-on: ubuntu-latest strategy: matrix: @@ -127,7 +126,6 @@ jobs: trial-olddeps: # Note: sqlite only; no postgres if: ${{ !cancelled() && !failure() }} # Allow previous steps to be skipped, but not fail - needs: linting-done runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -155,7 +153,6 @@ jobs: # Very slow; only run if the branch name includes 'pypy' # Note: sqlite only; no postgres. Completely untested since poetry move. if: ${{ contains(github.ref, 'pypy') && !failure() && !cancelled() }} - needs: linting-done runs-on: ubuntu-latest strategy: matrix: @@ -186,7 +183,6 @@ jobs: sytest: if: ${{ !failure() && !cancelled() }} - needs: linting-done runs-on: ubuntu-latest container: image: matrixdotorg/sytest-synapse:${{ matrix.sytest-tag }} @@ -277,7 +273,6 @@ jobs: portdb: if: ${{ !failure() && !cancelled() }} # Allow previous steps to be skipped, but not fail - needs: linting-done runs-on: ubuntu-latest env: TOP: ${{ github.workspace }} @@ -315,7 +310,6 @@ jobs: complement: if: "${{ !failure() && !cancelled() }}" - needs: linting-done runs-on: ubuntu-latest strategy: @@ -349,7 +343,6 @@ jobs: # See https://github.com/matrix-org/synapse/issues/13161 complement-workers: if: "${{ !failure() && !cancelled() }}" - needs: linting-done runs-on: ubuntu-latest steps: diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index a48cc0206944..5771a38fcd25 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -49,19 +49,21 @@ def __init__( def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() - def process_replication_rows( + async def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) - self._invalidate_caches_for_devices(token, rows) + await self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + return await super().process_replication_rows( + stream_name, instance_name, token, rows + ) - def _invalidate_caches_for_devices( + async def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: for row in rows: @@ -70,9 +72,11 @@ def _invalidate_caches_for_devices( # changes. if row.entity.startswith("@"): self._device_list_stream_cache.entity_has_changed(row.entity, token) - self.get_cached_devices_for_user.invalidate((row.entity,)) - self._get_cached_user_device.invalidate((row.entity,)) - self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) + await self.get_cached_devices_for_user.invalidate((row.entity,)) + await self._get_cached_user_device.invalidate((row.entity,)) + await self.get_device_list_last_stream_id_for_remote.invalidate( + (row.entity,) + ) else: self._device_list_federation_stream_cache.entity_has_changed( diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 52ee3f7e58ea..53b1be76a2e4 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -24,13 +24,15 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def get_max_push_rules_stream_id(self) -> int: return self._push_rules_stream_id_gen.get_current_token() - def process_replication_rows( + async def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == PushRulesStream.NAME: self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: - self.get_push_rules_for_user.invalidate((row.user_id,)) - self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) + await self.get_push_rules_for_user.invalidate((row.user_id,)) + await self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + return await super().process_replication_rows( + stream_name, instance_name, token, rows + ) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index de642bba71b0..fb3f5653af1d 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -40,9 +40,11 @@ def __init__( def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows( + async def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == PushersStream.NAME: self._pushers_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + return await super().process_replication_rows( + stream_name, instance_name, token, rows + ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 2f59245058e7..2abeafa2c62f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -144,7 +144,9 @@ async def on_rdata( token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - self.store.process_replication_rows(stream_name, instance_name, token, rows) + await self.store.process_replication_rows( + stream_name, instance_name, token, rows + ) if self.send_handler: await self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 8ecab86ec7d3..4ef1e1e47e3f 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -184,7 +184,7 @@ async def get_or_create_notice_room_for_user(self, user_id: str) -> str: ) room_id = info["room_id"] - self.maybe_get_notice_room_for_user.invalidate((user_id,)) + await self.maybe_get_notice_room_for_user.invalidate((user_id,)) max_id = await self._account_data_handler.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b8c8dcd76bfc..ef6cf29cb44c 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -47,7 +47,7 @@ def __init__( self.database_engine = database.engine self.db_pool = database - def process_replication_rows( + async def process_replication_rows( self, stream_name: str, instance_name: str, @@ -56,7 +56,7 @@ def process_replication_rows( ) -> None: pass - def _invalidate_state_caches( + async def _invalidate_state_caches( self, room_id: str, members_changed: Collection[str] ) -> None: """Invalidates caches that are based on the current state, but does @@ -68,28 +68,34 @@ def _invalidate_state_caches( """ # If there were any membership changes, purge the appropriate caches. for host in {get_domain_from_id(u) for u in members_changed}: - self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) + await self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) if members_changed: - self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) - self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,)) - self._attempt_to_invalidate_cache( + await self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) + await self._attempt_to_invalidate_cache( + "get_current_hosts_in_room", (room_id,) + ) + await self._attempt_to_invalidate_cache( "get_users_in_room_with_profiles", (room_id,) ) - self._attempt_to_invalidate_cache( + await self._attempt_to_invalidate_cache( "get_number_joined_users_in_room", (room_id,) ) - self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,)) + await self._attempt_to_invalidate_cache( + "get_local_users_in_room", (room_id,) + ) for user_id in members_changed: - self._attempt_to_invalidate_cache( + await self._attempt_to_invalidate_cache( "get_user_in_room_with_profile", (room_id, user_id) ) # Purge other caches based on room state. - self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) - self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,)) + await self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) + await self._attempt_to_invalidate_cache( + "get_partial_current_state_ids", (room_id,) + ) - def _attempt_to_invalidate_cache( + async def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] ) -> None: """Attempts to invalidate the cache of the given name, ignoring if the @@ -110,9 +116,9 @@ def _attempt_to_invalidate_cache( return if key is None: - cache.invalidate_all() + await cache.invalidate_all() else: - cache.invalidate(tuple(key)) + await cache.invalidate(tuple(key)) def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 9af9f4f18e19..2fd2a3a4dea7 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -414,7 +414,7 @@ async def ignored_users(self, user_id: str) -> FrozenSet[str]: ) ) - def process_replication_rows( + async def process_replication_rows( self, stream_name: str, instance_name: str, @@ -427,17 +427,19 @@ def process_replication_rows( self._account_data_id_gen.advance(instance_name, token) for row in rows: if not row.room_id: - self.get_global_account_data_by_type_for_user.invalidate( + await self.get_global_account_data_by_type_for_user.invalidate( (row.user_id, row.data_type) ) - self.get_account_data_for_user.invalidate((row.user_id,)) - self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) - self.get_account_data_for_room_and_type.invalidate( + await self.get_account_data_for_user.invalidate((row.user_id,)) + await self.get_account_data_for_room.invalidate( + (row.user_id, row.room_id) + ) + await self.get_account_data_for_room_and_type.invalidate( (row.user_id, row.room_id, row.data_type) ) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - super().process_replication_rows(stream_name, instance_name, token, rows) + await super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict @@ -475,9 +477,9 @@ async def add_account_data_to_room( ) self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) - self.get_account_data_for_room.invalidate((user_id, room_id)) - self.get_account_data_for_room_and_type.prefill( + await self.get_account_data_for_user.invalidate((user_id,)) + await self.get_account_data_for_room.invalidate((user_id, room_id)) + await self.get_account_data_for_room_and_type.prefill( (user_id, room_id, account_data_type), content ) @@ -510,8 +512,8 @@ async def add_account_data_for_user( ) self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) - self.get_global_account_data_by_type_for_user.invalidate( + await self.get_account_data_for_user.invalidate((user_id,)) + await self.get_global_account_data_by_type_for_user.invalidate( (user_id, account_data_type) ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2367ddeea3fd..dc0bd12dfefa 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -119,15 +119,15 @@ def get_all_updated_caches_txn( "get_all_updated_caches", get_all_updated_caches_txn ) - def process_replication_rows( + async def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == EventsStream.NAME: for row in rows: - self._process_event_stream_row(token, row) + await self._process_event_stream_row(token, row) elif stream_name == BackfillStream.NAME: for row in rows: - self._invalidate_caches_for_event( + await self._invalidate_caches_for_event( -token, row.event_id, row.room_id, @@ -150,18 +150,18 @@ def process_replication_rows( room_id = row.keys[0] members_changed = set(row.keys[1:]) - self._invalidate_state_caches(room_id, members_changed) + await self._invalidate_state_caches(room_id, members_changed) else: - self._attempt_to_invalidate_cache(row.cache_func, row.keys) + await self._attempt_to_invalidate_cache(row.cache_func, row.keys) - super().process_replication_rows(stream_name, instance_name, token, rows) + await super().process_replication_rows(stream_name, instance_name, token, rows) - def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: + async def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data if row.type == EventsStreamEventRow.TypeId: assert isinstance(data, EventsStreamEventRow) - self._invalidate_caches_for_event( + await self._invalidate_caches_for_event( token, data.event_id, data.room_id, @@ -176,13 +176,13 @@ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) if data.type == EventTypes.Member: - self.get_rooms_for_user_with_stream_ordering.invalidate( + await self.get_rooms_for_user_with_stream_ordering.invalidate( (data.state_key,) ) else: raise Exception("Unknown events stream row type %s" % (row.type,)) - def _invalidate_caches_for_event( + async def _invalidate_caches_for_event( self, stream_ordering: int, event_id: str, @@ -196,37 +196,37 @@ def _invalidate_caches_for_event( # This invalidates any local in-memory cached event objects, the original # process triggering the invalidation is responsible for clearing any external # cached objects. - self._invalidate_local_get_event_cache(event_id) - self.have_seen_event.invalidate((room_id, event_id)) + await self._invalidate_local_get_event_cache(event_id) + await self.have_seen_event.invalidate((room_id, event_id)) - self.get_latest_event_ids_in_room.invalidate((room_id,)) + await self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) + await self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) # The `_get_membership_from_event_id` is immutable, except for the # case where we look up an event *before* persisting it. - self._get_membership_from_event_id.invalidate((event_id,)) + await self._get_membership_from_event_id.invalidate((event_id,)) if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) if redacts: - self._invalidate_local_get_event_cache(redacts) + await self._invalidate_local_get_event_cache(redacts) # Caches which might leak edits must be invalidated for the event being # redacted. - self.get_relations_for_event.invalidate((redacts,)) - self.get_applicable_edit.invalidate((redacts,)) + await self.get_relations_for_event.invalidate((redacts,)) + await self.get_applicable_edit.invalidate((redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) - self.get_invited_rooms_for_local_user.invalidate((state_key,)) + await self.get_invited_rooms_for_local_user.invalidate((state_key,)) if relates_to: - self.get_relations_for_event.invalidate((relates_to,)) - self.get_aggregation_groups_for_event.invalidate((relates_to,)) - self.get_applicable_edit.invalidate((relates_to,)) - self.get_thread_summary.invalidate((relates_to,)) - self.get_thread_participated.invalidate((relates_to,)) + await self.get_relations_for_event.invalidate((relates_to,)) + await self.get_aggregation_groups_for_event.invalidate((relates_to,)) + await self.get_applicable_edit.invalidate((relates_to,)) + await self.get_thread_summary.invalidate((relates_to,)) + await self.get_thread_participated.invalidate((relates_to,)) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] @@ -242,7 +242,7 @@ async def invalidate_cache_and_stream( if not cache_func: return - cache_func.invalidate(keys) + await cache_func.invalidate(keys) await self.db_pool.runInteraction( "invalidate_cache_and_stream", self._send_invalidation_to_replication, diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 422e0e65ca50..45fe58c10427 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -128,7 +128,7 @@ def __init__( prefilled_cache=device_outbox_prefill, ) - def process_replication_rows( + async def process_replication_rows( self, stream_name: str, instance_name: str, @@ -148,7 +148,9 @@ def process_replication_rows( self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token ) - return super().process_replication_rows(stream_name, instance_name, token, rows) + return await super().process_replication_rows( + stream_name, instance_name, token, rows + ) def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index fa2266ba2036..ffd4d582a4da 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -238,7 +238,7 @@ async def _persist_events_and_state_updates( event_counter.labels(event.type, origin_type, origin_entity).inc() for room_id, latest_event_ids in new_forward_extremities.items(): - self.store.get_latest_event_ids_in_room.prefill( + await self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index f3935bfead96..8875ae2266b0 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -280,7 +280,7 @@ def get_chain_id_txn(txn: Cursor) -> int: id_column="chain_id", ) - def process_replication_rows( + async def process_replication_rows( self, stream_name: str, instance_name: str, @@ -292,7 +292,7 @@ def process_replication_rows( elif stream_name == BackfillStream.NAME: self._backfill_id_gen.advance(instance_name, -token) - super().process_replication_rows(stream_name, instance_name, token, rows) + await super().process_replication_rows(stream_name, instance_name, token, rows) async def have_censored_event(self, event_id: str) -> bool: """Check if an event has been censored, i.e. if the content of the event has been erased @@ -722,8 +722,8 @@ async def _invalidate_get_event_cache(self, event_id: str) -> None: self._event_ref.pop(event_id, None) self._current_event_fetches.pop(event_id, None) - def _invalidate_local_get_event_cache(self, event_id: str) -> None: - self._get_event_cache.invalidate_local((event_id,)) + async def _invalidate_local_get_event_cache(self, event_id: str) -> None: + await self._get_event_cache.invalidate_local((event_id,)) self._event_ref.pop(event_id, None) self._current_event_fetches.pop(event_id, None) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 0a19f607bda1..4bba67f95c1f 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -123,6 +123,7 @@ async def store_server_verify_keys( # param, which is itself the 2-tuple (server_name, key_id). invalidations.append((server_name, key_id)) + print("GOT", invalidations) await self.db_pool.simple_upsert_many( table="server_signature_keys", key_names=("server_name", "key_id"), @@ -139,7 +140,7 @@ async def store_server_verify_keys( invalidate = self._get_server_verify_key.invalidate for i in invalidations: - invalidate((i,)) + await invalidate((i,)) async def store_server_keys_json( self, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 9769a18a9d0c..74168dfef815 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -431,7 +431,7 @@ def take_presence_startup_info(self) -> List[UserPresenceState]: self._presence_on_startup = [] return active_on_startup - def process_replication_rows( + async def process_replication_rows( self, stream_name: str, instance_name: str, @@ -442,5 +442,7 @@ def process_replication_rows( self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) - self._get_presence_for_user.invalidate((row.user_id,)) - return super().process_replication_rows(stream_name, instance_name, token, rows) + await self._get_presence_for_user.invalidate((row.user_id,)) + return await super().process_replication_rows( + stream_name, instance_name, token, rows + ) diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index bd0cfa7f3211..12972457a67c 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -499,7 +499,7 @@ async def add_pusher( lock=False, ) - user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate( + user_has_pusher = await self.get_if_user_has_pusher.cache.get_immediate( (user_id,), None, update_metrics=False ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 0090c9f22512..d775caf42eb2 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -577,19 +577,21 @@ def get_all_updated_receipts_txn( "get_all_updated_receipts", get_all_updated_receipts_txn ) - def invalidate_caches_for_receipt( + async def invalidate_caches_for_receipt( self, room_id: str, receipt_type: str, user_id: str ) -> None: - self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type)) - self._get_linearized_receipts_for_room.invalidate((room_id,)) + await self._get_receipts_for_user_with_orderings.invalidate( + (user_id, receipt_type) + ) + await self._get_linearized_receipts_for_room.invalidate((room_id,)) # We use this method to invalidate so that we don't end up with circular # dependencies between the receipts and push action stores. - self._attempt_to_invalidate_cache( + await self._attempt_to_invalidate_cache( "get_unread_event_push_actions_by_room_for_user", (room_id,) ) - def process_replication_rows( + async def process_replication_rows( self, stream_name: str, instance_name: str, @@ -599,12 +601,14 @@ def process_replication_rows( if stream_name == ReceiptsStream.NAME: self._receipts_id_gen.advance(instance_name, token) for row in rows: - self.invalidate_caches_for_receipt( + await self.invalidate_caches_for_receipt( row.room_id, row.receipt_type, row.user_id ) self._receipts_stream_cache.entity_has_changed(row.room_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + return await super().process_replication_rows( + stream_name, instance_name, token, rows + ) def _insert_linearized_receipt_txn( self, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 105a518677b2..47a8493e5272 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -823,8 +823,10 @@ async def _get_joined_users_from_context( # If we do then we can reuse that result and simply update it with # any membership changes in `delta_ids` if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get_immediate( - (room_id, context.prev_group), None + prev_res = ( + await self._get_joined_users_from_context.cache.get_immediate( + (room_id, context.prev_group), None + ) ) if prev_res and isinstance(prev_res, dict): users_in_room = dict(prev_res) @@ -967,7 +969,7 @@ async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: # First we check if we already have `get_users_in_room` in the cache, as # we can just calculate result from that - users = self.get_users_in_room.cache.get_immediate( + users = await self.get_users_in_room.cache.get_immediate( (room_id,), None, update_metrics=False ) if users is not None: diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index b0f5de67a30d..c97c827c20b4 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -222,7 +222,7 @@ def add_tag_txn(txn: LoggingTransaction, next_id: int) -> None: async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) - self.get_tags_for_user.invalidate((user_id,)) + await self.get_tags_for_user.invalidate((user_id,)) return self._account_data_id_gen.get_current_token() @@ -246,7 +246,7 @@ def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) - self.get_tags_for_user.invalidate((user_id,)) + await self.get_tags_for_user.invalidate((user_id,)) return self._account_data_id_gen.get_current_token() @@ -292,7 +292,7 @@ def _update_revision_txn( # than the id that the client has. pass - def process_replication_rows( + async def process_replication_rows( self, stream_name: str, instance_name: str, @@ -302,10 +302,10 @@ def process_replication_rows( if stream_name == TagAccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) + await self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - super().process_replication_rows(stream_name, instance_name, token, rows) + await super().process_replication_rows(stream_name, instance_name, token, rows) class TagsStore(TagsWorkerStore): diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py deleted file mode 100644 index 1d6ec22191a0..000000000000 --- a/synapse/util/caches/deferred_cache.py +++ /dev/null @@ -1,351 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import enum -import threading -from typing import ( - Callable, - Generic, - Iterable, - MutableMapping, - Optional, - Sized, - TypeVar, - Union, - cast, -) - -from prometheus_client import Gauge - -from twisted.internet import defer -from twisted.python import failure -from twisted.python.failure import Failure - -from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry - -cache_pending_metric = Gauge( - "synapse_util_caches_cache_pending", - "Number of lookups currently pending for this cache", - ["name"], -) - -T = TypeVar("T") -KT = TypeVar("KT") -VT = TypeVar("VT") - - -class _Sentinel(enum.Enum): - # defining a sentinel in this way allows mypy to correctly handle the - # type of a dictionary lookup. - sentinel = object() - - -class DeferredCache(Generic[KT, VT]): - """Wraps an LruCache, adding support for Deferred results. - - It expects that each entry added with set() will be a Deferred; likewise get() - will return a Deferred. - """ - - __slots__ = ( - "cache", - "thread", - "_pending_deferred_cache", - ) - - def __init__( - self, - name: str, - max_entries: int = 1000, - tree: bool = False, - iterable: bool = False, - apply_cache_factor_from_config: bool = True, - prune_unread_entries: bool = True, - ): - """ - Args: - name: The name of the cache - max_entries: Maximum amount of entries that the cache will hold - tree: Use a TreeCache instead of a dict as the underlying cache type - iterable: If True, count each item in the cached object as an entry, - rather than each cached object - apply_cache_factor_from_config: Whether cache factors specified in the - config file affect `max_entries` - prune_unread_entries: If True, cache entries that haven't been read recently - will be evicted from the cache in the background. Set to False to - opt-out of this behaviour. - """ - cache_type = TreeCache if tree else dict - - # _pending_deferred_cache maps from the key value to a `CacheEntry` object. - self._pending_deferred_cache: Union[ - TreeCache, "MutableMapping[KT, CacheEntry]" - ] = cache_type() - - def metrics_cb() -> None: - cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) - - # cache is used for completed results and maps to the result itself, rather than - # a Deferred. - self.cache: LruCache[KT, VT] = LruCache( - max_size=max_entries, - cache_name=name, - cache_type=cache_type, - size_callback=( - (lambda d: len(cast(Sized, d)) or 1) - # Argument 1 to "len" has incompatible type "VT"; expected "Sized" - # We trust that `VT` is `Sized` when `iterable` is `True` - if iterable - else None - ), - metrics_collection_callback=metrics_cb, - apply_cache_factor_from_config=apply_cache_factor_from_config, - prune_unread_entries=prune_unread_entries, - ) - - self.thread: Optional[threading.Thread] = None - - @property - def max_entries(self) -> int: - return self.cache.max_size - - def check_thread(self) -> None: - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get( - self, - key: KT, - callback: Optional[Callable[[], None]] = None, - update_metrics: bool = True, - ) -> defer.Deferred: - """Looks the key up in the caches. - - For symmetry with set(), this method does *not* follow the synapse logcontext - rules: the logcontext will not be cleared on return, and the Deferred will run - its callbacks in the sentinel context. In other words: wrap the result with - make_deferred_yieldable() before `await`ing it. - - Args: - key: - callback: Gets called when the entry in the cache is invalidated - update_metrics (bool): whether to update the cache hit rate metrics - - Returns: - A Deferred which completes with the result. Note that this may later fail - if there is an ongoing set() operation which later completes with a failure. - - Raises: - KeyError if the key is not found in the cache - """ - callbacks = [callback] if callback else [] - val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) - if val is not _Sentinel.sentinel: - val.callbacks.update(callbacks) - if update_metrics: - m = self.cache.metrics - assert m # we always have a name, so should always have metrics - m.inc_hits() - return val.deferred.observe() - - val2 = self.cache.get( - key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics - ) - if val2 is _Sentinel.sentinel: - raise KeyError() - else: - return defer.succeed(val2) - - def get_immediate( - self, key: KT, default: T, update_metrics: bool = True - ) -> Union[VT, T]: - """If we have a *completed* cached value, return it.""" - return self.cache.get(key, default, update_metrics=update_metrics) - - def set( - self, - key: KT, - value: "defer.Deferred[VT]", - callback: Optional[Callable[[], None]] = None, - ) -> defer.Deferred: - """Adds a new entry to the cache (or updates an existing one). - - The given `value` *must* be a Deferred. - - First any existing entry for the same key is invalidated. Then a new entry - is added to the cache for the given key. - - Until the `value` completes, calls to `get()` for the key will also result in an - incomplete Deferred, which will ultimately complete with the same result as - `value`. - - If `value` completes successfully, subsequent calls to `get()` will then return - a completed deferred with the same result. If it *fails*, the cache is - invalidated and subequent calls to `get()` will raise a KeyError. - - If another call to `set()` happens before `value` completes, then (a) any - invalidation callbacks registered in the interim will be called, (b) any - `get()`s in the interim will continue to complete with the result from the - *original* `value`, (c) any future calls to `get()` will complete with the - result from the *new* `value`. - - It is expected that `value` does *not* follow the synapse logcontext rules - ie, - if it is incomplete, it runs its callbacks in the sentinel context. - - Args: - key: Key to be set - value: a deferred which will complete with a result to add to the cache - callback: An optional callback to be called when the entry is invalidated - """ - if not isinstance(value, defer.Deferred): - raise TypeError("not a Deferred") - - callbacks = [callback] if callback else [] - self.check_thread() - - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry: - existing_entry.invalidate() - - # XXX: why don't we invalidate the entry in `self.cache` yet? - - # we can save a whole load of effort if the deferred is ready. - if value.called: - result = value.result - if not isinstance(result, failure.Failure): - self.cache.set(key, cast(VT, result), callbacks) - return value - - # otherwise, we'll add an entry to the _pending_deferred_cache for now, - # and add callbacks to add it to the cache properly later. - - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) - - self._pending_deferred_cache[key] = entry - - def compare_and_pop() -> bool: - """Check if our entry is still the one in _pending_deferred_cache, and - if so, pop it. - - Returns true if the entries matched. - """ - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - return True - - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - - return False - - def cb(result: VT) -> None: - if compare_and_pop(): - self.cache.set(key, result, entry.callbacks) - else: - # we're not going to put this entry into the cache, so need - # to make sure that the invalidation callbacks are called. - # That was probably done when _pending_deferred_cache was - # updated, but it's possible that `set` was called without - # `invalidate` being previously called, in which case it may - # not have been. Either way, let's double-check now. - entry.invalidate() - - def eb(_fail: Failure) -> None: - compare_and_pop() - entry.invalidate() - - # once the deferred completes, we can move the entry from the - # _pending_deferred_cache to the real cache. - # - observer.addCallbacks(cb, eb) - - # we return a new Deferred which will be called before any subsequent observers. - return observable.observe() - - def prefill( - self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None - ) -> None: - callbacks = [callback] if callback else [] - self.cache.set(key, value, callbacks=callbacks) - - def invalidate(self, key: KT) -> None: - """Delete a key, or tree of entries - - If the cache is backed by a regular dict, then "key" must be of - the right type for this cache - - If the cache is backed by a TreeCache, then "key" must be a tuple, but - may be of lower cardinality than the TreeCache - in which case the whole - subtree is deleted. - """ - self.check_thread() - self.cache.del_multi(key) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, which will (a) stop it being returned - # for future queries and (b) stop it being persisted as a proper entry - # in self.cache. - entry = self._pending_deferred_cache.pop(key, None) - - # run the invalidation callbacks now, rather than waiting for the - # deferred to resolve. - if entry: - # _pending_deferred_cache.pop should either return a CacheEntry, or, in the - # case of a TreeCache, a dict of keys to cache entries. Either way calling - # iterate_tree_cache_entry on it will do the right thing. - for entry in iterate_tree_cache_entry(entry): - entry.invalidate() - - def invalidate_all(self) -> None: - self.check_thread() - self.cache.clear() - for entry in self._pending_deferred_cache.values(): - entry.invalidate() - self._pending_deferred_cache.clear() - - -class CacheEntry: - __slots__ = ["deferred", "callbacks", "invalidated"] - - def __init__( - self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] - ): - self.deferred = deferred - self.callbacks = set(callbacks) - self.invalidated = False - - def invalidate(self) -> None: - if not self.invalidated: - self.invalidated = True - for callback in self.callbacks: - callback() - self.callbacks.clear() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 867f315b2ace..ecce4a7d7cc2 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -28,6 +28,7 @@ Mapping, Optional, Sequence, + Sized, Tuple, Type, TypeVar, @@ -37,13 +38,10 @@ from weakref import WeakValueDictionary from twisted.internet import defer -from twisted.python.failure import Failure from synapse.logging.context import make_deferred_yieldable, preserve_fn -from synapse.util import unwrapFirstError -from synapse.util.async_helpers import delay_cancellation -from synapse.util.caches.deferred_cache import DeferredCache -from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.lrucache import AsyncLruCache, LruCache +from synapse.util.caches.treecache import TreeCache logger = logging.getLogger(__name__) @@ -246,7 +244,7 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any: return wrapped -class DeferredCacheDescriptor(_CacheDescriptorBase): +class AsyncCacheDescriptor(_CacheDescriptorBase): """A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that @@ -320,11 +318,19 @@ def __init__( self.prune_unread_entries = prune_unread_entries def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: - cache: DeferredCache[CacheKey, Any] = DeferredCache( - name=self.orig.__name__, - max_entries=self.max_entries, - tree=self.tree, - iterable=self.iterable, + cache_type = TreeCache if self.tree else dict + + cache: AsyncLruCache[CacheKey, Any] = AsyncLruCache( + max_size=self.max_entries, + cache_name=self.orig.__name__, + cache_type=cache_type, + size_callback=( + (lambda d: len(cast(Sized, d)) or 1) + # Argument 1 to "len" has incompatible type "VT"; expected "Sized" + # We trust that `VT` is `Sized` when `iterable` is `True` + if self.iterable + else None + ), prune_unread_entries=self.prune_unread_entries, ) @@ -332,15 +338,22 @@ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., An @functools.wraps(self.orig) def _wrapped(*args: Any, **kwargs: Any) -> Any: - # If we're passed a cache_context then we'll want to call its invalidate() - # whenever we are invalidated - invalidate_callback = kwargs.pop("on_invalidate", None) + async def _deferred(): + # If we're passed a cache_context then we'll want to call its invalidate() + # whenever we are invalidated + invalidate_callback = kwargs.pop("on_invalidate", None) + callbacks = [invalidate_callback] if invalidate_callback else [] - cache_key = get_cache_key(args, kwargs) + cache_key = get_cache_key(args, kwargs) + + default = object() + cached_value = await cache.get( + cache_key, callbacks=callbacks, default=default + ) + + if cached_value is not default: + return cached_value - try: - ret = cache.get(cache_key, callback=invalidate_callback) - except KeyError: # Add our own `cache_context` to argument list if the wrapped function # has asked for one if self.add_cache_context: @@ -348,15 +361,12 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any: cache, cache_key ) - ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) - ret = cache.set(cache_key, ret, callback=invalidate_callback) + value = await preserve_fn(self.orig)(obj, *args, **kwargs) + await cache.set(cache_key, value, callbacks=callbacks) - # We started a new call to `self.orig`, so we must always wait for it to - # complete. Otherwise we might mark our current logging context as - # finished while `self.orig` is still using it in the background. - ret = delay_cancellation(ret) + return value - return make_deferred_yieldable(ret) + return make_deferred_yieldable(defer.ensureDeferred(_deferred())) wrapped = cast(_CachedFunction, _wrapped) @@ -377,7 +387,7 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any: return wrapped -class DeferredCacheListDescriptor(_CacheDescriptorBase): +class AsyncCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given an iterable of keys it looks in the cache to find any hits, then passes @@ -422,108 +432,69 @@ def __get__( self, obj: Optional[Any], objtype: Optional[Type] = None ) -> Callable[..., "defer.Deferred[Dict[Hashable, Any]]"]: cached_method = getattr(obj, self.cached_method_name) - cache: DeferredCache[CacheKey, Any] = cached_method.cache + cache: AsyncLruCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": - # If we're passed a cache_context then we'll want to call its - # invalidate() whenever we are invalidated - invalidate_callback = kwargs.pop("on_invalidate", None) + async def _deferred(): + # If we're passed a cache_context then we'll want to call its + # invalidate() whenever we are invalidated + invalidate_callback = kwargs.pop("on_invalidate", None) + callbacks = [invalidate_callback] if invalidate_callback else [] - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] - list_args = arg_dict[self.list_name] + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + list_args = arg_dict[self.list_name] - results = {} + results = {} - def update_results_dict(res: Any, arg: Hashable) -> None: - results[arg] = res + def update_results_dict(res: Any, arg: Hashable) -> None: + results[arg] = res - # list of deferreds to wait for - cached_defers = [] + missing = set() - missing = set() + # If the cache takes a single arg then that is used as the key, + # otherwise a tuple is used. + if num_args == 1: - # If the cache takes a single arg then that is used as the key, - # otherwise a tuple is used. - if num_args == 1: + def arg_to_cache_key(arg: Hashable) -> Hashable: + return arg - def arg_to_cache_key(arg: Hashable) -> Hashable: - return arg + else: + keylist = list(keyargs) - else: - keylist = list(keyargs) - - def arg_to_cache_key(arg: Hashable) -> Hashable: - keylist[self.list_pos] = arg - return tuple(keylist) - - for arg in list_args: - try: - res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) - if not res.called: - res.addCallback(update_results_dict, arg) - cached_defers.append(res) - else: - results[arg] = res.result - except KeyError: - missing.add(arg) - - if missing: - # we need a deferred for each entry in the list, - # which we put in the cache. Each deferred resolves with the - # relevant result for that key. - deferreds_map = {} - for arg in missing: - deferred: "defer.Deferred[Any]" = defer.Deferred() - deferreds_map[arg] = deferred - key = arg_to_cache_key(arg) - cached_defers.append( - cache.set(key, deferred, callback=invalidate_callback) + def arg_to_cache_key(arg: Hashable) -> Hashable: + keylist[self.list_pos] = arg + return tuple(keylist) + + default = object() + + for arg in list_args: + res = await cache.get( + arg_to_cache_key(arg), callbacks=callbacks, default=default ) + if res is not default: + results[arg] = res + else: + missing.add(arg) - def complete_all(res: Dict[Hashable, Any]) -> None: - # the wrapped function has completed. It returns a dict. - # We can now update our own result map, and then resolve the - # observable deferreds in the cache. - for e, d1 in deferreds_map.items(): - val = res.get(e, None) - # make sure we update the results map before running the - # deferreds, because as soon as we run the last deferred, the - # gatherResults() below will complete and return the result - # dict to our caller. - results[e] = val - d1.callback(val) - - def errback_all(f: Failure) -> None: - # the wrapped function has failed. Propagate the failure into - # the cache, which will invalidate the entry, and cause the - # relevant cached_deferreds to fail, which will propagate the - # failure to our caller. - for d1 in deferreds_map.values(): - d1.errback(f) - - args_to_call = dict(arg_dict) - args_to_call[self.list_name] = missing - - # dispatch the call, and attach the two handlers - defer.maybeDeferred( - preserve_fn(self.orig), **args_to_call - ).addCallbacks(complete_all, errback_all) - - if cached_defers: - d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( - lambda _: results, unwrapFirstError - ) if missing: - # We started a new call to `self.orig`, so we must always wait for it to - # complete. Otherwise we might mark our current logging context as - # finished while `self.orig` is still using it in the background. - d = delay_cancellation(d) - return make_deferred_yieldable(d) - else: - return defer.succeed(results) + args_to_call = dict(arg_dict) + args_to_call[self.list_name] = missing + + missing_values = await preserve_fn(self.orig)(**args_to_call) + + for key in missing: + value = missing_values.get(key) + results[key] = value + await cache.set( + arg_to_cache_key(key), value, callbacks=callbacks + ) + + return results + + return make_deferred_yieldable(defer.ensureDeferred(_deferred())) obj.__dict__[self.orig.__name__] = wrapped @@ -537,7 +508,7 @@ class _CacheContext: on a lower level. """ - Cache = Union[DeferredCache, LruCache] + Cache = Union[AsyncLruCache, LruCache] _cache_context_objects: """WeakValueDictionary[ Tuple["_CacheContext.Cache", CacheKey], "_CacheContext" @@ -547,9 +518,9 @@ def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None: self._cache = cache self._cache_key = cache_key - def invalidate(self) -> None: + async def invalidate(self) -> None: """Invalidates the cache entry referred to by the context.""" - self._cache.invalidate(self._cache_key) + await self._cache.invalidate(self._cache_key) @classmethod def get_instance( @@ -578,7 +549,7 @@ def cached( iterable: bool = False, prune_unread_entries: bool = True, ) -> Callable[[F], _CachedFunction[F]]: - func = lambda orig: DeferredCacheDescriptor( + func = lambda orig: AsyncCacheDescriptor( orig, max_entries=max_entries, num_args=num_args, @@ -595,7 +566,7 @@ def cached( def cachedList( *, cached_method_name: str, list_name: str, num_args: Optional[int] = None ) -> Callable[[F], _CachedFunction[F]]: - """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. + """Creates a descriptor that wraps a function in a `AsyncCacheListDescriptor`. Used to do batch lookups for an already created cache. One of the arguments is specified as a list that is iterated through to lookup keys in the @@ -623,7 +594,7 @@ def do_something(self, first_arg, second_arg): def batch_do_something(self, first_arg, second_args): ... """ - func = lambda orig: DeferredCacheListDescriptor( + func = lambda orig: AsyncCacheListDescriptor( orig, cached_method_name=cached_method_name, list_name=list_name, diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 31f41fec8284..4bd72e99a98b 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -16,6 +16,7 @@ import math import threading import weakref +from collections import defaultdict from enum import Enum from functools import wraps from typing import ( @@ -43,6 +44,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import get_jemalloc_stats from synapse.util import Clock, caches +from synapse.util.async_helpers import maybe_awaitable from synapse.util.caches import CacheMetric, EvictionReason, register_cache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.linked_list import ListNode @@ -740,31 +742,78 @@ class AsyncLruCache(Generic[KT, VT]): utilize external cache systems that require await behaviour to be created. """ - def __init__(self, *args, **kwargs): # type: ignore - self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs) + def __init__(self, cache_name: str, *args, **kwargs): # type: ignore + self.name = cache_name + self._lru_cache: LruCache[KT, VT] = LruCache( + cache_name=cache_name, *args, **kwargs + ) + self.key_to_callbacks = defaultdict(list) async def get( - self, key: KT, default: Optional[T] = None, update_metrics: bool = True + self, + key: KT, + default: Optional[T] = None, + callbacks: Collection[Callable[[], None]] = (), + update_metrics: bool = True, + ) -> Optional[VT]: + self.key_to_callbacks[key].extend(callbacks) + return self._lru_cache.get(key, default=default, update_metrics=update_metrics) + + async def get_immediate( + self, + key: KT, + default: Optional[T] = None, + callbacks: Collection[Callable[[], None]] = (), + update_metrics: bool = True, ) -> Optional[VT]: - return self._lru_cache.get(key, update_metrics=update_metrics) + self.key_to_callbacks[key].extend(callbacks) + return self._lru_cache.get(key, default=default, update_metrics=update_metrics) - async def set(self, key: KT, value: VT) -> None: + async def set( + self, + key: KT, + value: VT, + callbacks: Collection[Callable[[], None]] = (), + ) -> None: + default = object() + current_value = await self.get(key, default=default) + if current_value is not default and current_value != value: + for callback in self.key_to_callbacks.pop(key, []): + await maybe_awaitable(callback()) + self.key_to_callbacks[key] = list(callbacks) self._lru_cache.set(key, value) async def invalidate(self, key: KT) -> None: - # This method should invalidate any external cache and then invalidate the LruCache. + for callback in self.key_to_callbacks.pop(key, []): + await maybe_awaitable(callback()) return self._lru_cache.invalidate(key) - def invalidate_local(self, key: KT) -> None: + async def invalidate_local(self, key: KT) -> None: """Remove an entry from the local cache This variant of `invalidate` is useful if we know that the external cache has already been invalidated. """ + for callback in self.key_to_callbacks.pop(key, []): + await maybe_awaitable(callback()) return self._lru_cache.invalidate(key) + async def invalidate_all(self) -> None: + for callbacks in self.key_to_callbacks.values(): + for cb in callbacks: + await maybe_awaitable(cb()) + self.key_to_callbacks = defaultdict(list) + return self._lru_cache.clear() + + async def prefill(self, *args, **kwargs) -> None: + return self._lru_cache.set(*args, **kwargs) + async def contains(self, key: KT) -> bool: return self._lru_cache.contains(key) async def clear(self) -> None: + for callbacks in self.key_to_callbacks.values(): + for cb in callbacks: + await maybe_awaitable(cb()) + self.key_to_callbacks = defaultdict(list) self._lru_cache.clear() diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index aa650756e40b..b6f27134b9f7 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -675,7 +675,9 @@ def test_unknown_room_version(self): ) # Invalidate method so that it returns the currently updated version # instead of the cached version. - self.hs.get_datastores().main.get_room_version_id.invalidate((self.room,)) + self.get_success( + self.hs.get_datastores().main.get_room_version_id.invalidate((self.room,)) + ) # The result should have only the space, along with a link from space -> room. expected = [(self.space, [self.room])] diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index e3f38fbcc5ce..952e6d49a6f8 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -158,7 +158,9 @@ def test_unknown_room_version(self): ) # Blow away caches (supported room versions can only change due to a restart). - self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() + self.get_success( + self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() + ) self.get_success(self.store._get_event_cache.clear()) self.store._event_ref.clear() diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index dbcba2663c5a..08413ba7144b 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -214,7 +214,9 @@ def test_send_server_notice(self) -> None: self.assertEqual(messages[0]["sender"], "@notices:test") # invalidate cache of server notices room_ids - self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) # send second message channel = self.make_request( @@ -289,7 +291,9 @@ def test_send_server_notice_leave_room(self) -> None: # invalidate cache of server notices room_ids # if server tries to send to a cached room_id the user gets the message # in old room - self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) # send second message channel = self.make_request( @@ -376,7 +380,9 @@ def test_send_server_notice_delete_room(self) -> None: # invalidate cache of server notices room_ids # if server tries to send to a cached room_id it gives an error - self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) # send second message channel = self.make_request( @@ -432,7 +438,9 @@ def test_update_notice_user_name_when_changed(self) -> None: self.server_notices_manager._config.servernotices.server_notices_mxid_display_name = ( new_display_name ) - self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all() + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all() + ) self.make_request( "POST", @@ -478,7 +486,9 @@ def test_update_notice_user_avatar_when_changed(self) -> None: self.server_notices_manager._config.servernotices.server_notices_mxid_avatar_url = ( new_avatar_url ) - self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all() + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all() + ) self.make_request( "POST", diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index b998ad42d90f..0bc4f83f1d02 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -80,8 +80,10 @@ def add_extremity(self, room_id: str, event_id: str) -> None: ) ) - self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate( - (room_id,) + self.get_success( + self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate( + (room_id,) + ) ) def test_soft_failed_extremities_handled_correctly(self): diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 9c1182ed16b5..eb7aa94ebc6a 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -115,6 +115,6 @@ def test_purge_room(self): ) # The events aren't found. - self.store._invalidate_local_get_event_cache(create_event.event_id) + self.get_success(self.store._invalidate_local_get_event_cache(create_event.event_id)) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index b4574b2ffed2..e9e46130856c 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -14,7 +14,6 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.deferred_cache import DeferredCache from tests import unittest @@ -129,36 +128,3 @@ def test_get_build(self): self.assertTrue(b"osversion=" in items[0]) self.assertTrue(b"pythonversion=" in items[0]) self.assertTrue(b"version=" in items[0]) - - -class CacheMetricsTests(unittest.HomeserverTestCase): - def test_cache_metric(self): - """ - Caches produce metrics reflecting their state when scraped. - """ - CACHE_NAME = "cache_metrics_test_fgjkbdfg" - cache = DeferredCache(CACHE_NAME, max_entries=777) - - items = { - x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") - for x in filter( - lambda x: b"cache_metrics_test_fgjkbdfg" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - } - - self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") - self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") - - cache.prefill("1", "hi") - - items = { - x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") - for x in filter( - lambda x: b"cache_metrics_test_fgjkbdfg" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - } - - self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") - self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py deleted file mode 100644 index 02b99b466a26..000000000000 --- a/tests/util/caches/test_deferred_cache.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial - -from twisted.internet import defer - -from synapse.util.caches.deferred_cache import DeferredCache - -from tests.unittest import TestCase - - -class DeferredCacheTestCase(TestCase): - def test_empty(self): - cache = DeferredCache("test") - with self.assertRaises(KeyError): - cache.get("foo") - - def test_hit(self): - cache = DeferredCache("test") - cache.prefill("foo", 123) - - self.assertEqual(self.successResultOf(cache.get("foo")), 123) - - def test_hit_deferred(self): - cache = DeferredCache("test") - origin_d = defer.Deferred() - set_d = cache.set("k1", origin_d) - - # get should return an incomplete deferred - get_d = cache.get("k1") - self.assertFalse(get_d.called) - - # add a callback that will make sure that the set_d gets called before the get_d - def check1(r): - self.assertTrue(set_d.called) - return r - - get_d.addCallback(check1) - - # now fire off all the deferreds - origin_d.callback(99) - self.assertEqual(self.successResultOf(origin_d), 99) - self.assertEqual(self.successResultOf(set_d), 99) - self.assertEqual(self.successResultOf(get_d), 99) - - def test_callbacks(self): - """Invalidation callbacks are called at the right time""" - cache = DeferredCache("test") - callbacks = set() - - # start with an entry, with a callback - cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) - - # now replace that entry with a pending result - origin_d = defer.Deferred() - set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) - - # ... and also make a get request - get_d = cache.get("k1", callback=lambda: callbacks.add("get")) - - # we don't expect the invalidation callback for the original value to have - # been called yet, even though get() will now return a different result. - # I'm not sure if that is by design or not. - self.assertEqual(callbacks, set()) - - # now fire off all the deferreds - origin_d.callback(20) - self.assertEqual(self.successResultOf(set_d), 20) - self.assertEqual(self.successResultOf(get_d), 20) - - # now the original invalidation callback should have been called, but none of - # the others - self.assertEqual(callbacks, {"prefill"}) - callbacks.clear() - - # another update should invalidate both the previous results - cache.prefill("k1", 30) - self.assertEqual(callbacks, {"set", "get"}) - - def test_set_fail(self): - cache = DeferredCache("test") - callbacks = set() - - # start with an entry, with a callback - cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) - - # now replace that entry with a pending result - origin_d = defer.Deferred() - set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) - - # ... and also make a get request - get_d = cache.get("k1", callback=lambda: callbacks.add("get")) - - # none of the callbacks should have been called yet - self.assertEqual(callbacks, set()) - - # oh noes! fails! - e = Exception("oops") - origin_d.errback(e) - self.assertIs(self.failureResultOf(set_d, Exception).value, e) - self.assertIs(self.failureResultOf(get_d, Exception).value, e) - - # the callbacks for the failed requests should have been called. - # I'm not sure if this is deliberate or not. - self.assertEqual(callbacks, {"get", "set"}) - callbacks.clear() - - # the old value should still be returned now? - get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2")) - self.assertEqual(self.successResultOf(get_d2), 10) - - # replacing the value now should run the callbacks for those requests - # which got the original result - cache.prefill("k1", 30) - self.assertEqual(callbacks, {"prefill", "get2"}) - - def test_get_immediate(self): - cache = DeferredCache("test") - d1 = defer.Deferred() - cache.set("key1", d1) - - # get_immediate should return default - v = cache.get_immediate("key1", 1) - self.assertEqual(v, 1) - - # now complete the set - d1.callback(2) - - # get_immediate should return result - v = cache.get_immediate("key1", 1) - self.assertEqual(v, 2) - - def test_invalidate(self): - cache = DeferredCache("test") - cache.prefill(("foo",), 123) - cache.invalidate(("foo",)) - - with self.assertRaises(KeyError): - cache.get(("foo",)) - - def test_invalidate_all(self): - cache = DeferredCache("testcache") - - callback_record = [False, False] - - def record_callback(idx): - callback_record[idx] = True - - # add a couple of pending entries - d1 = defer.Deferred() - cache.set("key1", d1, partial(record_callback, 0)) - - d2 = defer.Deferred() - cache.set("key2", d2, partial(record_callback, 1)) - - # lookup should return pending deferreds - self.assertFalse(cache.get("key1").called) - self.assertFalse(cache.get("key2").called) - - # let one of the lookups complete - d2.callback("result2") - - # now the cache will return a completed deferred - self.assertEqual(self.successResultOf(cache.get("key2")), "result2") - - # now do the invalidation - cache.invalidate_all() - - # lookup should fail - with self.assertRaises(KeyError): - cache.get("key1") - with self.assertRaises(KeyError): - cache.get("key2") - - # both callbacks should have been callbacked - self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") - self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") - - # letting the other lookup complete should do nothing - d1.callback("result1") - with self.assertRaises(KeyError): - cache.get("key1", None) - - def test_eviction(self): - cache = DeferredCache( - "test", max_entries=2, apply_cache_factor_from_config=False - ) - - cache.prefill(1, "one") - cache.prefill(2, "two") - cache.prefill(3, "three") # 1 will be evicted - - with self.assertRaises(KeyError): - cache.get(1) - - cache.get(2) - cache.get(3) - - def test_eviction_lru(self): - cache = DeferredCache( - "test", max_entries=2, apply_cache_factor_from_config=False - ) - - cache.prefill(1, "one") - cache.prefill(2, "two") - - # Now access 1 again, thus causing 2 to be least-recently used - cache.get(1) - - cache.prefill(3, "three") - - with self.assertRaises(KeyError): - cache.get(2) - - cache.get(1) - cache.get(3) - - def test_eviction_iterable(self): - cache = DeferredCache( - "test", - max_entries=3, - apply_cache_factor_from_config=False, - iterable=True, - ) - - cache.prefill(1, ["one", "two"]) - cache.prefill(2, ["three"]) - - # Now access 1 again, thus causing 2 to be least-recently used - cache.get(1) - - # Now add an item to the cache, which evicts 2. - cache.prefill(3, ["four"]) - with self.assertRaises(KeyError): - cache.get(2) - - # Ensure 1 & 3 are in the cache. - cache.get(1) - cache.get(3) - - # Now access 1 again, thus causing 3 to be least-recently used - cache.get(1) - - # Now add an item with multiple elements to the cache - cache.prefill(4, ["five", "six"]) - - # Both 1 and 3 are evicted since there's too many elements. - with self.assertRaises(KeyError): - cache.get(1) - with self.assertRaises(KeyError): - cache.get(3) - - # Now add another item to fill the cache again. - cache.prefill(5, ["seven"]) - - # Now access 4, thus causing 5 to be least-recently used - cache.get(4) - - # Add an empty item. - cache.prefill(6, []) - - # 5 gets evicted and replaced since an empty element counts as an item. - with self.assertRaises(KeyError): - cache.get(5) - cache.get(4) - cache.get(6) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 48e616ac7419..41fe55d5b15f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -1,1010 +1,1010 @@ -# Copyright 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import Set -from unittest import mock - -from twisted.internet import defer, reactor -from twisted.internet.defer import CancelledError, Deferred - -from synapse.api.errors import SynapseError -from synapse.logging.context import ( - SENTINEL_CONTEXT, - LoggingContext, - PreserveLoggingContext, - current_context, - make_deferred_yieldable, -) -from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached, cachedList, lru_cache - -from tests import unittest -from tests.test_utils import get_awaitable_result - -logger = logging.getLogger(__name__) - - -class LruCacheDecoratorTestCase(unittest.TestCase): - def test_base(self): - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @lru_cache() - def fn(self, arg1, arg2): - return self.mock(arg1, arg2) - - obj = Cls() - obj.mock.return_value = "fish" - r = obj.fn(1, 2) - self.assertEqual(r, "fish") - obj.mock.assert_called_once_with(1, 2) - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = "chips" - r = obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_called_once_with(1, 3) - obj.mock.reset_mock() - - # the two values should now be cached - r = obj.fn(1, 2) - self.assertEqual(r, "fish") - r = obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_not_called() - - -def run_on_reactor(): - d = defer.Deferred() - reactor.callLater(0, d.callback, 0) - return make_deferred_yieldable(d) - - -class DescriptorTestCase(unittest.TestCase): - @defer.inlineCallbacks - def test_cache(self): - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @descriptors.cached() - def fn(self, arg1, arg2): - return self.mock(arg1, arg2) - - obj = Cls() - - obj.mock.return_value = "fish" - r = yield obj.fn(1, 2) - self.assertEqual(r, "fish") - obj.mock.assert_called_once_with(1, 2) - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = "chips" - r = yield obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_called_once_with(1, 3) - obj.mock.reset_mock() - - # the two values should now be cached - r = yield obj.fn(1, 2) - self.assertEqual(r, "fish") - r = yield obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_not_called() - - @defer.inlineCallbacks - def test_cache_num_args(self): - """Only the first num_args arguments should matter to the cache""" - - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @descriptors.cached(num_args=1) - def fn(self, arg1, arg2): - return self.mock(arg1, arg2) - - obj = Cls() - obj.mock.return_value = "fish" - r = yield obj.fn(1, 2) - self.assertEqual(r, "fish") - obj.mock.assert_called_once_with(1, 2) - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = "chips" - r = yield obj.fn(2, 3) - self.assertEqual(r, "chips") - obj.mock.assert_called_once_with(2, 3) - obj.mock.reset_mock() - - # the two values should now be cached; we should be able to vary - # the second argument and still get the cached result. - r = yield obj.fn(1, 4) - self.assertEqual(r, "fish") - r = yield obj.fn(2, 5) - self.assertEqual(r, "chips") - obj.mock.assert_not_called() - - @defer.inlineCallbacks - def test_cache_uncached_args(self): - """ - Only the arguments not named in uncached_args should matter to the cache - - Note that this is identical to test_cache_num_args, but provides the - arguments differently. - """ - - class Cls: - # Note that it is important that this is not the last argument to - # test behaviour of skipping arguments properly. - @descriptors.cached(uncached_args=("arg2",)) - def fn(self, arg1, arg2, arg3): - return self.mock(arg1, arg2, arg3) - - def __init__(self): - self.mock = mock.Mock() - - obj = Cls() - obj.mock.return_value = "fish" - r = yield obj.fn(1, 2, 3) - self.assertEqual(r, "fish") - obj.mock.assert_called_once_with(1, 2, 3) - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = "chips" - r = yield obj.fn(2, 3, 4) - self.assertEqual(r, "chips") - obj.mock.assert_called_once_with(2, 3, 4) - obj.mock.reset_mock() - - # the two values should now be cached; we should be able to vary - # the second argument and still get the cached result. - r = yield obj.fn(1, 4, 3) - self.assertEqual(r, "fish") - r = yield obj.fn(2, 5, 4) - self.assertEqual(r, "chips") - obj.mock.assert_not_called() - - @defer.inlineCallbacks - def test_cache_kwargs(self): - """Test that keyword arguments are treated properly""" - - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @descriptors.cached() - def fn(self, arg1, kwarg1=2): - return self.mock(arg1, kwarg1=kwarg1) - - obj = Cls() - obj.mock.return_value = "fish" - r = yield obj.fn(1, kwarg1=2) - self.assertEqual(r, "fish") - obj.mock.assert_called_once_with(1, kwarg1=2) - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = "chips" - r = yield obj.fn(1, kwarg1=3) - self.assertEqual(r, "chips") - obj.mock.assert_called_once_with(1, kwarg1=3) - obj.mock.reset_mock() - - # the values should now be cached. - r = yield obj.fn(1, kwarg1=2) - self.assertEqual(r, "fish") - # We should be able to not provide kwarg1 and get the cached value back. - r = yield obj.fn(1) - self.assertEqual(r, "fish") - # Keyword arguments can be in any order. - r = yield obj.fn(kwarg1=2, arg1=1) - self.assertEqual(r, "fish") - obj.mock.assert_not_called() - - def test_cache_with_sync_exception(self): - """If the wrapped function throws synchronously, things should continue to work""" - - class Cls: - @cached() - def fn(self, arg1): - raise SynapseError(100, "mai spoon iz too big!!1") - - obj = Cls() - - # this should fail immediately - d = obj.fn(1) - self.failureResultOf(d, SynapseError) - - # ... leaving the cache empty - self.assertEqual(len(obj.fn.cache.cache), 0) - - # and a second call should result in a second exception - d = obj.fn(1) - self.failureResultOf(d, SynapseError) - - def test_cache_with_async_exception(self): - """The wrapped function returns a failure""" - - class Cls: - result = None - call_count = 0 - - @cached() - def fn(self, arg1): - self.call_count += 1 - return self.result - - obj = Cls() - callbacks: Set[str] = set() - - # set off an asynchronous request - obj.result = origin_d = defer.Deferred() - - d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) - self.assertFalse(d1.called) - - # a second request should also return a deferred, but should not call the - # function itself. - d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2")) - self.assertFalse(d2.called) - self.assertEqual(obj.call_count, 1) - - # no callbacks yet - self.assertEqual(callbacks, set()) - - # the original request fails - e = Exception("bzz") - origin_d.errback(e) - - # ... which should cause the lookups to fail similarly - self.assertIs(self.failureResultOf(d1, Exception).value, e) - self.assertIs(self.failureResultOf(d2, Exception).value, e) - - # ... and the callbacks to have been, uh, called. - self.assertEqual(callbacks, {"d1", "d2"}) - - # ... leaving the cache empty - self.assertEqual(len(obj.fn.cache.cache), 0) - - # and a second call should work as normal - obj.result = defer.succeed(100) - d3 = obj.fn(1) - self.assertEqual(self.successResultOf(d3), 100) - self.assertEqual(obj.call_count, 2) - - def test_cache_logcontexts(self): - """Check that logcontexts are set and restored correctly when - using the cache.""" - - complete_lookup = defer.Deferred() - - class Cls: - @descriptors.cached() - def fn(self, arg1): - @defer.inlineCallbacks - def inner_fn(): - with PreserveLoggingContext(): - yield complete_lookup - return 1 - - return inner_fn() - - @defer.inlineCallbacks - def do_lookup(): - with LoggingContext("c1") as c1: - r = yield obj.fn(1) - self.assertEqual(current_context(), c1) - return r - - def check_result(r): - self.assertEqual(r, 1) - - obj = Cls() - - # set off a deferred which will do a cache lookup - d1 = do_lookup() - self.assertEqual(current_context(), SENTINEL_CONTEXT) - d1.addCallback(check_result) - - # and another - d2 = do_lookup() - self.assertEqual(current_context(), SENTINEL_CONTEXT) - d2.addCallback(check_result) - - # let the lookup complete - complete_lookup.callback(None) - - return defer.gatherResults([d1, d2]) - - def test_cache_logcontexts_with_exception(self): - """Check that the cache sets and restores logcontexts correctly when - the lookup function throws an exception""" - - class Cls: - @descriptors.cached() - def fn(self, arg1): - @defer.inlineCallbacks - def inner_fn(): - # we want this to behave like an asynchronous function - yield run_on_reactor() - raise SynapseError(400, "blah") - - return inner_fn() - - @defer.inlineCallbacks - def do_lookup(): - with LoggingContext("c1") as c1: - try: - d = obj.fn(1) - self.assertEqual( - current_context(), - SENTINEL_CONTEXT, - ) - yield d - self.fail("No exception thrown") - except SynapseError: - pass - - self.assertEqual(current_context(), c1) - - # the cache should now be empty - self.assertEqual(len(obj.fn.cache.cache), 0) - - obj = Cls() - - # set off a deferred which will do a cache lookup - d1 = do_lookup() - self.assertEqual(current_context(), SENTINEL_CONTEXT) - - return d1 - - @defer.inlineCallbacks - def test_cache_default_args(self): - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @descriptors.cached() - def fn(self, arg1, arg2=2, arg3=3): - return self.mock(arg1, arg2, arg3) - - obj = Cls() - - obj.mock.return_value = "fish" - r = yield obj.fn(1, 2, 3) - self.assertEqual(r, "fish") - obj.mock.assert_called_once_with(1, 2, 3) - obj.mock.reset_mock() - - # a call with same params shouldn't call the mock again - r = yield obj.fn(1, 2) - self.assertEqual(r, "fish") - obj.mock.assert_not_called() - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = "chips" - r = yield obj.fn(2, 3) - self.assertEqual(r, "chips") - obj.mock.assert_called_once_with(2, 3, 3) - obj.mock.reset_mock() - - # the two values should now be cached - r = yield obj.fn(1, 2) - self.assertEqual(r, "fish") - r = yield obj.fn(2, 3) - self.assertEqual(r, "chips") - obj.mock.assert_not_called() - - def test_cache_iterable(self): - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @descriptors.cached(iterable=True) - def fn(self, arg1, arg2): - return self.mock(arg1, arg2) - - obj = Cls() - - obj.mock.return_value = ["spam", "eggs"] - r = obj.fn(1, 2) - self.assertEqual(r.result, ["spam", "eggs"]) - obj.mock.assert_called_once_with(1, 2) - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = ["chips"] - r = obj.fn(1, 3) - self.assertEqual(r.result, ["chips"]) - obj.mock.assert_called_once_with(1, 3) - obj.mock.reset_mock() - - # the two values should now be cached - self.assertEqual(len(obj.fn.cache.cache), 3) - - r = obj.fn(1, 2) - self.assertEqual(r.result, ["spam", "eggs"]) - r = obj.fn(1, 3) - self.assertEqual(r.result, ["chips"]) - obj.mock.assert_not_called() - - def test_cache_iterable_with_sync_exception(self): - """If the wrapped function throws synchronously, things should continue to work""" - - class Cls: - @descriptors.cached(iterable=True) - def fn(self, arg1): - raise SynapseError(100, "mai spoon iz too big!!1") - - obj = Cls() - - # this should fail immediately - d = obj.fn(1) - self.failureResultOf(d, SynapseError) - - # ... leaving the cache empty - self.assertEqual(len(obj.fn.cache.cache), 0) - - # and a second call should result in a second exception - d = obj.fn(1) - self.failureResultOf(d, SynapseError) - - def test_invalidate_cascade(self): - """Invalidations should cascade up through cache contexts""" - - class Cls: - @cached(cache_context=True) - async def func1(self, key, cache_context): - return await self.func2(key, on_invalidate=cache_context.invalidate) - - @cached(cache_context=True) - async def func2(self, key, cache_context): - return self.func3(key, on_invalidate=cache_context.invalidate) - - @lru_cache(cache_context=True) - def func3(self, key, cache_context): - self.invalidate = cache_context.invalidate - return 42 - - obj = Cls() - - top_invalidate = mock.Mock() - r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate)) - self.assertEqual(r, 42) - obj.invalidate() - top_invalidate.assert_called_once() - - def test_cancel(self): - """Test that cancelling a lookup does not cancel other lookups""" - complete_lookup: "Deferred[None]" = Deferred() - - class Cls: - @cached() - async def fn(self, arg1): - await complete_lookup - return str(arg1) - - obj = Cls() - - d1 = obj.fn(123) - d2 = obj.fn(123) - self.assertFalse(d1.called) - self.assertFalse(d2.called) - - # Cancel `d1`, which is the lookup that caused `fn` to run. - d1.cancel() - - # `d2` should complete normally. - complete_lookup.callback(None) - self.failureResultOf(d1, CancelledError) - self.assertEqual(d2.result, "123") +# # Copyright 2016 OpenMarket Ltd +# # Copyright 2018 New Vector Ltd +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. +# import logging +# from typing import Set +# from unittest import mock + +# from twisted.internet import defer, reactor +# from twisted.internet.defer import CancelledError, Deferred + +# from synapse.api.errors import SynapseError +# from synapse.logging.context import ( +# SENTINEL_CONTEXT, +# LoggingContext, +# PreserveLoggingContext, +# current_context, +# make_deferred_yieldable, +# ) +# from synapse.util.caches import descriptors +# from synapse.util.caches.descriptors import cached, cachedList, lru_cache + +# from tests import unittest +# from tests.test_utils import get_awaitable_result + +# logger = logging.getLogger(__name__) + + +# class LruCacheDecoratorTestCase(unittest.TestCase): +# def test_base(self): +# class Cls: +# def __init__(self): +# self.mock = mock.Mock() + +# @lru_cache() +# def fn(self, arg1, arg2): +# return self.mock(arg1, arg2) + +# obj = Cls() +# obj.mock.return_value = "fish" +# r = obj.fn(1, 2) +# self.assertEqual(r, "fish") +# obj.mock.assert_called_once_with(1, 2) +# obj.mock.reset_mock() + +# # a call with different params should call the mock again +# obj.mock.return_value = "chips" +# r = obj.fn(1, 3) +# self.assertEqual(r, "chips") +# obj.mock.assert_called_once_with(1, 3) +# obj.mock.reset_mock() + +# # the two values should now be cached +# r = obj.fn(1, 2) +# self.assertEqual(r, "fish") +# r = obj.fn(1, 3) +# self.assertEqual(r, "chips") +# obj.mock.assert_not_called() + + +# def run_on_reactor(): +# d = defer.Deferred() +# reactor.callLater(0, d.callback, 0) +# return make_deferred_yieldable(d) + + +# class DescriptorTestCase(unittest.TestCase): +# @defer.inlineCallbacks +# def test_cache(self): +# class Cls: +# def __init__(self): +# self.mock = mock.Mock() + +# @descriptors.cached() +# def fn(self, arg1, arg2): +# return self.mock(arg1, arg2) + +# obj = Cls() + +# obj.mock.return_value = "fish" +# r = yield obj.fn(1, 2) +# self.assertEqual(r, "fish") +# obj.mock.assert_called_once_with(1, 2) +# obj.mock.reset_mock() + +# # a call with different params should call the mock again +# obj.mock.return_value = "chips" +# r = yield obj.fn(1, 3) +# self.assertEqual(r, "chips") +# obj.mock.assert_called_once_with(1, 3) +# obj.mock.reset_mock() + +# # the two values should now be cached +# r = yield obj.fn(1, 2) +# self.assertEqual(r, "fish") +# r = yield obj.fn(1, 3) +# self.assertEqual(r, "chips") +# obj.mock.assert_not_called() + +# @defer.inlineCallbacks +# def test_cache_num_args(self): +# """Only the first num_args arguments should matter to the cache""" + +# class Cls: +# def __init__(self): +# self.mock = mock.Mock() + +# @descriptors.cached(num_args=1) +# def fn(self, arg1, arg2): +# return self.mock(arg1, arg2) + +# obj = Cls() +# obj.mock.return_value = "fish" +# r = yield obj.fn(1, 2) +# self.assertEqual(r, "fish") +# obj.mock.assert_called_once_with(1, 2) +# obj.mock.reset_mock() + +# # a call with different params should call the mock again +# obj.mock.return_value = "chips" +# r = yield obj.fn(2, 3) +# self.assertEqual(r, "chips") +# obj.mock.assert_called_once_with(2, 3) +# obj.mock.reset_mock() + +# # the two values should now be cached; we should be able to vary +# # the second argument and still get the cached result. +# r = yield obj.fn(1, 4) +# self.assertEqual(r, "fish") +# r = yield obj.fn(2, 5) +# self.assertEqual(r, "chips") +# obj.mock.assert_not_called() + +# @defer.inlineCallbacks +# def test_cache_uncached_args(self): +# """ +# Only the arguments not named in uncached_args should matter to the cache + +# Note that this is identical to test_cache_num_args, but provides the +# arguments differently. +# """ + +# class Cls: +# # Note that it is important that this is not the last argument to +# # test behaviour of skipping arguments properly. +# @descriptors.cached(uncached_args=("arg2",)) +# def fn(self, arg1, arg2, arg3): +# return self.mock(arg1, arg2, arg3) + +# def __init__(self): +# self.mock = mock.Mock() + +# obj = Cls() +# obj.mock.return_value = "fish" +# r = yield obj.fn(1, 2, 3) +# self.assertEqual(r, "fish") +# obj.mock.assert_called_once_with(1, 2, 3) +# obj.mock.reset_mock() + +# # a call with different params should call the mock again +# obj.mock.return_value = "chips" +# r = yield obj.fn(2, 3, 4) +# self.assertEqual(r, "chips") +# obj.mock.assert_called_once_with(2, 3, 4) +# obj.mock.reset_mock() + +# # the two values should now be cached; we should be able to vary +# # the second argument and still get the cached result. +# r = yield obj.fn(1, 4, 3) +# self.assertEqual(r, "fish") +# r = yield obj.fn(2, 5, 4) +# self.assertEqual(r, "chips") +# obj.mock.assert_not_called() + +# @defer.inlineCallbacks +# def test_cache_kwargs(self): +# """Test that keyword arguments are treated properly""" + +# class Cls: +# def __init__(self): +# self.mock = mock.Mock() + +# @descriptors.cached() +# def fn(self, arg1, kwarg1=2): +# return self.mock(arg1, kwarg1=kwarg1) + +# obj = Cls() +# obj.mock.return_value = "fish" +# r = yield obj.fn(1, kwarg1=2) +# self.assertEqual(r, "fish") +# obj.mock.assert_called_once_with(1, kwarg1=2) +# obj.mock.reset_mock() + +# # a call with different params should call the mock again +# obj.mock.return_value = "chips" +# r = yield obj.fn(1, kwarg1=3) +# self.assertEqual(r, "chips") +# obj.mock.assert_called_once_with(1, kwarg1=3) +# obj.mock.reset_mock() + +# # the values should now be cached. +# r = yield obj.fn(1, kwarg1=2) +# self.assertEqual(r, "fish") +# # We should be able to not provide kwarg1 and get the cached value back. +# r = yield obj.fn(1) +# self.assertEqual(r, "fish") +# # Keyword arguments can be in any order. +# r = yield obj.fn(kwarg1=2, arg1=1) +# self.assertEqual(r, "fish") +# obj.mock.assert_not_called() + +# def test_cache_with_sync_exception(self): +# """If the wrapped function throws synchronously, things should continue to work""" + +# class Cls: +# @cached() +# def fn(self, arg1): +# raise SynapseError(100, "mai spoon iz too big!!1") + +# obj = Cls() + +# # this should fail immediately +# d = obj.fn(1) +# self.failureResultOf(d, SynapseError) + +# # ... leaving the cache empty +# self.assertEqual(len(obj.fn.cache.cache), 0) + +# # and a second call should result in a second exception +# d = obj.fn(1) +# self.failureResultOf(d, SynapseError) + +# def test_cache_with_async_exception(self): +# """The wrapped function returns a failure""" + +# class Cls: +# result = None +# call_count = 0 + +# @cached() +# def fn(self, arg1): +# self.call_count += 1 +# return self.result + +# obj = Cls() +# callbacks: Set[str] = set() + +# # set off an asynchronous request +# obj.result = origin_d = defer.Deferred() + +# d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) +# self.assertFalse(d1.called) + +# # a second request should also return a deferred, but should not call the +# # function itself. +# d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2")) +# self.assertFalse(d2.called) +# self.assertEqual(obj.call_count, 1) + +# # no callbacks yet +# self.assertEqual(callbacks, set()) + +# # the original request fails +# e = Exception("bzz") +# origin_d.errback(e) + +# # ... which should cause the lookups to fail similarly +# self.assertIs(self.failureResultOf(d1, Exception).value, e) +# self.assertIs(self.failureResultOf(d2, Exception).value, e) + +# # ... and the callbacks to have been, uh, called. +# self.assertEqual(callbacks, {"d1", "d2"}) + +# # ... leaving the cache empty +# self.assertEqual(len(obj.fn.cache.cache), 0) + +# # and a second call should work as normal +# obj.result = defer.succeed(100) +# d3 = obj.fn(1) +# self.assertEqual(self.successResultOf(d3), 100) +# self.assertEqual(obj.call_count, 2) + +# def test_cache_logcontexts(self): +# """Check that logcontexts are set and restored correctly when +# using the cache.""" + +# complete_lookup = defer.Deferred() + +# class Cls: +# @descriptors.cached() +# def fn(self, arg1): +# @defer.inlineCallbacks +# def inner_fn(): +# with PreserveLoggingContext(): +# yield complete_lookup +# return 1 + +# return inner_fn() + +# @defer.inlineCallbacks +# def do_lookup(): +# with LoggingContext("c1") as c1: +# r = yield obj.fn(1) +# self.assertEqual(current_context(), c1) +# return r + +# def check_result(r): +# self.assertEqual(r, 1) + +# obj = Cls() + +# # set off a deferred which will do a cache lookup +# d1 = do_lookup() +# self.assertEqual(current_context(), SENTINEL_CONTEXT) +# d1.addCallback(check_result) + +# # and another +# d2 = do_lookup() +# self.assertEqual(current_context(), SENTINEL_CONTEXT) +# d2.addCallback(check_result) + +# # let the lookup complete +# complete_lookup.callback(None) + +# return defer.gatherResults([d1, d2]) + +# def test_cache_logcontexts_with_exception(self): +# """Check that the cache sets and restores logcontexts correctly when +# the lookup function throws an exception""" + +# class Cls: +# @descriptors.cached() +# def fn(self, arg1): +# @defer.inlineCallbacks +# def inner_fn(): +# # we want this to behave like an asynchronous function +# yield run_on_reactor() +# raise SynapseError(400, "blah") + +# return inner_fn() + +# @defer.inlineCallbacks +# def do_lookup(): +# with LoggingContext("c1") as c1: +# try: +# d = obj.fn(1) +# self.assertEqual( +# current_context(), +# SENTINEL_CONTEXT, +# ) +# yield d +# self.fail("No exception thrown") +# except SynapseError: +# pass + +# self.assertEqual(current_context(), c1) + +# # the cache should now be empty +# self.assertEqual(len(obj.fn.cache.cache), 0) + +# obj = Cls() + +# # set off a deferred which will do a cache lookup +# d1 = do_lookup() +# self.assertEqual(current_context(), SENTINEL_CONTEXT) + +# return d1 + +# @defer.inlineCallbacks +# def test_cache_default_args(self): +# class Cls: +# def __init__(self): +# self.mock = mock.Mock() + +# @descriptors.cached() +# def fn(self, arg1, arg2=2, arg3=3): +# return self.mock(arg1, arg2, arg3) + +# obj = Cls() + +# obj.mock.return_value = "fish" +# r = yield obj.fn(1, 2, 3) +# self.assertEqual(r, "fish") +# obj.mock.assert_called_once_with(1, 2, 3) +# obj.mock.reset_mock() + +# # a call with same params shouldn't call the mock again +# r = yield obj.fn(1, 2) +# self.assertEqual(r, "fish") +# obj.mock.assert_not_called() +# obj.mock.reset_mock() + +# # a call with different params should call the mock again +# obj.mock.return_value = "chips" +# r = yield obj.fn(2, 3) +# self.assertEqual(r, "chips") +# obj.mock.assert_called_once_with(2, 3, 3) +# obj.mock.reset_mock() + +# # the two values should now be cached +# r = yield obj.fn(1, 2) +# self.assertEqual(r, "fish") +# r = yield obj.fn(2, 3) +# self.assertEqual(r, "chips") +# obj.mock.assert_not_called() + +# def test_cache_iterable(self): +# class Cls: +# def __init__(self): +# self.mock = mock.Mock() + +# @descriptors.cached(iterable=True) +# def fn(self, arg1, arg2): +# return self.mock(arg1, arg2) + +# obj = Cls() + +# obj.mock.return_value = ["spam", "eggs"] +# r = obj.fn(1, 2) +# self.assertEqual(r.result, ["spam", "eggs"]) +# obj.mock.assert_called_once_with(1, 2) +# obj.mock.reset_mock() + +# # a call with different params should call the mock again +# obj.mock.return_value = ["chips"] +# r = obj.fn(1, 3) +# self.assertEqual(r.result, ["chips"]) +# obj.mock.assert_called_once_with(1, 3) +# obj.mock.reset_mock() + +# # the two values should now be cached +# self.assertEqual(len(obj.fn.cache.cache), 3) + +# r = obj.fn(1, 2) +# self.assertEqual(r.result, ["spam", "eggs"]) +# r = obj.fn(1, 3) +# self.assertEqual(r.result, ["chips"]) +# obj.mock.assert_not_called() + +# def test_cache_iterable_with_sync_exception(self): +# """If the wrapped function throws synchronously, things should continue to work""" + +# class Cls: +# @descriptors.cached(iterable=True) +# def fn(self, arg1): +# raise SynapseError(100, "mai spoon iz too big!!1") + +# obj = Cls() + +# # this should fail immediately +# d = obj.fn(1) +# self.failureResultOf(d, SynapseError) + +# # ... leaving the cache empty +# self.assertEqual(len(obj.fn.cache.cache), 0) + +# # and a second call should result in a second exception +# d = obj.fn(1) +# self.failureResultOf(d, SynapseError) + +# def test_invalidate_cascade(self): +# """Invalidations should cascade up through cache contexts""" + +# class Cls: +# @cached(cache_context=True) +# async def func1(self, key, cache_context): +# return await self.func2(key, on_invalidate=cache_context.invalidate) + +# @cached(cache_context=True) +# async def func2(self, key, cache_context): +# return self.func3(key, on_invalidate=cache_context.invalidate) + +# @lru_cache(cache_context=True) +# def func3(self, key, cache_context): +# self.invalidate = cache_context.invalidate +# return 42 + +# obj = Cls() + +# top_invalidate = mock.Mock() +# r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate)) +# self.assertEqual(r, 42) +# obj.invalidate() +# top_invalidate.assert_called_once() + +# def test_cancel(self): +# """Test that cancelling a lookup does not cancel other lookups""" +# complete_lookup: "Deferred[None]" = Deferred() + +# class Cls: +# @cached() +# async def fn(self, arg1): +# await complete_lookup +# return str(arg1) + +# obj = Cls() + +# d1 = obj.fn(123) +# d2 = obj.fn(123) +# self.assertFalse(d1.called) +# self.assertFalse(d2.called) + +# # Cancel `d1`, which is the lookup that caused `fn` to run. +# d1.cancel() + +# # `d2` should complete normally. +# complete_lookup.callback(None) +# self.failureResultOf(d1, CancelledError) +# self.assertEqual(d2.result, "123") - def test_cancel_logcontexts(self): - """Test that cancellation does not break logcontexts. +# def test_cancel_logcontexts(self): +# """Test that cancellation does not break logcontexts. - * The `CancelledError` must be raised with the correct logcontext. - * The inner lookup must not resume with a finished logcontext. - * The inner lookup must not restore a finished logcontext when done. - """ - complete_lookup: "Deferred[None]" = Deferred() - - class Cls: - inner_context_was_finished = False +# * The `CancelledError` must be raised with the correct logcontext. +# * The inner lookup must not resume with a finished logcontext. +# * The inner lookup must not restore a finished logcontext when done. +# """ +# complete_lookup: "Deferred[None]" = Deferred() + +# class Cls: +# inner_context_was_finished = False - @cached() - async def fn(self, arg1): - await make_deferred_yieldable(complete_lookup) - self.inner_context_was_finished = current_context().finished - return str(arg1) +# @cached() +# async def fn(self, arg1): +# await make_deferred_yieldable(complete_lookup) +# self.inner_context_was_finished = current_context().finished +# return str(arg1) - obj = Cls() +# obj = Cls() - async def do_lookup(): - with LoggingContext("c1") as c1: - try: - await obj.fn(123) - self.fail("No CancelledError thrown") - except CancelledError: - self.assertEqual( - current_context(), - c1, - "CancelledError was not raised with the correct logcontext", - ) - # suppress the error and succeed +# async def do_lookup(): +# with LoggingContext("c1") as c1: +# try: +# await obj.fn(123) +# self.fail("No CancelledError thrown") +# except CancelledError: +# self.assertEqual( +# current_context(), +# c1, +# "CancelledError was not raised with the correct logcontext", +# ) +# # suppress the error and succeed - d = defer.ensureDeferred(do_lookup()) - d.cancel() +# d = defer.ensureDeferred(do_lookup()) +# d.cancel() - complete_lookup.callback(None) - self.successResultOf(d) - self.assertFalse( - obj.inner_context_was_finished, "Tried to restart a finished logcontext" - ) - self.assertEqual(current_context(), SENTINEL_CONTEXT) +# complete_lookup.callback(None) +# self.successResultOf(d) +# self.assertFalse( +# obj.inner_context_was_finished, "Tried to restart a finished logcontext" +# ) +# self.assertEqual(current_context(), SENTINEL_CONTEXT) -class CacheDecoratorTestCase(unittest.HomeserverTestCase): - """More tests for @cached +# class CacheDecoratorTestCase(unittest.HomeserverTestCase): +# """More tests for @cached - The following is a set of tests that got lost in a different file for a while. +# The following is a set of tests that got lost in a different file for a while. - There are probably duplicates of the tests in DescriptorTestCase. Ideally the - duplicates would be removed and the two sets of classes combined. - """ +# There are probably duplicates of the tests in DescriptorTestCase. Ideally the +# duplicates would be removed and the two sets of classes combined. +# """ - @defer.inlineCallbacks - def test_passthrough(self): - class A: - @cached() - def func(self, key): - return key +# @defer.inlineCallbacks +# def test_passthrough(self): +# class A: +# @cached() +# def func(self, key): +# return key - a = A() +# a = A() - self.assertEqual((yield a.func("foo")), "foo") - self.assertEqual((yield a.func("bar")), "bar") +# self.assertEqual((yield a.func("foo")), "foo") +# self.assertEqual((yield a.func("bar")), "bar") - @defer.inlineCallbacks - def test_hit(self): - callcount = [0] +# @defer.inlineCallbacks +# def test_hit(self): +# callcount = [0] - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key +# class A: +# @cached() +# def func(self, key): +# callcount[0] += 1 +# return key - a = A() - yield a.func("foo") +# a = A() +# yield a.func("foo") - self.assertEqual(callcount[0], 1) +# self.assertEqual(callcount[0], 1) - self.assertEqual((yield a.func("foo")), "foo") - self.assertEqual(callcount[0], 1) +# self.assertEqual((yield a.func("foo")), "foo") +# self.assertEqual(callcount[0], 1) - @defer.inlineCallbacks - def test_invalidate(self): - callcount = [0] +# @defer.inlineCallbacks +# def test_invalidate(self): +# callcount = [0] - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key +# class A: +# @cached() +# def func(self, key): +# callcount[0] += 1 +# return key - a = A() - yield a.func("foo") +# a = A() +# yield a.func("foo") - self.assertEqual(callcount[0], 1) +# self.assertEqual(callcount[0], 1) - a.func.invalidate(("foo",)) +# a.func.invalidate(("foo",)) - yield a.func("foo") +# yield a.func("foo") - self.assertEqual(callcount[0], 2) +# self.assertEqual(callcount[0], 2) - def test_invalidate_missing(self): - class A: - @cached() - def func(self, key): - return key +# def test_invalidate_missing(self): +# class A: +# @cached() +# def func(self, key): +# return key - A().func.invalidate(("what",)) +# A().func.invalidate(("what",)) - @defer.inlineCallbacks - def test_max_entries(self): - callcount = [0] +# @defer.inlineCallbacks +# def test_max_entries(self): +# callcount = [0] - class A: - @cached(max_entries=10) - def func(self, key): - callcount[0] += 1 - return key +# class A: +# @cached(max_entries=10) +# def func(self, key): +# callcount[0] += 1 +# return key - a = A() +# a = A() - for k in range(0, 12): - yield a.func(k) +# for k in range(0, 12): +# yield a.func(k) - self.assertEqual(callcount[0], 12) +# self.assertEqual(callcount[0], 12) - # There must have been at least 2 evictions, meaning if we calculate - # all 12 values again, we must get called at least 2 more times - for k in range(0, 12): - yield a.func(k) +# # There must have been at least 2 evictions, meaning if we calculate +# # all 12 values again, we must get called at least 2 more times +# for k in range(0, 12): +# yield a.func(k) - self.assertTrue( - callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0]) - ) +# self.assertTrue( +# callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0]) +# ) - def test_prefill(self): - callcount = [0] +# def test_prefill(self): +# callcount = [0] - d = defer.succeed(123) +# d = defer.succeed(123) - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return d +# class A: +# @cached() +# def func(self, key): +# callcount[0] += 1 +# return d - a = A() +# a = A() - a.func.prefill(("foo",), 456) +# a.func.prefill(("foo",), 456) - self.assertEqual(a.func("foo").result, 456) - self.assertEqual(callcount[0], 0) +# self.assertEqual(a.func("foo").result, 456) +# self.assertEqual(callcount[0], 0) - @defer.inlineCallbacks - def test_invalidate_context(self): - callcount = [0] - callcount2 = [0] +# @defer.inlineCallbacks +# def test_invalidate_context(self): +# callcount = [0] +# callcount2 = [0] - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key +# class A: +# @cached() +# def func(self, key): +# callcount[0] += 1 +# return key - @cached(cache_context=True) - def func2(self, key, cache_context): - callcount2[0] += 1 - return self.func(key, on_invalidate=cache_context.invalidate) +# @cached(cache_context=True) +# def func2(self, key, cache_context): +# callcount2[0] += 1 +# return self.func(key, on_invalidate=cache_context.invalidate) - a = A() - yield a.func2("foo") +# a = A() +# yield a.func2("foo") - self.assertEqual(callcount[0], 1) - self.assertEqual(callcount2[0], 1) +# self.assertEqual(callcount[0], 1) +# self.assertEqual(callcount2[0], 1) - a.func.invalidate(("foo",)) - yield a.func("foo") +# a.func.invalidate(("foo",)) +# yield a.func("foo") - self.assertEqual(callcount[0], 2) - self.assertEqual(callcount2[0], 1) +# self.assertEqual(callcount[0], 2) +# self.assertEqual(callcount2[0], 1) - yield a.func2("foo") +# yield a.func2("foo") - self.assertEqual(callcount[0], 2) - self.assertEqual(callcount2[0], 2) +# self.assertEqual(callcount[0], 2) +# self.assertEqual(callcount2[0], 2) - @defer.inlineCallbacks - def test_eviction_context(self): - callcount = [0] - callcount2 = [0] +# @defer.inlineCallbacks +# def test_eviction_context(self): +# callcount = [0] +# callcount2 = [0] - class A: - @cached(max_entries=2) - def func(self, key): - callcount[0] += 1 - return key +# class A: +# @cached(max_entries=2) +# def func(self, key): +# callcount[0] += 1 +# return key - @cached(cache_context=True) - def func2(self, key, cache_context): - callcount2[0] += 1 - return self.func(key, on_invalidate=cache_context.invalidate) +# @cached(cache_context=True) +# def func2(self, key, cache_context): +# callcount2[0] += 1 +# return self.func(key, on_invalidate=cache_context.invalidate) - a = A() - yield a.func2("foo") - yield a.func2("foo2") +# a = A() +# yield a.func2("foo") +# yield a.func2("foo2") - self.assertEqual(callcount[0], 2) - self.assertEqual(callcount2[0], 2) +# self.assertEqual(callcount[0], 2) +# self.assertEqual(callcount2[0], 2) - yield a.func2("foo") - self.assertEqual(callcount[0], 2) - self.assertEqual(callcount2[0], 2) +# yield a.func2("foo") +# self.assertEqual(callcount[0], 2) +# self.assertEqual(callcount2[0], 2) - yield a.func("foo3") +# yield a.func("foo3") - self.assertEqual(callcount[0], 3) - self.assertEqual(callcount2[0], 2) +# self.assertEqual(callcount[0], 3) +# self.assertEqual(callcount2[0], 2) - yield a.func2("foo") +# yield a.func2("foo") - self.assertEqual(callcount[0], 4) - self.assertEqual(callcount2[0], 3) +# self.assertEqual(callcount[0], 4) +# self.assertEqual(callcount2[0], 3) - @defer.inlineCallbacks - def test_double_get(self): - callcount = [0] - callcount2 = [0] - - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key - - @cached(cache_context=True) - def func2(self, key, cache_context): - callcount2[0] += 1 - return self.func(key, on_invalidate=cache_context.invalidate) - - a = A() - a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache) - - yield a.func2("foo") - - self.assertEqual(callcount[0], 1) - self.assertEqual(callcount2[0], 1) - - a.func2.invalidate(("foo",)) - self.assertEqual(a.func2.cache.cache.del_multi.call_count, 1) - - yield a.func2("foo") - a.func2.invalidate(("foo",)) - self.assertEqual(a.func2.cache.cache.del_multi.call_count, 2) - - self.assertEqual(callcount[0], 1) - self.assertEqual(callcount2[0], 2) - - a.func.invalidate(("foo",)) - self.assertEqual(a.func2.cache.cache.del_multi.call_count, 3) - yield a.func("foo") - - self.assertEqual(callcount[0], 2) - self.assertEqual(callcount2[0], 2) - - yield a.func2("foo") - - self.assertEqual(callcount[0], 2) - self.assertEqual(callcount2[0], 3) - - -class CachedListDescriptorTestCase(unittest.TestCase): - @defer.inlineCallbacks - def test_cache(self): - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @descriptors.cached() - def fn(self, arg1, arg2): - pass - - @descriptors.cachedList(cached_method_name="fn", list_name="args1") - async def list_fn(self, args1, arg2): - assert current_context().name == "c1" - # we want this to behave like an asynchronous function - await run_on_reactor() - assert current_context().name == "c1" - return self.mock(args1, arg2) - - with LoggingContext("c1") as c1: - obj = Cls() - obj.mock.return_value = {10: "fish", 20: "chips"} - - # start the lookup off - d1 = obj.list_fn([10, 20], 2) - self.assertEqual(current_context(), SENTINEL_CONTEXT) - r = yield d1 - self.assertEqual(current_context(), c1) - obj.mock.assert_called_once_with({10, 20}, 2) - self.assertEqual(r, {10: "fish", 20: "chips"}) - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = {30: "peas"} - r = yield obj.list_fn([20, 30], 2) - obj.mock.assert_called_once_with({30}, 2) - self.assertEqual(r, {20: "chips", 30: "peas"}) - obj.mock.reset_mock() - - # all the values should now be cached - r = yield obj.fn(10, 2) - self.assertEqual(r, "fish") - r = yield obj.fn(20, 2) - self.assertEqual(r, "chips") - r = yield obj.fn(30, 2) - self.assertEqual(r, "peas") - r = yield obj.list_fn([10, 20, 30], 2) - obj.mock.assert_not_called() - self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"}) - - # we should also be able to use a (single-use) iterable, and should - # deduplicate the keys - obj.mock.reset_mock() - obj.mock.return_value = {40: "gravy"} - iterable = (x for x in [10, 40, 40]) - r = yield obj.list_fn(iterable, 2) - obj.mock.assert_called_once_with({40}, 2) - self.assertEqual(r, {10: "fish", 40: "gravy"}) - - def test_concurrent_lookups(self): - """All concurrent lookups should get the same result""" - - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @descriptors.cached() - def fn(self, arg1): - pass - - @descriptors.cachedList(cached_method_name="fn", list_name="args1") - def list_fn(self, args1) -> "Deferred[dict]": - return self.mock(args1) - - obj = Cls() - deferred_result = Deferred() - obj.mock.return_value = deferred_result - - # start off several concurrent lookups of the same key - d1 = obj.list_fn([10]) - d2 = obj.list_fn([10]) - d3 = obj.list_fn([10]) - - # the mock should have been called exactly once - obj.mock.assert_called_once_with({10}) - obj.mock.reset_mock() - - # ... and none of the calls should yet be complete - self.assertFalse(d1.called) - self.assertFalse(d2.called) - self.assertFalse(d3.called) - - # complete the lookup. @cachedList functions need to complete with a map - # of input->result - deferred_result.callback({10: "peas"}) - - # ... which should give the right result to all the callers - self.assertEqual(self.successResultOf(d1), {10: "peas"}) - self.assertEqual(self.successResultOf(d2), {10: "peas"}) - self.assertEqual(self.successResultOf(d3), {10: "peas"}) - - @defer.inlineCallbacks - def test_invalidate(self): - """Make sure that invalidation callbacks are called.""" - - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @descriptors.cached() - def fn(self, arg1, arg2): - pass - - @descriptors.cachedList(cached_method_name="fn", list_name="args1") - async def list_fn(self, args1, arg2): - # we want this to behave like an asynchronous function - await run_on_reactor() - return self.mock(args1, arg2) - - obj = Cls() - invalidate0 = mock.Mock() - invalidate1 = mock.Mock() - - # cache miss - obj.mock.return_value = {10: "fish", 20: "chips"} - r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) - obj.mock.assert_called_once_with({10, 20}, 2) - self.assertEqual(r1, {10: "fish", 20: "chips"}) - obj.mock.reset_mock() - - # cache hit - r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1) - obj.mock.assert_not_called() - self.assertEqual(r2, {10: "fish", 20: "chips"}) - - invalidate0.assert_not_called() - invalidate1.assert_not_called() - - # now if we invalidate the keys, both invalidations should get called - obj.fn.invalidate((10, 2)) - invalidate0.assert_called_once() - invalidate1.assert_called_once() - - def test_cancel(self): - """Test that cancelling a lookup does not cancel other lookups""" - complete_lookup: "Deferred[None]" = Deferred() - - class Cls: - @cached() - def fn(self, arg1): - pass - - @cachedList(cached_method_name="fn", list_name="args") - async def list_fn(self, args): - await complete_lookup - return {arg: str(arg) for arg in args} - - obj = Cls() - - d1 = obj.list_fn([123, 456]) - d2 = obj.list_fn([123, 456, 789]) - self.assertFalse(d1.called) - self.assertFalse(d2.called) - - d1.cancel() - - # `d2` should complete normally. - complete_lookup.callback(None) - self.failureResultOf(d1, CancelledError) - self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"}) - - def test_cancel_logcontexts(self): - """Test that cancellation does not break logcontexts. - - * The `CancelledError` must be raised with the correct logcontext. - * The inner lookup must not resume with a finished logcontext. - * The inner lookup must not restore a finished logcontext when done. - """ - complete_lookup: "Deferred[None]" = Deferred() - - class Cls: - inner_context_was_finished = False - - @cached() - def fn(self, arg1): - pass - - @cachedList(cached_method_name="fn", list_name="args") - async def list_fn(self, args): - await make_deferred_yieldable(complete_lookup) - self.inner_context_was_finished = current_context().finished - return {arg: str(arg) for arg in args} - - obj = Cls() - - async def do_lookup(): - with LoggingContext("c1") as c1: - try: - await obj.list_fn([123]) - self.fail("No CancelledError thrown") - except CancelledError: - self.assertEqual( - current_context(), - c1, - "CancelledError was not raised with the correct logcontext", - ) - # suppress the error and succeed - - d = defer.ensureDeferred(do_lookup()) - d.cancel() - - complete_lookup.callback(None) - self.successResultOf(d) - self.assertFalse( - obj.inner_context_was_finished, "Tried to restart a finished logcontext" - ) - self.assertEqual(current_context(), SENTINEL_CONTEXT) +# @defer.inlineCallbacks +# def test_double_get(self): +# callcount = [0] +# callcount2 = [0] + +# class A: +# @cached() +# def func(self, key): +# callcount[0] += 1 +# return key + +# @cached(cache_context=True) +# def func2(self, key, cache_context): +# callcount2[0] += 1 +# return self.func(key, on_invalidate=cache_context.invalidate) + +# a = A() +# a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache) + +# yield a.func2("foo") + +# self.assertEqual(callcount[0], 1) +# self.assertEqual(callcount2[0], 1) + +# a.func2.invalidate(("foo",)) +# self.assertEqual(a.func2.cache.cache.del_multi.call_count, 1) + +# yield a.func2("foo") +# a.func2.invalidate(("foo",)) +# self.assertEqual(a.func2.cache.cache.del_multi.call_count, 2) + +# self.assertEqual(callcount[0], 1) +# self.assertEqual(callcount2[0], 2) + +# a.func.invalidate(("foo",)) +# self.assertEqual(a.func2.cache.cache.del_multi.call_count, 3) +# yield a.func("foo") + +# self.assertEqual(callcount[0], 2) +# self.assertEqual(callcount2[0], 2) + +# yield a.func2("foo") + +# self.assertEqual(callcount[0], 2) +# self.assertEqual(callcount2[0], 3) + + +# class CachedListDescriptorTestCase(unittest.TestCase): +# @defer.inlineCallbacks +# def test_cache(self): +# class Cls: +# def __init__(self): +# self.mock = mock.Mock() + +# @descriptors.cached() +# def fn(self, arg1, arg2): +# pass + +# @descriptors.cachedList(cached_method_name="fn", list_name="args1") +# async def list_fn(self, args1, arg2): +# assert current_context().name == "c1" +# # we want this to behave like an asynchronous function +# await run_on_reactor() +# assert current_context().name == "c1" +# return self.mock(args1, arg2) + +# with LoggingContext("c1") as c1: +# obj = Cls() +# obj.mock.return_value = {10: "fish", 20: "chips"} + +# # start the lookup off +# d1 = obj.list_fn([10, 20], 2) +# self.assertEqual(current_context(), SENTINEL_CONTEXT) +# r = yield d1 +# self.assertEqual(current_context(), c1) +# obj.mock.assert_called_once_with({10, 20}, 2) +# self.assertEqual(r, {10: "fish", 20: "chips"}) +# obj.mock.reset_mock() + +# # a call with different params should call the mock again +# obj.mock.return_value = {30: "peas"} +# r = yield obj.list_fn([20, 30], 2) +# obj.mock.assert_called_once_with({30}, 2) +# self.assertEqual(r, {20: "chips", 30: "peas"}) +# obj.mock.reset_mock() + +# # all the values should now be cached +# r = yield obj.fn(10, 2) +# self.assertEqual(r, "fish") +# r = yield obj.fn(20, 2) +# self.assertEqual(r, "chips") +# r = yield obj.fn(30, 2) +# self.assertEqual(r, "peas") +# r = yield obj.list_fn([10, 20, 30], 2) +# obj.mock.assert_not_called() +# self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"}) + +# # we should also be able to use a (single-use) iterable, and should +# # deduplicate the keys +# obj.mock.reset_mock() +# obj.mock.return_value = {40: "gravy"} +# iterable = (x for x in [10, 40, 40]) +# r = yield obj.list_fn(iterable, 2) +# obj.mock.assert_called_once_with({40}, 2) +# self.assertEqual(r, {10: "fish", 40: "gravy"}) + +# def test_concurrent_lookups(self): +# """All concurrent lookups should get the same result""" + +# class Cls: +# def __init__(self): +# self.mock = mock.Mock() + +# @descriptors.cached() +# def fn(self, arg1): +# pass + +# @descriptors.cachedList(cached_method_name="fn", list_name="args1") +# def list_fn(self, args1) -> "Deferred[dict]": +# return self.mock(args1) + +# obj = Cls() +# deferred_result = Deferred() +# obj.mock.return_value = deferred_result + +# # start off several concurrent lookups of the same key +# d1 = obj.list_fn([10]) +# d2 = obj.list_fn([10]) +# d3 = obj.list_fn([10]) + +# # the mock should have been called exactly once +# obj.mock.assert_called_once_with({10}) +# obj.mock.reset_mock() + +# # ... and none of the calls should yet be complete +# self.assertFalse(d1.called) +# self.assertFalse(d2.called) +# self.assertFalse(d3.called) + +# # complete the lookup. @cachedList functions need to complete with a map +# # of input->result +# deferred_result.callback({10: "peas"}) + +# # ... which should give the right result to all the callers +# self.assertEqual(self.successResultOf(d1), {10: "peas"}) +# self.assertEqual(self.successResultOf(d2), {10: "peas"}) +# self.assertEqual(self.successResultOf(d3), {10: "peas"}) + +# @defer.inlineCallbacks +# def test_invalidate(self): +# """Make sure that invalidation callbacks are called.""" + +# class Cls: +# def __init__(self): +# self.mock = mock.Mock() + +# @descriptors.cached() +# def fn(self, arg1, arg2): +# pass + +# @descriptors.cachedList(cached_method_name="fn", list_name="args1") +# async def list_fn(self, args1, arg2): +# # we want this to behave like an asynchronous function +# await run_on_reactor() +# return self.mock(args1, arg2) + +# obj = Cls() +# invalidate0 = mock.Mock() +# invalidate1 = mock.Mock() + +# # cache miss +# obj.mock.return_value = {10: "fish", 20: "chips"} +# r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) +# obj.mock.assert_called_once_with({10, 20}, 2) +# self.assertEqual(r1, {10: "fish", 20: "chips"}) +# obj.mock.reset_mock() + +# # cache hit +# r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1) +# obj.mock.assert_not_called() +# self.assertEqual(r2, {10: "fish", 20: "chips"}) + +# invalidate0.assert_not_called() +# invalidate1.assert_not_called() + +# # now if we invalidate the keys, both invalidations should get called +# obj.fn.invalidate((10, 2)) +# invalidate0.assert_called_once() +# invalidate1.assert_called_once() + +# def test_cancel(self): +# """Test that cancelling a lookup does not cancel other lookups""" +# complete_lookup: "Deferred[None]" = Deferred() + +# class Cls: +# @cached() +# def fn(self, arg1): +# pass + +# @cachedList(cached_method_name="fn", list_name="args") +# async def list_fn(self, args): +# await complete_lookup +# return {arg: str(arg) for arg in args} + +# obj = Cls() + +# d1 = obj.list_fn([123, 456]) +# d2 = obj.list_fn([123, 456, 789]) +# self.assertFalse(d1.called) +# self.assertFalse(d2.called) + +# d1.cancel() + +# # `d2` should complete normally. +# complete_lookup.callback(None) +# self.failureResultOf(d1, CancelledError) +# self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"}) + +# def test_cancel_logcontexts(self): +# """Test that cancellation does not break logcontexts. + +# * The `CancelledError` must be raised with the correct logcontext. +# * The inner lookup must not resume with a finished logcontext. +# * The inner lookup must not restore a finished logcontext when done. +# """ +# complete_lookup: "Deferred[None]" = Deferred() + +# class Cls: +# inner_context_was_finished = False + +# @cached() +# def fn(self, arg1): +# pass + +# @cachedList(cached_method_name="fn", list_name="args") +# async def list_fn(self, args): +# await make_deferred_yieldable(complete_lookup) +# self.inner_context_was_finished = current_context().finished +# return {arg: str(arg) for arg in args} + +# obj = Cls() + +# async def do_lookup(): +# with LoggingContext("c1") as c1: +# try: +# await obj.list_fn([123]) +# self.fail("No CancelledError thrown") +# except CancelledError: +# self.assertEqual( +# current_context(), +# c1, +# "CancelledError was not raised with the correct logcontext", +# ) +# # suppress the error and succeed + +# d = defer.ensureDeferred(do_lookup()) +# d.cancel() + +# complete_lookup.callback(None) +# self.successResultOf(d) +# self.assertFalse( +# obj.inner_context_was_finished, "Tried to restart a finished logcontext" +# ) +# self.assertEqual(current_context(), SENTINEL_CONTEXT)