diff --git a/src/cachew/__init__.py b/src/cachew/__init__.py index 1774e2d..ec8ae00 100644 --- a/src/cachew/__init__.py +++ b/src/cachew/__init__.py @@ -41,6 +41,14 @@ fromisoformat = datetime.fromisoformat +from .compat import fix_sqlalchemy_StatementError_str +try: + fix_sqlalchemy_StatementError_str() +except Exception as e: + # todo warn or something?? + pass + + # in case of changes in the way cachew stores data, this should be changed to discard old caches CACHEW_VERSION: str = __version__ @@ -873,6 +881,8 @@ def cachew_wrapper( yield from func(*args, **kwargs) return + early_exit = False + # WARNING: annoyingly huge try/catch ahead... # but it lets us save a function call, hence a stack frame # see test_recursive and test_deep_recursive @@ -980,7 +990,12 @@ def flush(): chunk = [] for d in datas: - yield d + try: + yield d + except GeneratorExit: + early_exit = True + return + chunk.append(binder.to_row(d)) if len(chunk) >= chunk_by: flush() @@ -993,6 +1008,10 @@ def flush(): # pylint: disable=no-value-for-parameter conn.execute(db.table_hash.insert().values([{'value': h}])) except Exception as e: + # sigh... see test_early_exit_shutdown... + if early_exit and 'Cannot operate on a closed database' in str(e): + return + # todo hmm, kinda annoying that it tries calling the function twice? # but gonna require some sophisticated cooperation with the cached wrapper otherwise cachew_error(e) diff --git a/src/cachew/compat.py b/src/cachew/compat.py index 39044d8..3939d7a 100644 --- a/src/cachew/compat.py +++ b/src/cachew/compat.py @@ -130,3 +130,48 @@ def nullcontext(): yield finally: pass + + +### + +import sys +def fix_sqlalchemy_StatementError_str(): + # see https://github.com/sqlalchemy/sqlalchemy/issues/5632 + import sqlalchemy # type: ignore + v = sqlalchemy.__version__ + if v != '1.3.19': + # sigh... will still affect smaller versions.. but patching code to remove import dynamically would be far too mad + return + + from sqlalchemy.util import compat # type: ignore + from sqlalchemy.exc import StatementError as SE # type: ignore + + def _sql_message(self, as_unicode): + details = [self._message(as_unicode=as_unicode)] + if self.statement: + if not as_unicode and not compat.py3k: + stmt_detail = "[SQL: %s]" % compat.safe_bytestring( + self.statement + ) + else: + stmt_detail = "[SQL: %s]" % self.statement + details.append(stmt_detail) + if self.params: + if self.hide_parameters: + details.append( + "[SQL parameters hidden due to hide_parameters=True]" + ) + else: + # NOTE: this will still cause issues + from sqlalchemy.sql import util # type: ignore + + params_repr = util._repr_params( + self.params, 10, ismulti=self.ismulti + ) + details.append("[parameters: %r]" % params_repr) + code_str = self._code_str() + if code_str: + details.append(code_str) + return "\n".join(["(%s)" % det for det in self.detail] + details) + + SE._sql_message = _sql_message diff --git a/src/cachew/tests/test_cachew.py b/src/cachew/tests/test_cachew.py index 52b0c25..4306b55 100644 --- a/src/cachew/tests/test_cachew.py +++ b/src/cachew/tests/test_cachew.py @@ -1,9 +1,10 @@ from datetime import datetime, date, timezone import inspect +from itertools import islice import logging from pathlib import Path from random import Random -from subprocess import check_call +from subprocess import check_call, run, PIPE import string import sys import time @@ -1034,3 +1035,64 @@ def fun() -> Iterator[int]: assert calls == 2 assert list(fun()) == [1, 2] assert calls == 3 + + +def test_early_exit(tmp_path: Path): + cf = 0 + @cachew(tmp_path) # / 'fun', force_file=True) + def f() -> Iterator[int]: + yield from range(20) + nonlocal cf + cf += 1 + + cg = 0 + @cachew(tmp_path) # / 'fun', force_file=True) + def g() -> Iterator[int]: + yield from f() + nonlocal cg + cg += 1 + + assert len(list(islice(g(), 0, 10))) == 10 + assert cf == 0 # hasn't finished + assert cg == 0 # hasn't finished + + # todo not sure if need to check that db is empty? + + assert len(list(g())) == 20 + assert cf == 1 + assert cg == 1 + + # should be cached now + assert len(list(g())) == 20 + assert cf == 1 + assert cg == 1 + + +# see https://github.com/sqlalchemy/sqlalchemy/issues/5522#issuecomment-705156746 +def test_early_exit_shutdown(tmp_path: Path): + # don't ask... otherwise the exception doesn't appear :shrug: + import_hack = ''' +from sqlalchemy import Column + +import re +re.hack = lambda: None + ''' + Path(tmp_path / 'import_hack.py').write_text(import_hack) + + prog = f''' +import import_hack + +import cachew +cachew.settings.THROW_ON_ERROR = True # todo check with both? +@cachew.cachew('{tmp_path}', cls=int) +def fun(): + yield 0 + +g = fun() +e = next(g) + +print("FINISHED") + ''' + r = run(['python3', '-c', prog], cwd=tmp_path, stderr=PIPE, stdout=PIPE, check=True) + assert r.stdout.strip() == b'FINISHED' + assert b'Traceback' not in r.stderr