From 1f4d9081b9fcea793446683f2e8cb4845c77b694 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 3 Jun 2024 10:24:03 -0400 Subject: [PATCH] fix: prevent qthrottled and qdebounced from holding strong references with bound methods (#247) * finish * linting * done * use weakmethod, add signature * add test for warning --- .github/workflows/test_and_deploy.yml | 15 +++-- pyproject.toml | 3 +- src/superqt/utils/_throttler.py | 86 +++++++++++++++++++-------- tests/test_throttler.py | 36 ++++++++++- 4 files changed, 106 insertions(+), 34 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index af86339e..389fc2c6 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -37,19 +37,18 @@ jobs: - python-version: "3.11" backend: pyside2 include: - # https://bugreports.qt.io/browse/PYSIDE-2627 - python-version: "3.10" platform: macos-latest - backend: "'pyside6!=6.6.2'" + backend: pyside6 - python-version: "3.11" platform: macos-latest - backend: "'pyside6!=6.6.2'" + backend: pyside6 - python-version: "3.10" platform: windows-latest - backend: "'pyside6!=6.6.2'" + backend: pyside6 - python-version: "3.11" platform: windows-latest - backend: "'pyside6!=6.6.2'" + backend: pyside6 - python-version: "3.12" platform: macos-latest backend: pyqt6 @@ -69,7 +68,7 @@ jobs: with: python-version: "3.8" qt: pyqt5 - pip-post-installs: 'qtpy==1.1.0 typing-extensions==3.7.4.3' + pip-post-installs: "qtpy==1.1.0 typing-extensions==3.7.4.3" pip-install-flags: -e coverage-upload: artifact @@ -84,11 +83,11 @@ jobs: with: dependency-repo: napari/napari dependency-ref: ${{ matrix.napari-version }} - dependency-extras: 'testing' + dependency-extras: "testing" qt: ${{ matrix.qt }} pytest-args: 'napari/_qt -k "not async and not qt_dims_2 and not qt_viewer_console_focus and not keybinding_editor"' python-version: "3.10" - post-install-cmd: 'pip install lxml_html_clean' + post-install-cmd: "pip install lxml_html_clean" strategy: fail-fast: false matrix: diff --git a/pyproject.toml b/pyproject.toml index 7a788fd9..7e779b9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,6 +155,7 @@ minversion = "6.0" testpaths = ["tests"] filterwarnings = [ "error", + "ignore:Failed to disconnect::pytestqt", "ignore:QPixmapCache.find:DeprecationWarning:", "ignore:SelectableGroups dict interface:DeprecationWarning", "ignore:The distutils package is deprecated:DeprecationWarning", @@ -191,7 +192,7 @@ exclude_lines = [ "@overload", "except ImportError", "\\.\\.\\.", - "pass" + "pass", ] # https://github.com/mgedmin/check-manifest#configuration diff --git a/src/superqt/utils/_throttler.py b/src/superqt/utils/_throttler.py index b63f0db2..c5628c7a 100644 --- a/src/superqt/utils/_throttler.py +++ b/src/superqt/utils/_throttler.py @@ -29,11 +29,15 @@ from __future__ import annotations +import warnings from concurrent.futures import Future +from contextlib import suppress from enum import IntFlag, auto from functools import wraps -from typing import TYPE_CHECKING, Callable, Generic, TypeVar, overload -from weakref import WeakKeyDictionary +from inspect import signature +from types import MethodType +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, overload +from weakref import WeakKeyDictionary, WeakMethod from qtpy.QtCore import QObject, Qt, QTimer, Signal @@ -53,6 +57,12 @@ P = TypeVar("P") R = TypeVar("R") +REF_ERROR = ( + "To use qthrottled or qdebounced as a method decorator, " + "objects must have `__dict__` or be weak referenceable. " + "Please either add `__weakref__` to `__slots__` or use" + "qthrottled/qdebounced as a function (not a decorator)." +) class Kind(IntFlag): @@ -157,7 +167,7 @@ def _emitTriggered(self) -> None: self.triggered.emit() self._timer.start() - def _maybeEmitTriggered(self, restart_timer=True) -> None: + def _maybeEmitTriggered(self, restart_timer: bool = True) -> None: if self._hasPendingEmission: self._emitTriggered() if not restart_timer: @@ -203,6 +213,26 @@ def __init__( # below here part is unique to superqt (not from KD) +def _weak_func(func: Callable[P, R]) -> Callable[P, R]: + if isinstance(func, MethodType): + # this is a bound method, we need to avoid strong references + try: + weak_method = WeakMethod(func) + except TypeError as e: + raise TypeError(REF_ERROR) from e + + def weak_func(*args, **kwargs): + if method := weak_method(): + return method(*args, **kwargs) + warnings.warn( + "Method has been garbage collected", RuntimeWarning, stacklevel=2 + ) + + return weak_func + + return func + + class ThrottledCallable(GenericSignalThrottler, Generic[P, R]): def __init__( self, @@ -214,26 +244,32 @@ def __init__( super().__init__(kind, emissionPolicy, parent) self._future: Future[R] = Future() + + self._is_static_method: bool = False if isinstance(func, staticmethod): - self._func = func.__func__ - else: - self._func = func + self._is_static_method = True + func = func.__func__ + + max_args = get_max_args(func) + with suppress(TypeError, ValueError): + self.__signature__ = signature(func) - self.__wrapped__ = func + self._func = _weak_func(func) + self.__wrapped__ = self._func self._args: tuple = () self._kwargs: dict = {} self.triggered.connect(self._set_future_result) self._name = None - self._obj_dkt = WeakKeyDictionary() + self._obj_dkt: WeakKeyDictionary[Any, ThrottledCallable] = WeakKeyDictionary() # even if we were to compile __call__ with a signature matching that of func, # PySide wouldn't correctly inspect the signature of the ThrottledCallable # instance: https://bugreports.qt.io/browse/PYSIDE-2423 # so we do it ourselfs and limit the number of positional arguments # that we pass to func - self._max_args: int | None = get_max_args(self._func) + self._max_args: int | None = max_args def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "Future[R]": # noqa if not self._future.done(): @@ -251,12 +287,18 @@ def _set_future_result(self): self._future.set_result(result) def __set_name__(self, owner, name): - if not isinstance(self.__wrapped__, staticmethod): + if not self._is_static_method: self._name = name - def _get_throttler(self, instance, owner, parent, obj): + def _get_throttler(self, instance, owner, parent, obj, name): + try: + bound_method = self._func.__get__(instance, owner) + except Exception as e: # pragma: no cover + raise RuntimeError( + f"Failed to bind function {self._func!r} to object {instance!r}" + ) from e throttler = ThrottledCallable( - self.__wrapped__.__get__(instance, owner), + bound_method, self._kind, self._emissionPolicy, parent=parent, @@ -264,21 +306,12 @@ def _get_throttler(self, instance, owner, parent, obj): throttler.setTimerType(self.timerType()) throttler.setTimeout(self.timeout()) try: - setattr( - obj, - self._name, - throttler, - ) + setattr(obj, name, throttler) except AttributeError: try: self._obj_dkt[obj] = throttler except TypeError as e: - raise TypeError( - "To use qthrottled or qdebounced as a method decorator, " - "objects must have `__dict__` or be weak referenceable. " - "Please either add `__weakref__` to `__slots__` or use" - "qthrottled/qdebounced as a function (not a decorator)." - ) from e + raise TypeError(REF_ERROR) from e return throttler def __get__(self, instance, owner): @@ -292,7 +325,7 @@ def __get__(self, instance, owner): if parent is None and isinstance(instance, QObject): parent = instance - return self._get_throttler(instance, owner, parent, instance) + return self._get_throttler(instance, owner, parent, instance, self._name) @overload @@ -438,6 +471,11 @@ def deco(func: Callable[P, R]) -> ThrottledCallable[P, R]: obj = ThrottledCallable(func, kind, policy, parent=parent) obj.setTimerType(timer_type) obj.setTimeout(timeout) + + if instance is not None: + # this is a bound method, we need to avoid strong references, + # and functools.wraps will prevent garbage collection on bound methods + return obj return wraps(func)(obj) return deco(func) if func is not None else deco diff --git a/tests/test_throttler.py b/tests/test_throttler.py index f5aa156b..b62cbd53 100644 --- a/tests/test_throttler.py +++ b/tests/test_throttler.py @@ -1,3 +1,5 @@ +import gc +import weakref from unittest.mock import Mock import pytest @@ -116,7 +118,6 @@ def call2(): A.call2(32) qtbot.wait(5) - assert a.count == 1 mock1.assert_called_once() mock2.assert_called_once() @@ -201,3 +202,36 @@ def func(a: int, b: int): mock.assert_called_once_with(1, 2) assert func.__doc__ == "docstring" assert func.__name__ == "func" + + +def test_qthrottled_does_not_prevent_gc(qtbot): + mock = Mock() + + class Thing: + @qdebounced(timeout=1) + def dmethod(self) -> None: + mock() + + @qthrottled(timeout=1) + def tmethod(self, x: int = 1) -> None: + mock() + + thing = Thing() + thing_ref = weakref.ref(thing) + assert thing_ref() is not None + thing.dmethod() + qtbot.waitUntil(thing.dmethod._future.done, timeout=2000) + assert mock.call_count == 1 + thing.tmethod() + qtbot.waitUntil(thing.tmethod._future.done, timeout=2000) + assert mock.call_count == 2 + + wm = thing.tmethod + assert isinstance(wm, ThrottledCallable) + del thing + gc.collect() + assert thing_ref() is None + + with pytest.warns(RuntimeWarning, match="Method has been garbage collected"): + wm() + wm._set_future_result()