From e9529445538be506d7c4c87e775b3791eed236aa Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 31 Oct 2024 16:43:55 +0100 Subject: [PATCH] QA and CI: Format code using ruff. Validate using ruff and mypy. --- .github/workflows/tests.yml | 2 +- DEVELOP.rst | 19 +- bootstrap.sh | 2 +- docs/conf.py | 24 +- pyproject.toml | 103 +++++++- setup.cfg | 2 - setup.py | 108 +++++---- src/crate/__init__.py | 2 + src/crate/client/__init__.py | 4 +- src/crate/client/blob.py | 16 +- src/crate/client/connection.py | 100 ++++---- src/crate/client/converter.py | 19 +- src/crate/client/cursor.py | 99 ++++---- src/crate/client/exceptions.py | 8 +- src/crate/client/http.py | 350 ++++++++++++++------------- src/crate/testing/layer.py | 242 +++++++++++-------- src/crate/testing/util.py | 22 +- tests/client/layer.py | 133 ++++++----- tests/client/settings.py | 7 +- tests/client/test_connection.py | 47 ++-- tests/client/test_cursor.py | 324 ++++++++++++++++--------- tests/client/test_exceptions.py | 1 - tests/client/test_http.py | 403 ++++++++++++++++++-------------- tests/client/tests.py | 48 ++-- tests/testing/test_layer.py | 225 ++++++++++-------- tests/testing/tests.py | 2 +- 26 files changed, 1371 insertions(+), 941 deletions(-) delete mode 100644 setup.cfg diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index afb5a5b2..d2aa1af4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -60,7 +60,7 @@ jobs: echo "Invoking tests with CrateDB ${CRATEDB_VERSION}" # Run linter. - flake8 src bin + poe lint # Run tests. coverage run bin/test -vvv diff --git a/DEVELOP.rst b/DEVELOP.rst index b523a4bf..9dd66656 100644 --- a/DEVELOP.rst +++ b/DEVELOP.rst @@ -26,7 +26,7 @@ see, for example, `useful command-line options for zope-testrunner`_. Run all tests:: - bin/test + poe test Run specific tests:: @@ -77,6 +77,23 @@ are listening on the default CrateDB transport port to avoid side effects with the test layer. +Formatting and linting code +=========================== + +To use Ruff for code formatting, according to the standards configured in +``pyproject.toml``, use:: + + poe format + +To lint the code base using Ruff and mypy, use:: + + poe lint + +Linting and software testing, all together now:: + + poe check + + Renew certificates ================== diff --git a/bootstrap.sh b/bootstrap.sh index 06c52f12..50ab6d35 100644 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -97,7 +97,7 @@ function main() { } function lint() { - flake8 "$@" src bin + poe lint } main diff --git a/docs/conf.py b/docs/conf.py index 01351068..47cc4ae9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,6 +1,6 @@ +# ruff: noqa: F403, F405 from crate.theme.rtd.conf.python import * - if "sphinx.ext.intersphinx" not in extensions: extensions += ["sphinx.ext.intersphinx"] @@ -9,21 +9,25 @@ intersphinx_mapping = {} -intersphinx_mapping.update({ - 'py': ('https://docs.python.org/3/', None), - 'urllib3': ('https://urllib3.readthedocs.io/en/1.26.13/', None), - }) +intersphinx_mapping.update( + { + "py": ("https://docs.python.org/3/", None), + "urllib3": ("https://urllib3.readthedocs.io/en/1.26.13/", None), + } +) linkcheck_anchors = True linkcheck_ignore = [] # Disable version chooser. -html_context.update({ - "display_version": False, - "current_version": None, - "versions": [], -}) +html_context.update( + { + "display_version": False, + "current_version": None, + "versions": [], + } +) rst_prolog = """ .. |nbsp| unicode:: 0xA0 diff --git a/pyproject.toml b/pyproject.toml index 2f6fe486..31717680 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,102 @@ [tool.mypy] +mypy_path = "src" +packages = [ + "crate", +] +exclude = [ +] +check_untyped_defs = true +explicit_package_bases = true +ignore_missing_imports = true +implicit_optional = true +install_types = true +namespace_packages = true +non_interactive = true -# Needed until `mypy-0.990` for `ConverterDefinition` in `converter.py`. -# https://github.com/python/mypy/issues/731#issuecomment-1260976955 -enable_recursive_aliases = true + +[tool.ruff] +line-length = 80 + +extend-exclude = [ + "/example_*", +] + +lint.select = [ + # Builtins + "A", + # Bugbear + "B", + # comprehensions + "C4", + # Pycodestyle + "E", + # eradicate + "ERA", + # Pyflakes + "F", + # isort + "I", + # pandas-vet + "PD", + # return + "RET", + # Bandit + "S", + # print + "T20", + "W", + # flake8-2020 + "YTT", +] + +lint.extend-ignore = [ + # Unnecessary variable assignment before `return` statement + "RET504", + # Unnecessary `elif` after `return` statement + "RET505", +] + +lint.per-file-ignores."example_*" = [ + "ERA001", # Found commented-out code + "T201", # Allow `print` +] +lint.per-file-ignores."devtools/*" = [ + "T201", # Allow `print` +] +lint.per-file-ignores."examples/*" = [ + "ERA001", # Found commented-out code + "T201", # Allow `print` +] +lint.per-file-ignores."tests/*" = [ + "S106", # Possible hardcoded password assigned to argument: "password" + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes +] + + +# =================== +# Tasks configuration +# =================== + +[tool.poe.tasks] + +check = [ + "lint", + "test", +] + +format = [ + { cmd = "ruff format ." }, + # Configure Ruff not to auto-fix (remove!): + # unused imports (F401), unused variables (F841), `print` statements (T201), and commented-out code (ERA001). + { cmd = "ruff check --fix --ignore=ERA --ignore=F401 --ignore=F841 --ignore=T20 --ignore=ERA001 ." }, +] + +lint = [ + { cmd = "ruff format --check ." }, + { cmd = "ruff check ." }, + { cmd = "mypy" }, +] + +test = [ + { cmd = "bin/test" }, +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 79c80a4c..00000000 --- a/setup.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[flake8] -ignore = E501, C901, W503, W504 diff --git a/setup.py b/setup.py index ab6d001b..958b746f 100644 --- a/setup.py +++ b/setup.py @@ -19,78 +19,84 @@ # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. -from setuptools import setup, find_packages import os import re +from setuptools import find_packages, setup + def read(path): with open(os.path.join(os.path.dirname(__file__), path)) as f: return f.read() -long_description = read('README.rst') +long_description = read("README.rst") versionf_content = read("src/crate/client/__init__.py") version_rex = r'^__version__ = [\'"]([^\'"]*)[\'"]$' m = re.search(version_rex, versionf_content, re.M) if m: version = m.group(1) else: - raise RuntimeError('Unable to find version string') + raise RuntimeError("Unable to find version string") setup( - name='crate', + name="crate", version=version, - url='https://github.com/crate/crate-python', - author='Crate.io', - author_email='office@crate.io', - package_dir={'': 'src'}, - description='CrateDB Python Client', + url="https://github.com/crate/crate-python", + author="Crate.io", + author_email="office@crate.io", + package_dir={"": "src"}, + description="CrateDB Python Client", long_description=long_description, - long_description_content_type='text/x-rst', - platforms=['any'], - license='Apache License 2.0', - keywords='cratedb db api dbapi database sql http rdbms olap', - packages=find_packages('src'), - namespace_packages=['crate'], + long_description_content_type="text/x-rst", + platforms=["any"], + license="Apache License 2.0", + keywords="cratedb db api dbapi database sql http rdbms olap", + packages=find_packages("src"), + namespace_packages=["crate"], install_requires=[ - 'urllib3<2.3', - 'verlib2==0.2.0', + "urllib3<2.3", + "verlib2==0.2.0", ], - extras_require=dict( - test=['tox>=3,<5', - 'zope.testing>=4,<6', - 'zope.testrunner>=5,<7', - 'zc.customdoctests>=1.0.1,<2', - 'backports.zoneinfo<1; python_version<"3.9"', - 'certifi', - 'createcoverage>=1,<2', - 'stopit>=1.1.2,<2', - 'flake8>=4,<8', - 'pytz', - ], - doc=['sphinx>=3.5,<9', - 'crate-docs-theme>=0.26.5'], - ), - python_requires='>=3.6', - package_data={'': ['*.txt']}, + extras_require={ + "doc": [ + "crate-docs-theme>=0.26.5", + "sphinx>=3.5,<9", + ], + "test": [ + 'backports.zoneinfo<1; python_version<"3.9"', + "certifi", + "createcoverage>=1,<2", + "mypy<1.14", + "poethepoet<0.30", + "ruff<0.8", + "stopit>=1.1.2,<2", + "tox>=3,<5", + "pytz", + "zc.customdoctests>=1.0.1,<2", + "zope.testing>=4,<6", + "zope.testrunner>=5,<7", + ], + }, + python_requires=">=3.6", + package_data={"": ["*.txt"]}, classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Programming Language :: Python :: 3.13', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', - 'Topic :: Database' + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Database", ], ) diff --git a/src/crate/__init__.py b/src/crate/__init__.py index 1fcff2bb..026c0677 100644 --- a/src/crate/__init__.py +++ b/src/crate/__init__.py @@ -22,7 +22,9 @@ # this is a namespace package try: import pkg_resources + pkg_resources.declare_namespace(__name__) except ImportError: import pkgutil + __path__ = pkgutil.extend_path(__path__, __name__) diff --git a/src/crate/client/__init__.py b/src/crate/client/__init__.py index 7e6e610e..639ab201 100644 --- a/src/crate/client/__init__.py +++ b/src/crate/client/__init__.py @@ -23,8 +23,8 @@ from .exceptions import Error __all__ = [ - connect, - Error, + "connect", + "Error", ] # version string read from setup.py using a regex. Take care not to break the diff --git a/src/crate/client/blob.py b/src/crate/client/blob.py index 73d733ef..4b0528ba 100644 --- a/src/crate/client/blob.py +++ b/src/crate/client/blob.py @@ -22,8 +22,8 @@ import hashlib -class BlobContainer(object): - """ class that represents a blob collection in crate. +class BlobContainer: + """class that represents a blob collection in crate. can be used to download, upload and delete blobs """ @@ -34,7 +34,7 @@ def __init__(self, container_name, connection): def _compute_digest(self, f): f.seek(0) - m = hashlib.sha1() + m = hashlib.sha1() # noqa: S324 while True: d = f.read(1024 * 32) if not d: @@ -64,8 +64,9 @@ def put(self, f, digest=None): else: actual_digest = self._compute_digest(f) - created = self.conn.client.blob_put(self.container_name, - actual_digest, f) + created = self.conn.client.blob_put( + self.container_name, actual_digest, f + ) if digest: return created return actual_digest @@ -78,8 +79,9 @@ def get(self, digest, chunk_size=1024 * 128): :param chunk_size: the size of the chunks returned on each iteration :return: generator returning chunks of data """ - return self.conn.client.blob_get(self.container_name, digest, - chunk_size) + return self.conn.client.blob_get( + self.container_name, digest, chunk_size + ) def delete(self, digest): """ diff --git a/src/crate/client/connection.py b/src/crate/client/connection.py index 9e72b2f7..de7682f6 100644 --- a/src/crate/client/connection.py +++ b/src/crate/client/connection.py @@ -19,37 +19,38 @@ # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. +from verlib2 import Version + +from .blob import BlobContainer from .cursor import Cursor -from .exceptions import ProgrammingError, ConnectionError +from .exceptions import ConnectionError, ProgrammingError from .http import Client -from .blob import BlobContainer -from verlib2 import Version -class Connection(object): - - def __init__(self, - servers=None, - timeout=None, - backoff_factor=0, - client=None, - verify_ssl_cert=True, - ca_cert=None, - error_trace=False, - cert_file=None, - key_file=None, - ssl_relax_minimum_version=False, - username=None, - password=None, - schema=None, - pool_size=None, - socket_keepalive=True, - socket_tcp_keepidle=None, - socket_tcp_keepintvl=None, - socket_tcp_keepcnt=None, - converter=None, - time_zone=None, - ): +class Connection: + def __init__( + self, + servers=None, + timeout=None, + backoff_factor=0, + client=None, + verify_ssl_cert=True, + ca_cert=None, + error_trace=False, + cert_file=None, + key_file=None, + ssl_relax_minimum_version=False, + username=None, + password=None, + schema=None, + pool_size=None, + socket_keepalive=True, + socket_tcp_keepidle=None, + socket_tcp_keepintvl=None, + socket_tcp_keepcnt=None, + converter=None, + time_zone=None, + ): """ :param servers: either a string in the form of ':' @@ -123,7 +124,7 @@ def __init__(self, When `time_zone` is given, the returned `datetime` objects are "aware", with `tzinfo` set, converted using ``datetime.fromtimestamp(..., tz=...)``. - """ + """ # noqa: E501 self._converter = converter self.time_zone = time_zone @@ -131,24 +132,25 @@ def __init__(self, if client: self.client = client else: - self.client = Client(servers, - timeout=timeout, - backoff_factor=backoff_factor, - verify_ssl_cert=verify_ssl_cert, - ca_cert=ca_cert, - error_trace=error_trace, - cert_file=cert_file, - key_file=key_file, - ssl_relax_minimum_version=ssl_relax_minimum_version, - username=username, - password=password, - schema=schema, - pool_size=pool_size, - socket_keepalive=socket_keepalive, - socket_tcp_keepidle=socket_tcp_keepidle, - socket_tcp_keepintvl=socket_tcp_keepintvl, - socket_tcp_keepcnt=socket_tcp_keepcnt, - ) + self.client = Client( + servers, + timeout=timeout, + backoff_factor=backoff_factor, + verify_ssl_cert=verify_ssl_cert, + ca_cert=ca_cert, + error_trace=error_trace, + cert_file=cert_file, + key_file=key_file, + ssl_relax_minimum_version=ssl_relax_minimum_version, + username=username, + password=password, + schema=schema, + pool_size=pool_size, + socket_keepalive=socket_keepalive, + socket_tcp_keepidle=socket_tcp_keepidle, + socket_tcp_keepintvl=socket_tcp_keepintvl, + socket_tcp_keepcnt=socket_tcp_keepcnt, + ) self.lowest_server_version = self._lowest_server_version() self._closed = False @@ -182,7 +184,7 @@ def commit(self): raise ProgrammingError("Connection closed") def get_blob_container(self, container_name): - """ Retrieve a BlobContainer for `container_name` + """Retrieve a BlobContainer for `container_name` :param container_name: the name of the BLOB container. :returns: a :class:ContainerObject @@ -199,10 +201,10 @@ def _lowest_server_version(self): continue if not lowest or version < lowest: lowest = version - return lowest or Version('0.0.0') + return lowest or Version("0.0.0") def __repr__(self): - return ''.format(repr(self.client)) + return "".format(repr(self.client)) def __enter__(self): return self diff --git a/src/crate/client/converter.py b/src/crate/client/converter.py index c4dbf598..dd29e868 100644 --- a/src/crate/client/converter.py +++ b/src/crate/client/converter.py @@ -23,6 +23,7 @@ https://crate.io/docs/crate/reference/en/latest/interfaces/http.html#column-types """ + import ipaddress from copy import deepcopy from datetime import datetime @@ -33,7 +34,9 @@ ColTypesDefinition = Union[int, List[Union[int, "ColTypesDefinition"]]] -def _to_ipaddress(value: Optional[str]) -> Optional[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]: +def _to_ipaddress( + value: Optional[str], +) -> Optional[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]: """ https://docs.python.org/3/library/ipaddress.html """ @@ -55,7 +58,7 @@ def _to_default(value: Optional[Any]) -> Optional[Any]: return value -# Symbolic aliases for the numeric data type identifiers defined by the CrateDB HTTP interface. +# Data type identifiers defined by the CrateDB HTTP interface. # https://crate.io/docs/crate/reference/en/latest/interfaces/http.html#column-types class DataType(Enum): NULL = 0 @@ -112,7 +115,9 @@ def get(self, type_: ColTypesDefinition) -> ConverterFunction: return self._mappings.get(DataType(type_), self._default) type_, inner_type = type_ if DataType(type_) is not DataType.ARRAY: - raise ValueError(f"Data type {type_} is not implemented as collection type") + raise ValueError( + f"Data type {type_} is not implemented as collection type" + ) inner_convert = self.get(inner_type) @@ -128,11 +133,11 @@ def set(self, type_: DataType, converter: ConverterFunction): class DefaultTypeConverter(Converter): - def __init__(self, more_mappings: Optional[ConverterMapping] = None) -> None: + def __init__( + self, more_mappings: Optional[ConverterMapping] = None + ) -> None: mappings: ConverterMapping = {} mappings.update(deepcopy(_DEFAULT_CONVERTERS)) if more_mappings: mappings.update(deepcopy(more_mappings)) - super().__init__( - mappings=mappings, default=_to_default - ) + super().__init__(mappings=mappings, default=_to_default) diff --git a/src/crate/client/cursor.py b/src/crate/client/cursor.py index c458ae1b..cf79efa7 100644 --- a/src/crate/client/cursor.py +++ b/src/crate/client/cursor.py @@ -18,21 +18,20 @@ # However, if you have executed another commercial license agreement # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. -from datetime import datetime, timedelta, timezone - -from .converter import DataType -import warnings import typing as t +import warnings +from datetime import datetime, timedelta, timezone -from .converter import Converter +from .converter import Converter, DataType from .exceptions import ProgrammingError -class Cursor(object): +class Cursor: """ not thread-safe by intention should not be shared between different threads """ + lastrowid = None # currently not supported def __init__(self, connection, converter: Converter, **kwargs): @@ -40,7 +39,7 @@ def __init__(self, connection, converter: Converter, **kwargs): self.connection = connection self._converter = converter self._closed = False - self._result = None + self._result: t.Dict[str, t.Any] = {} self.rows = None self._time_zone = None self.time_zone = kwargs.get("time_zone") @@ -55,8 +54,9 @@ def execute(self, sql, parameters=None, bulk_parameters=None): if self._closed: raise ProgrammingError("Cursor closed") - self._result = self.connection.client.sql(sql, parameters, - bulk_parameters) + self._result = self.connection.client.sql( + sql, parameters, bulk_parameters + ) if "rows" in self._result: if self._converter is None: self.rows = iter(self._result["rows"]) @@ -73,9 +73,9 @@ def executemany(self, sql, seq_of_parameters): durations = [] self.execute(sql, bulk_parameters=seq_of_parameters) - for result in self._result.get('results', []): - if result.get('rowcount') > -1: - row_counts.append(result.get('rowcount')) + for result in self._result.get("results", []): + if result.get("rowcount") > -1: + row_counts.append(result.get("rowcount")) if self.duration > -1: durations.append(self.duration) @@ -85,7 +85,7 @@ def executemany(self, sql, seq_of_parameters): "rows": [], "cols": self._result.get("cols", []), "col_types": self._result.get("col_types", []), - "results": self._result.get("results") + "results": self._result.get("results"), } if self._converter is None: self.rows = iter(self._result["rows"]) @@ -112,7 +112,7 @@ def __iter__(self): This iterator is shared. Advancing this iterator will advance other iterators created from this cursor. """ - warnings.warn("DB-API extension cursor.__iter__() used") + warnings.warn("DB-API extension cursor.__iter__() used", stacklevel=2) return self def fetchmany(self, count=None): @@ -126,7 +126,7 @@ def fetchmany(self, count=None): if count == 0: return self.fetchall() result = [] - for i in range(count): + for _ in range(count): try: result.append(self.next()) except StopIteration: @@ -153,7 +153,7 @@ def close(self): Close the cursor now """ self._closed = True - self._result = None + self._result = {} def setinputsizes(self, sizes): """ @@ -174,7 +174,7 @@ def rowcount(self): .execute*() produced (for DQL statements like ``SELECT``) or affected (for DML statements like ``UPDATE`` or ``INSERT``). """ - if (self._closed or not self._result or "rows" not in self._result): + if self._closed or not self._result or "rows" not in self._result: return -1 return self._result.get("rowcount", -1) @@ -185,10 +185,10 @@ def next(self): """ if self.rows is None: raise ProgrammingError( - "No result available. " + - "execute() or executemany() must be called first." + "No result available. " + + "execute() or executemany() must be called first." ) - elif not self._closed: + if not self._closed: return next(self.rows) else: raise ProgrammingError("Cursor closed") @@ -201,17 +201,11 @@ def description(self): This read-only attribute is a sequence of 7-item sequences. """ if self._closed: - return + return None description = [] for col in self._result["cols"]: - description.append((col, - None, - None, - None, - None, - None, - None)) + description.append((col, None, None, None, None, None, None)) return tuple(description) @property @@ -220,9 +214,7 @@ def duration(self): This read-only attribute specifies the server-side duration of a query in milliseconds. """ - if self._closed or \ - not self._result or \ - "duration" not in self._result: + if self._closed or not self._result or "duration" not in self._result: return -1 return self._result.get("duration", 0) @@ -230,22 +222,19 @@ def _convert_rows(self): """ Iterate rows, apply type converters, and generate converted rows. """ - assert "col_types" in self._result and self._result["col_types"], \ - "Unable to apply type conversion without `col_types` information" + assert ( # noqa: S101 + "col_types" in self._result and self._result["col_types"] + ), "Unable to apply type conversion without `col_types` information" - # Resolve `col_types` definition to converter functions. Running the lookup - # redundantly on each row loop iteration would be a huge performance hog. + # Resolve `col_types` definition to converter functions. Running + # the lookup redundantly on each row loop iteration would be a + # huge performance hog. types = self._result["col_types"] - converters = [ - self._converter.get(type) for type in types - ] + converters = [self._converter.get(type_) for type_ in types] # Process result rows with conversion. for row in self._result["rows"]: - yield [ - convert(value) - for convert, value in zip(converters, row) - ] + yield [convert(value) for convert, value in zip(converters, row)] @property def time_zone(self): @@ -268,10 +257,11 @@ def time_zone(self, tz): - ``+0530`` (UTC offset in string format) When `time_zone` is `None`, the returned `datetime` objects are - "naive", without any `tzinfo`, converted using ``datetime.utcfromtimestamp(...)``. + "naive", without any `tzinfo`, converted using + `datetime.utcfromtimestamp(...)`. When `time_zone` is given, the returned `datetime` objects are "aware", - with `tzinfo` set, converted using ``datetime.fromtimestamp(..., tz=...)``. + with `tzinfo` set, converted by `datetime.fromtimestamp(..., tz=...)`. """ # Do nothing when time zone is reset. @@ -279,18 +269,22 @@ def time_zone(self, tz): self._time_zone = None return - # Requesting datetime-aware `datetime` objects needs the data type converter. + # Requesting datetime-aware `datetime` objects + # needs the data type converter. # Implicitly create one, when needed. if self._converter is None: self._converter = Converter() - # When the time zone is given as a string, assume UTC offset format, e.g. `+0530`. + # When the time zone is given as a string, + # assume UTC offset format, e.g. `+0530`. if isinstance(tz, str): tz = self._timezone_from_utc_offset(tz) self._time_zone = tz - def _to_datetime_with_tz(value: t.Optional[float]) -> t.Optional[datetime]: + def _to_datetime_with_tz( + value: t.Optional[float], + ) -> t.Optional[datetime]: """ Convert CrateDB's `TIMESTAMP` value to a native Python `datetime` object, with timezone-awareness. @@ -306,12 +300,17 @@ def _to_datetime_with_tz(value: t.Optional[float]) -> t.Optional[datetime]: @staticmethod def _timezone_from_utc_offset(tz) -> timezone: """ - Convert UTC offset in string format (e.g. `+0530`) into `datetime.timezone` object. + UTC offset in string format (e.g. `+0530`) to `datetime.timezone`. """ - assert len(tz) == 5, f"Time zone '{tz}' is given in invalid UTC offset format" + # TODO: Remove use of `assert`. Better use exceptions? + assert ( # noqa: S101 + len(tz) == 5 + ), f"Time zone '{tz}' is given in invalid UTC offset format" try: hours = int(tz[:3]) minutes = int(tz[0] + tz[3:]) return timezone(timedelta(hours=hours, minutes=minutes), name=tz) except Exception as ex: - raise ValueError(f"Time zone '{tz}' is given in invalid UTC offset format: {ex}") + raise ValueError( + f"Time zone '{tz}' is given in invalid UTC offset format: {ex}" + ) from ex diff --git a/src/crate/client/exceptions.py b/src/crate/client/exceptions.py index 175cb30c..3833eecc 100644 --- a/src/crate/client/exceptions.py +++ b/src/crate/client/exceptions.py @@ -21,7 +21,6 @@ class Error(Exception): - def __init__(self, msg=None, error_trace=None): # for compatibility reasons we want to keep the exception message # attribute because clients may depend on it @@ -36,7 +35,8 @@ def __str__(self): return "\n".join([super().__str__(), str(self.error_trace)]) -class Warning(Exception): +# A001 Variable `Warning` is shadowing a Python builtin +class Warning(Exception): # noqa: A001 pass @@ -74,7 +74,9 @@ class NotSupportedError(DatabaseError): # exceptions not in db api -class ConnectionError(OperationalError): + +# A001 Variable `ConnectionError` is shadowing a Python builtin +class ConnectionError(OperationalError): # noqa: A001 pass diff --git a/src/crate/client/http.py b/src/crate/client/http.py index 78e0e594..d0b18a04 100644 --- a/src/crate/client/http.py +++ b/src/crate/client/http.py @@ -30,11 +30,11 @@ import socket import ssl import threading -from urllib.parse import urlparse from base64 import b64encode -from time import time -from datetime import datetime, date, timezone +from datetime import date, datetime, timezone from decimal import Decimal +from time import time +from urllib.parse import urlparse from uuid import UUID import urllib3 @@ -52,42 +52,41 @@ from verlib2 import Version from crate.client.exceptions import ( - ConnectionError, BlobLocationNotFoundException, + ConnectionError, DigestNotFoundException, - ProgrammingError, IntegrityError, + ProgrammingError, ) - logger = logging.getLogger(__name__) -_HTTP_PAT = pat = re.compile('https?://.+', re.I) -SRV_UNAVAILABLE_STATUSES = set((502, 503, 504, 509)) -PRESERVE_ACTIVE_SERVER_EXCEPTIONS = set((ConnectionResetError, BrokenPipeError)) -SSL_ONLY_ARGS = set(('ca_certs', 'cert_reqs', 'cert_file', 'key_file')) +_HTTP_PAT = pat = re.compile("https?://.+", re.I) +SRV_UNAVAILABLE_STATUSES = {502, 503, 504, 509} +PRESERVE_ACTIVE_SERVER_EXCEPTIONS = {ConnectionResetError, BrokenPipeError} +SSL_ONLY_ARGS = {"ca_certs", "cert_reqs", "cert_file", "key_file"} def super_len(o): - if hasattr(o, '__len__'): + if hasattr(o, "__len__"): return len(o) - if hasattr(o, 'len'): + if hasattr(o, "len"): return o.len - if hasattr(o, 'fileno'): + if hasattr(o, "fileno"): try: fileno = o.fileno() except io.UnsupportedOperation: pass else: return os.fstat(fileno).st_size - if hasattr(o, 'getvalue'): + if hasattr(o, "getvalue"): # e.g. BytesIO, cStringIO.StringI return len(o.getvalue()) + return None class CrateJsonEncoder(json.JSONEncoder): - epoch_aware = datetime(1970, 1, 1, tzinfo=timezone.utc) epoch_naive = datetime(1970, 1, 1) @@ -99,21 +98,22 @@ def default(self, o): delta = o - self.epoch_aware else: delta = o - self.epoch_naive - return int(delta.microseconds / 1000.0 + - (delta.seconds + delta.days * 24 * 3600) * 1000.0) + return int( + delta.microseconds / 1000.0 + + (delta.seconds + delta.days * 24 * 3600) * 1000.0 + ) if isinstance(o, date): return calendar.timegm(o.timetuple()) * 1000 return json.JSONEncoder.default(self, o) -class Server(object): - +class Server: def __init__(self, server, **pool_kw): socket_options = _get_socket_opts( - pool_kw.pop('socket_keepalive', False), - pool_kw.pop('socket_tcp_keepidle', None), - pool_kw.pop('socket_tcp_keepintvl', None), - pool_kw.pop('socket_tcp_keepcnt', None), + pool_kw.pop("socket_keepalive", False), + pool_kw.pop("socket_tcp_keepidle", None), + pool_kw.pop("socket_tcp_keepintvl", None), + pool_kw.pop("socket_tcp_keepcnt", None), ) self.pool = connection_from_url( server, @@ -121,53 +121,57 @@ def __init__(self, server, **pool_kw): **pool_kw, ) - def request(self, - method, - path, - data=None, - stream=False, - headers=None, - username=None, - password=None, - schema=None, - backoff_factor=0, - **kwargs): + def request( + self, + method, + path, + data=None, + stream=False, + headers=None, + username=None, + password=None, + schema=None, + backoff_factor=0, + **kwargs, + ): """Send a request Always set the Content-Length and the Content-Type header. """ if headers is None: headers = {} - if 'Content-Length' not in headers: + if "Content-Length" not in headers: length = super_len(data) if length is not None: - headers['Content-Length'] = length + headers["Content-Length"] = length # Authentication credentials if username is not None: - if 'Authorization' not in headers and username is not None: - credentials = username + ':' + if "Authorization" not in headers and username is not None: + credentials = username + ":" if password is not None: credentials += password - headers['Authorization'] = 'Basic %s' % b64encode(credentials.encode('utf-8')).decode('utf-8') + headers["Authorization"] = "Basic %s" % b64encode( + credentials.encode("utf-8") + ).decode("utf-8") # For backwards compatibility with Crate <= 2.2 - if 'X-User' not in headers: - headers['X-User'] = username + if "X-User" not in headers: + headers["X-User"] = username if schema is not None: - headers['Default-Schema'] = schema - headers['Accept'] = 'application/json' - headers['Content-Type'] = 'application/json' - kwargs['assert_same_host'] = False - kwargs['redirect'] = False - kwargs['retries'] = Retry(read=0, backoff_factor=backoff_factor) + headers["Default-Schema"] = schema + headers["Accept"] = "application/json" + headers["Content-Type"] = "application/json" + kwargs["assert_same_host"] = False + kwargs["redirect"] = False + kwargs["retries"] = Retry(read=0, backoff_factor=backoff_factor) return self.pool.urlopen( method, path, body=data, preload_content=not stream, headers=headers, - **kwargs + **kwargs, ) def close(self): @@ -176,24 +180,27 @@ def close(self): def _json_from_response(response): try: - return json.loads(response.data.decode('utf-8')) - except ValueError: + return json.loads(response.data.decode("utf-8")) + except ValueError as ex: raise ProgrammingError( - "Invalid server response of content-type '{}':\n{}" - .format(response.headers.get("content-type", "unknown"), response.data.decode('utf-8'))) + "Invalid server response of content-type '{}':\n{}".format( + response.headers.get("content-type", "unknown"), + response.data.decode("utf-8"), + ) + ) from ex def _blob_path(table, digest): - return '/_blobs/{table}/{digest}'.format(table=table, digest=digest) + return "/_blobs/{table}/{digest}".format(table=table, digest=digest) def _ex_to_message(ex): - return getattr(ex, 'message', None) or str(ex) or repr(ex) + return getattr(ex, "message", None) or str(ex) or repr(ex) def _raise_for_status(response): """ - Properly raise `IntegrityError` exceptions for CrateDB's `DuplicateKeyException` errors. + Raise `IntegrityError` exceptions for `DuplicateKeyException` errors. """ try: return _raise_for_status_real(response) @@ -204,29 +211,33 @@ def _raise_for_status(response): def _raise_for_status_real(response): - """ make sure that only crate.exceptions are raised that are defined in - the DB-API specification """ - message = '' + """make sure that only crate.exceptions are raised that are defined in + the DB-API specification""" + message = "" if 400 <= response.status < 500: - message = '%s Client Error: %s' % (response.status, response.reason) + message = "%s Client Error: %s" % (response.status, response.reason) elif 500 <= response.status < 600: - message = '%s Server Error: %s' % (response.status, response.reason) + message = "%s Server Error: %s" % (response.status, response.reason) else: return if response.status == 503: raise ConnectionError(message) if response.headers.get("content-type", "").startswith("application/json"): - data = json.loads(response.data.decode('utf-8')) - error = data.get('error', {}) - error_trace = data.get('error_trace', None) + data = json.loads(response.data.decode("utf-8")) + error = data.get("error", {}) + error_trace = data.get("error_trace", None) if "results" in data: - errors = [res["error_message"] for res in data["results"] - if res.get("error_message")] + errors = [ + res["error_message"] + for res in data["results"] + if res.get("error_message") + ] if errors: raise ProgrammingError("\n".join(errors)) if isinstance(error, dict): - raise ProgrammingError(error.get('message', ''), - error_trace=error_trace) + raise ProgrammingError( + error.get("message", ""), error_trace=error_trace + ) raise ProgrammingError(error, error_trace=error_trace) raise ProgrammingError(message) @@ -247,9 +258,9 @@ def _server_url(server): http://demo.crate.io """ if not _HTTP_PAT.match(server): - server = 'http://%s' % server + server = "http://%s" % server parsed = urlparse(server) - url = '%s://%s' % (parsed.scheme, parsed.netloc) + url = "%s://%s" % (parsed.scheme, parsed.netloc) return url @@ -259,30 +270,36 @@ def _to_server_list(servers): return [_server_url(s) for s in servers] -def _pool_kw_args(verify_ssl_cert, ca_cert, client_cert, client_key, - timeout=None, pool_size=None): - ca_cert = ca_cert or os.environ.get('REQUESTS_CA_BUNDLE', None) +def _pool_kw_args( + verify_ssl_cert, + ca_cert, + client_cert, + client_key, + timeout=None, + pool_size=None, +): + ca_cert = ca_cert or os.environ.get("REQUESTS_CA_BUNDLE", None) if ca_cert and not os.path.exists(ca_cert): # Sanity check raise IOError('CA bundle file "{}" does not exist.'.format(ca_cert)) kw = { - 'ca_certs': ca_cert, - 'cert_reqs': ssl.CERT_REQUIRED if verify_ssl_cert else ssl.CERT_NONE, - 'cert_file': client_cert, - 'key_file': client_key, + "ca_certs": ca_cert, + "cert_reqs": ssl.CERT_REQUIRED if verify_ssl_cert else ssl.CERT_NONE, + "cert_file": client_cert, + "key_file": client_key, } if timeout is not None: if isinstance(timeout, str): timeout = float(timeout) - kw['timeout'] = timeout + kw["timeout"] = timeout if pool_size is not None: - kw['maxsize'] = int(pool_size) + kw["maxsize"] = int(pool_size) return kw def _remove_certs_for_non_https(server, kwargs): - if server.lower().startswith('https'): + if server.lower().startswith("https"): return kwargs used_ssl_args = SSL_ONLY_ARGS & set(kwargs.keys()) if used_ssl_args: @@ -300,6 +317,7 @@ def _update_pool_kwargs_for_ssl_minimum_version(server, kwargs): """ if Version(urllib3.__version__) >= Version("2"): from urllib3.util import parse_url + scheme, _, host, port, *_ = parse_url(server) if scheme == "https": kwargs["ssl_minimum_version"] = ssl.TLSVersion.MINIMUM_SUPPORTED @@ -307,24 +325,21 @@ def _update_pool_kwargs_for_ssl_minimum_version(server, kwargs): def _create_sql_payload(stmt, args, bulk_args): if not isinstance(stmt, str): - raise ValueError('stmt is not a string') + raise ValueError("stmt is not a string") if args and bulk_args: - raise ValueError('Cannot provide both: args and bulk_args') + raise ValueError("Cannot provide both: args and bulk_args") - data = { - 'stmt': stmt - } + data = {"stmt": stmt} if args: - data['args'] = args + data["args"] = args if bulk_args: - data['bulk_args'] = bulk_args + data["bulk_args"] = bulk_args return json.dumps(data, cls=CrateJsonEncoder) -def _get_socket_opts(keepalive=True, - tcp_keepidle=None, - tcp_keepintvl=None, - tcp_keepcnt=None): +def _get_socket_opts( + keepalive=True, tcp_keepidle=None, tcp_keepintvl=None, tcp_keepcnt=None +): """ Return an optional list of socket options for urllib3's HTTPConnection constructor. @@ -337,23 +352,23 @@ def _get_socket_opts(keepalive=True, # hasattr check because some options depend on system capabilities # see https://docs.python.org/3/library/socket.html#socket.SOMAXCONN - if hasattr(socket, 'TCP_KEEPIDLE') and tcp_keepidle is not None: + if hasattr(socket, "TCP_KEEPIDLE") and tcp_keepidle is not None: opts.append((socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, tcp_keepidle)) - if hasattr(socket, 'TCP_KEEPINTVL') and tcp_keepintvl is not None: + if hasattr(socket, "TCP_KEEPINTVL") and tcp_keepintvl is not None: opts.append((socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, tcp_keepintvl)) - if hasattr(socket, 'TCP_KEEPCNT') and tcp_keepcnt is not None: + if hasattr(socket, "TCP_KEEPCNT") and tcp_keepcnt is not None: opts.append((socket.IPPROTO_TCP, socket.TCP_KEEPCNT, tcp_keepcnt)) # additionally use urllib3's default socket options return HTTPConnection.default_socket_options + opts -class Client(object): +class Client: """ Crate connection client using CrateDB's HTTP API. """ - SQL_PATH = '/_sql?types=true' + SQL_PATH = "/_sql?types=true" """Crate URI path for issuing SQL statements.""" retry_interval = 30 @@ -362,25 +377,26 @@ class Client(object): default_server = "http://127.0.0.1:4200" """Default server to use if no servers are given on instantiation.""" - def __init__(self, - servers=None, - timeout=None, - backoff_factor=0, - verify_ssl_cert=True, - ca_cert=None, - error_trace=False, - cert_file=None, - key_file=None, - ssl_relax_minimum_version=False, - username=None, - password=None, - schema=None, - pool_size=None, - socket_keepalive=True, - socket_tcp_keepidle=None, - socket_tcp_keepintvl=None, - socket_tcp_keepcnt=None, - ): + def __init__( + self, + servers=None, + timeout=None, + backoff_factor=0, + verify_ssl_cert=True, + ca_cert=None, + error_trace=False, + cert_file=None, + key_file=None, + ssl_relax_minimum_version=False, + username=None, + password=None, + schema=None, + pool_size=None, + socket_keepalive=True, + socket_tcp_keepidle=None, + socket_tcp_keepintvl=None, + socket_tcp_keepcnt=None, + ): if not servers: servers = [self.default_server] else: @@ -396,22 +412,30 @@ def __init__(self, if url.password is not None: password = url.password except Exception as ex: - logger.warning("Unable to decode credentials from database " - "URI, so connecting to CrateDB without " - "authentication: {ex}" - .format(ex=ex)) + logger.warning( + "Unable to decode credentials from database " + "URI, so connecting to CrateDB without " + "authentication: {ex}".format(ex=ex) + ) self._active_servers = servers self._inactive_servers = [] pool_kw = _pool_kw_args( - verify_ssl_cert, ca_cert, cert_file, key_file, timeout, pool_size, + verify_ssl_cert, + ca_cert, + cert_file, + key_file, + timeout, + pool_size, + ) + pool_kw.update( + { + "socket_keepalive": socket_keepalive, + "socket_tcp_keepidle": socket_tcp_keepidle, + "socket_tcp_keepintvl": socket_tcp_keepintvl, + "socket_tcp_keepcnt": socket_tcp_keepcnt, + } ) - pool_kw.update({ - 'socket_keepalive': socket_keepalive, - 'socket_tcp_keepidle': socket_tcp_keepidle, - 'socket_tcp_keepintvl': socket_tcp_keepintvl, - 'socket_tcp_keepcnt': socket_tcp_keepcnt, - }) self.ssl_relax_minimum_version = ssl_relax_minimum_version self.backoff_factor = backoff_factor self.server_pool = {} @@ -425,7 +449,7 @@ def __init__(self, self.path = self.SQL_PATH if error_trace: - self.path += '&error_trace=true' + self.path += "&error_trace=true" def close(self): for server in self.server_pool.values(): @@ -433,8 +457,9 @@ def close(self): def _create_server(self, server, **pool_kw): kwargs = _remove_certs_for_non_https(server, pool_kw) - # After updating to urllib3 v2, optionally retain support for TLS 1.0 and TLS 1.1, - # in order to support connectivity to older versions of CrateDB. + # After updating to urllib3 v2, optionally retain support + # for TLS 1.0 and TLS 1.1, in order to support connectivity + # to older versions of CrateDB. if self.ssl_relax_minimum_version: _update_pool_kwargs_for_ssl_minimum_version(server, kwargs) self.server_pool[server] = Server(server, **kwargs) @@ -451,28 +476,26 @@ def sql(self, stmt, parameters=None, bulk_parameters=None): return None data = _create_sql_payload(stmt, parameters, bulk_parameters) - logger.debug( - 'Sending request to %s with payload: %s', self.path, data) - content = self._json_request('POST', self.path, data=data) + logger.debug("Sending request to %s with payload: %s", self.path, data) + content = self._json_request("POST", self.path, data=data) logger.debug("JSON response for stmt(%s): %s", stmt, content) return content def server_infos(self, server): - response = self._request('GET', '/', server=server) + response = self._request("GET", "/", server=server) _raise_for_status(response) content = _json_from_response(response) node_name = content.get("name") - node_version = content.get('version', {}).get('number', '0.0.0') + node_version = content.get("version", {}).get("number", "0.0.0") return server, node_name, node_version - def blob_put(self, table, digest, data): + def blob_put(self, table, digest, data) -> bool: """ Stores the contents of the file like @data object in a blob under the given table and digest. """ - response = self._request('PUT', _blob_path(table, digest), - data=data) + response = self._request("PUT", _blob_path(table, digest), data=data) if response.status == 201: # blob created return True @@ -482,40 +505,43 @@ def blob_put(self, table, digest, data): if response.status in (400, 404): raise BlobLocationNotFoundException(table, digest) _raise_for_status(response) + return False - def blob_del(self, table, digest): + def blob_del(self, table, digest) -> bool: """ Deletes the blob with given digest under the given table. """ - response = self._request('DELETE', _blob_path(table, digest)) + response = self._request("DELETE", _blob_path(table, digest)) if response.status == 204: return True if response.status == 404: return False _raise_for_status(response) + return False def blob_get(self, table, digest, chunk_size=1024 * 128): """ Returns a file like object representing the contents of the blob with the given digest. """ - response = self._request('GET', _blob_path(table, digest), stream=True) + response = self._request("GET", _blob_path(table, digest), stream=True) if response.status == 404: raise DigestNotFoundException(table, digest) _raise_for_status(response) return response.stream(amt=chunk_size) - def blob_exists(self, table, digest): + def blob_exists(self, table, digest) -> bool: """ Returns true if the blob with the given digest exists under the given table. """ - response = self._request('HEAD', _blob_path(table, digest)) + response = self._request("HEAD", _blob_path(table, digest)) if response.status == 200: return True elif response.status == 404: return False _raise_for_status(response) + return False def _add_server(self, server): with self._lock: @@ -537,42 +563,45 @@ def _request(self, method, path, server=None, **kwargs): password=self.password, backoff_factor=self.backoff_factor, schema=self.schema, - **kwargs + **kwargs, ) redirect_location = response.get_redirect_location() if redirect_location and 300 <= response.status <= 308: redirect_server = _server_url(redirect_location) self._add_server(redirect_server) return self._request( - method, path, server=redirect_server, **kwargs) + method, path, server=redirect_server, **kwargs + ) if not server and response.status in SRV_UNAVAILABLE_STATUSES: with self._lock: # drop server from active ones self._drop_server(next_server, response.reason) else: return response - except (MaxRetryError, - ReadTimeoutError, - SSLError, - HTTPError, - ProxyError,) as ex: + except ( + MaxRetryError, + ReadTimeoutError, + SSLError, + HTTPError, + ProxyError, + ) as ex: ex_message = _ex_to_message(ex) if server: raise ConnectionError( "Server not available, exception: %s" % ex_message - ) + ) from ex preserve_server = False if isinstance(ex, ProtocolError): preserve_server = any( t in [type(arg) for arg in ex.args] for t in PRESERVE_ACTIVE_SERVER_EXCEPTIONS ) - if (not preserve_server): + if not preserve_server: with self._lock: # drop server from active ones self._drop_server(next_server, ex_message) except Exception as e: - raise ProgrammingError(_ex_to_message(e)) + raise ProgrammingError(_ex_to_message(e)) from e def _json_request(self, method, path, data): """ @@ -592,7 +621,7 @@ def _get_server(self): """ with self._lock: inactive_server_count = len(self._inactive_servers) - for i in range(inactive_server_count): + for _ in range(inactive_server_count): try: ts, server, message = heapq.heappop(self._inactive_servers) except IndexError: @@ -600,12 +629,14 @@ def _get_server(self): else: if (ts + self.retry_interval) > time(): # Not yet, put it back - heapq.heappush(self._inactive_servers, - (ts, server, message)) + heapq.heappush( + self._inactive_servers, (ts, server, message) + ) else: self._active_servers.append(server) - logger.warning("Restored server %s into active pool", - server) + logger.warning( + "Restored server %s into active pool", server + ) # if none is old enough, use oldest if not self._active_servers: @@ -639,8 +670,9 @@ def _drop_server(self, server, message): # if this is the last server raise exception, otherwise try next if not self._active_servers: raise ConnectionError( - ("No more Servers available, " - "exception from last server: %s") % message) + ("No more Servers available, " "exception from last server: %s") + % message + ) def _roundrobin(self): """ @@ -649,4 +681,4 @@ def _roundrobin(self): self._active_servers.append(self._active_servers.pop(0)) def __repr__(self): - return ''.format(str(self._active_servers)) + return "".format(str(self._active_servers)) diff --git a/src/crate/testing/layer.py b/src/crate/testing/layer.py index ef8bfe2b..8ff9f24c 100644 --- a/src/crate/testing/layer.py +++ b/src/crate/testing/layer.py @@ -19,38 +19,44 @@ # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. +# ruff: noqa: S603 # `subprocess` call: check for execution of untrusted input +# ruff: noqa: S202 # Uses of `tarfile.extractall()` + +import io +import json +import logging import os import re -import sys -import time -import json -import urllib3 -import tempfile import shutil import subprocess +import sys import tarfile -import io +import tempfile import threading -import logging +import time + +import urllib3 try: from urllib.request import urlopen except ImportError: - from urllib import urlopen + from urllib import urlopen # type: ignore[attr-defined,no-redef] log = logging.getLogger(__name__) -CRATE_CONFIG_ERROR = 'crate_config must point to a folder or to a file named "crate.yml"' +CRATE_CONFIG_ERROR = ( + 'crate_config must point to a folder or to a file named "crate.yml"' +) HTTP_ADDRESS_RE = re.compile( - r'.*\[(http|.*HttpServer.*)\s*] \[.*\] .*' - 'publish_address {' - r'(?:inet\[[\w\d\.-]*/|\[)?' - r'(?:[\w\d\.-]+/)?' - r'(?P[\d\.:]+)' - r'(?:\])?' - '}' + r".*\[(http|.*HttpServer.*)\s*] \[.*\] .*" + "publish_address {" + r"(?:inet\[[\w\d\.-]*/|\[)?" + r"(?:[\w\d\.-]+/)?" + r"(?P[\d\.:]+)" + r"(?:\])?" + "}" ) @@ -61,18 +67,22 @@ def http_url_from_host_port(host, port): port = int(port) except ValueError: return None - return '{}:{}'.format(prepend_http(host), port) + return "{}:{}".format(prepend_http(host), port) return None def prepend_http(host): - if not re.match(r'^https?\:\/\/.*', host): - return 'http://{}'.format(host) + if not re.match(r"^https?\:\/\/.*", host): + return "http://{}".format(host) return host def _download_and_extract(uri, directory): - sys.stderr.write("\nINFO: Downloading CrateDB archive from {} into {}".format(uri, directory)) + sys.stderr.write( + "\nINFO: Downloading CrateDB archive from {} into {}".format( + uri, directory + ) + ) sys.stderr.flush() with io.BytesIO(urlopen(uri).read()) as tmpfile: with tarfile.open(fileobj=tmpfile) as t: @@ -82,19 +92,18 @@ def _download_and_extract(uri, directory): def wait_for_http_url(log, timeout=30, verbose=False): start = time.monotonic() while True: - line = log.readline().decode('utf-8').strip() + line = log.readline().decode("utf-8").strip() elapsed = time.monotonic() - start if verbose: - sys.stderr.write('[{:>4.1f}s]{}\n'.format(elapsed, line)) + sys.stderr.write("[{:>4.1f}s]{}\n".format(elapsed, line)) m = HTTP_ADDRESS_RE.match(line) if m: - return prepend_http(m.group('addr')) + return prepend_http(m.group("addr")) elif elapsed > timeout: return None class OutputMonitor: - def __init__(self): self.consumers = [] @@ -105,7 +114,9 @@ def consume(self, iterable): def start(self, proc): self._stop_out_thread = threading.Event() - self._out_thread = threading.Thread(target=self.consume, args=(proc.stdout,)) + self._out_thread = threading.Thread( + target=self.consume, args=(proc.stdout,) + ) self._out_thread.daemon = True self._out_thread.start() @@ -116,7 +127,6 @@ def stop(self): class LineBuffer: - def __init__(self): self.lines = [] @@ -124,7 +134,7 @@ def send(self, line): self.lines.append(line.strip()) -class CrateLayer(object): +class CrateLayer: """ This layer starts a Crate server. """ @@ -135,14 +145,16 @@ class CrateLayer(object): wait_interval = 0.2 @staticmethod - def from_uri(uri, - name, - http_port='4200-4299', - transport_port='4300-4399', - settings=None, - directory=None, - cleanup=True, - verbose=False): + def from_uri( + uri, + name, + http_port="4200-4299", + transport_port="4300-4399", + settings=None, + directory=None, + cleanup=True, + verbose=False, + ): """Download the Crate tarball from a URI and create a CrateLayer :param uri: The uri that points to the Crate tarball @@ -158,11 +170,14 @@ def from_uri(uri, """ directory = directory or tempfile.mkdtemp() filename = os.path.basename(uri) - crate_dir = re.sub(r'\.tar(\.gz)?$', '', filename) + crate_dir = re.sub(r"\.tar(\.gz)?$", "", filename) crate_home = os.path.join(directory, crate_dir) if os.path.exists(crate_home): - sys.stderr.write("\nWARNING: Not extracting Crate tarball because folder already exists") + sys.stderr.write( + "\nWARNING: Not extracting CrateDB tarball" + " because folder already exists" + ) sys.stderr.flush() else: _download_and_extract(uri, directory) @@ -173,29 +188,33 @@ def from_uri(uri, port=http_port, transport_port=transport_port, settings=settings, - verbose=verbose) + verbose=verbose, + ) if cleanup: tearDown = layer.tearDown def new_teardown(*args, **kws): shutil.rmtree(directory) tearDown(*args, **kws) - layer.tearDown = new_teardown + + layer.tearDown = new_teardown # type: ignore[method-assign] return layer - def __init__(self, - name, - crate_home, - crate_config=None, - port=None, - keepRunning=False, - transport_port=None, - crate_exec=None, - cluster_name=None, - host="127.0.0.1", - settings=None, - verbose=False, - env=None): + def __init__( + self, + name, + crate_home, + crate_config=None, + port=None, + keepRunning=False, + transport_port=None, + crate_exec=None, + cluster_name=None, + host="127.0.0.1", + settings=None, + verbose=False, + env=None, + ): """ :param name: layer name, is also used as the cluser name :param crate_home: path to home directory of the crate installation @@ -216,52 +235,69 @@ def __init__(self, self.__name__ = name if settings and isinstance(settings, dict): # extra settings may override host/port specification! - self.http_url = http_url_from_host_port(settings.get('network.host', host), - settings.get('http.port', port)) + self.http_url = http_url_from_host_port( + settings.get("network.host", host), + settings.get("http.port", port), + ) else: self.http_url = http_url_from_host_port(host, port) self.process = None self.verbose = verbose self.env = env or {} - self.env.setdefault('CRATE_USE_IPV4', 'true') - self.env.setdefault('JAVA_HOME', os.environ.get('JAVA_HOME', '')) + self.env.setdefault("CRATE_USE_IPV4", "true") + self.env.setdefault("JAVA_HOME", os.environ.get("JAVA_HOME", "")) self._stdout_consumers = [] self.conn_pool = urllib3.PoolManager(num_pools=1) crate_home = os.path.abspath(crate_home) if crate_exec is None: - start_script = 'crate.bat' if sys.platform == 'win32' else 'crate' - crate_exec = os.path.join(crate_home, 'bin', start_script) + start_script = "crate.bat" if sys.platform == "win32" else "crate" + crate_exec = os.path.join(crate_home, "bin", start_script) if crate_config is None: - crate_config = os.path.join(crate_home, 'config', 'crate.yml') - elif (os.path.isfile(crate_config) and - os.path.basename(crate_config) != 'crate.yml'): + crate_config = os.path.join(crate_home, "config", "crate.yml") + elif ( + os.path.isfile(crate_config) + and os.path.basename(crate_config) != "crate.yml" + ): raise ValueError(CRATE_CONFIG_ERROR) if cluster_name is None: - cluster_name = "Testing{0}".format(port or 'Dynamic') - settings = self.create_settings(crate_config, - cluster_name, - name, - host, - port or '4200-4299', - transport_port or '4300-4399', - settings) + cluster_name = "Testing{0}".format(port or "Dynamic") + settings = self.create_settings( + crate_config, + cluster_name, + name, + host, + port or "4200-4299", + transport_port or "4300-4399", + settings, + ) # ES 5 cannot parse 'True'/'False' as booleans so convert to lowercase - start_cmd = (crate_exec, ) + tuple(["-C%s=%s" % ((key, str(value).lower()) if isinstance(value, bool) else (key, value)) - for key, value in settings.items()]) - - self._wd = wd = os.path.join(CrateLayer.tmpdir, 'crate_layer', name) - self.start_cmd = start_cmd + ('-Cpath.data=%s' % wd,) - - def create_settings(self, - crate_config, - cluster_name, - node_name, - host, - http_port, - transport_port, - further_settings=None): + start_cmd = (crate_exec,) + tuple( + [ + "-C%s=%s" + % ( + (key, str(value).lower()) + if isinstance(value, bool) + else (key, value) + ) + for key, value in settings.items() + ] + ) + + self._wd = wd = os.path.join(CrateLayer.tmpdir, "crate_layer", name) + self.start_cmd = start_cmd + ("-Cpath.data=%s" % wd,) + + def create_settings( + self, + crate_config, + cluster_name, + node_name, + host, + http_port, + transport_port, + further_settings=None, + ): settings = { "discovery.type": "zen", "discovery.initial_state_timeout": 0, @@ -294,20 +330,23 @@ def _clean(self): def start(self): self._clean() - self.process = subprocess.Popen(self.start_cmd, - env=self.env, - stdout=subprocess.PIPE) + self.process = subprocess.Popen( + self.start_cmd, env=self.env, stdout=subprocess.PIPE + ) returncode = self.process.poll() if returncode is not None: raise SystemError( - 'Failed to start server rc={0} cmd={1}'.format(returncode, - self.start_cmd) + "Failed to start server rc={0} cmd={1}".format( + returncode, self.start_cmd + ) ) if not self.http_url: # try to read http_url from startup logs # this is necessary if no static port is assigned - self.http_url = wait_for_http_url(self.process.stdout, verbose=self.verbose) + self.http_url = wait_for_http_url( + self.process.stdout, verbose=self.verbose + ) self.monitor = OutputMonitor() self.monitor.start(self.process) @@ -315,10 +354,10 @@ def start(self): if not self.http_url: self.stop() else: - sys.stderr.write('HTTP: {}\n'.format(self.http_url)) + sys.stderr.write("HTTP: {}\n".format(self.http_url)) self._wait_for_start() self._wait_for_master() - sys.stderr.write('\nCrate instance ready.\n') + sys.stderr.write("\nCrate instance ready.\n") def stop(self): self.conn_pool.clear() @@ -352,10 +391,9 @@ def _wait_for(self, validator): for line in line_buf.lines: log.error(line) self.stop() - raise SystemError('Failed to start Crate instance in time.') - else: - sys.stderr.write('.') - time.sleep(self.wait_interval) + raise SystemError("Failed to start Crate instance in time.") + sys.stderr.write(".") + time.sleep(self.wait_interval) self.monitor.consumers.remove(line_buf) @@ -367,7 +405,7 @@ def _wait_for_start(self): # after the layer starts don't result in 503 def validator(): try: - resp = self.conn_pool.request('HEAD', self.http_url) + resp = self.conn_pool.request("HEAD", self.http_url) return resp.status == 200 except Exception: return False @@ -379,12 +417,12 @@ def _wait_for_master(self): def validator(): resp = self.conn_pool.urlopen( - 'POST', - '{server}/_sql'.format(server=self.http_url), - headers={'Content-Type': 'application/json'}, - body='{"stmt": "select master_node from sys.cluster"}' + "POST", + "{server}/_sql".format(server=self.http_url), + headers={"Content-Type": "application/json"}, + body='{"stmt": "select master_node from sys.cluster"}', ) - data = json.loads(resp.data.decode('utf-8')) - return resp.status == 200 and data['rows'][0][0] + data = json.loads(resp.data.decode("utf-8")) + return resp.status == 200 and data["rows"][0][0] self._wait_for(validator) diff --git a/src/crate/testing/util.py b/src/crate/testing/util.py index 54f9098c..6f25b276 100644 --- a/src/crate/testing/util.py +++ b/src/crate/testing/util.py @@ -21,8 +21,7 @@ import unittest -class ClientMocked(object): - +class ClientMocked: active_servers = ["http://localhost:4200"] def __init__(self): @@ -52,14 +51,15 @@ class ParametrizedTestCase(unittest.TestCase): https://eli.thegreenplace.net/2011/08/02/python-unit-testing-parametrized-test-cases """ + def __init__(self, methodName="runTest", param=None): super(ParametrizedTestCase, self).__init__(methodName) self.param = param @staticmethod def parametrize(testcase_klass, param=None): - """ Create a suite containing all tests taken from the given - subclass, passing them the parameter 'param'. + """Create a suite containing all tests taken from the given + subclass, passing them the parameter 'param'. """ testloader = unittest.TestLoader() testnames = testloader.getTestCaseNames(testcase_klass) @@ -69,7 +69,7 @@ def parametrize(testcase_klass, param=None): return suite -class ExtraAssertions: +class ExtraAssertions(unittest.TestCase): """ Additional assert methods for unittest. @@ -83,9 +83,13 @@ def assertIsSubclass(self, cls, superclass, msg=None): r = issubclass(cls, superclass) except TypeError: if not isinstance(cls, type): - self.fail(self._formatMessage(msg, - '%r is not a class' % (cls,))) + self.fail( + self._formatMessage(msg, "%r is not a class" % (cls,)) + ) raise if not r: - self.fail(self._formatMessage(msg, - '%r is not a subclass of %r' % (cls, superclass))) + self.fail( + self._formatMessage( + msg, "%r is not a subclass of %r" % (cls, superclass) + ) + ) diff --git a/tests/client/layer.py b/tests/client/layer.py index b2d521e7..c381299d 100644 --- a/tests/client/layer.py +++ b/tests/client/layer.py @@ -22,28 +22,32 @@ from __future__ import absolute_import import json -import os +import logging import socket -import unittest -from pprint import pprint -from http.server import HTTPServer, BaseHTTPRequestHandler import ssl -import time import threading -import logging +import time +import unittest +from http.server import BaseHTTPRequestHandler, HTTPServer +from pprint import pprint import stopit from crate.client import connect from crate.testing.layer import CrateLayer -from .settings import \ - assets_path, crate_host, crate_path, crate_port, \ - crate_transport_port, localhost +from .settings import ( + assets_path, + crate_host, + crate_path, + crate_port, + crate_transport_port, + localhost, +) makeSuite = unittest.TestLoader().loadTestsFromTestCase -log = logging.getLogger('crate.testing.layer') +log = logging.getLogger("crate.testing.layer") ch = logging.StreamHandler() ch.setLevel(logging.ERROR) log.addHandler(ch) @@ -51,20 +55,20 @@ def cprint(s): if isinstance(s, bytes): - s = s.decode('utf-8') - print(s) + s = s.decode("utf-8") + print(s) # noqa: T201 settings = { - 'udc.enabled': 'false', - 'lang.js.enabled': 'true', - 'auth.host_based.enabled': 'true', - 'auth.host_based.config.0.user': 'crate', - 'auth.host_based.config.0.method': 'trust', - 'auth.host_based.config.98.user': 'trusted_me', - 'auth.host_based.config.98.method': 'trust', - 'auth.host_based.config.99.user': 'me', - 'auth.host_based.config.99.method': 'password', + "udc.enabled": "false", + "lang.js.enabled": "true", + "auth.host_based.enabled": "true", + "auth.host_based.config.0.user": "crate", + "auth.host_based.config.0.method": "trust", + "auth.host_based.config.98.user": "trusted_me", + "auth.host_based.config.98.method": "trust", + "auth.host_based.config.99.user": "me", + "auth.host_based.config.99.method": "password", } crate_layer = None @@ -86,40 +90,46 @@ def ensure_cratedb_layer(): global crate_layer if crate_layer is None: - crate_layer = CrateLayer('crate', - crate_home=crate_path(), - port=crate_port, - host=localhost, - transport_port=crate_transport_port, - settings=settings) + crate_layer = CrateLayer( + "crate", + crate_home=crate_path(), + port=crate_port, + host=localhost, + transport_port=crate_transport_port, + settings=settings, + ) return crate_layer def setUpCrateLayerBaseline(test): if hasattr(test, "globs"): - test.globs['crate_host'] = crate_host - test.globs['pprint'] = pprint - test.globs['print'] = cprint + test.globs["crate_host"] = crate_host + test.globs["pprint"] = pprint + test.globs["print"] = cprint with connect(crate_host) as conn: cursor = conn.cursor() - with open(assets_path('mappings/locations.sql')) as s: + with open(assets_path("mappings/locations.sql")) as s: stmt = s.read() cursor.execute(stmt) - stmt = ("select count(*) from information_schema.tables " - "where table_name = 'locations'") + stmt = ( + "select count(*) from information_schema.tables " + "where table_name = 'locations'" + ) cursor.execute(stmt) - assert cursor.fetchall()[0][0] == 1 + assert cursor.fetchall()[0][0] == 1 # noqa: S101 - data_path = assets_path('import/test_a.json') + data_path = assets_path("import/test_a.json") # load testing data into crate cursor.execute("copy locations from ?", (data_path,)) # refresh location table so imported data is visible immediately cursor.execute("refresh table locations") # create blob table - cursor.execute("create blob table myfiles clustered into 1 shards " + - "with (number_of_replicas=0)") + cursor.execute( + "create blob table myfiles clustered into 1 shards " + + "with (number_of_replicas=0)" + ) # create users cursor.execute("CREATE USER me WITH (password = 'my_secret_pw')") @@ -149,20 +159,20 @@ class HttpsTestServerLayer: CACERT_FILE = assets_path("pki/cacert_valid.pem") __name__ = "httpsserver" - __bases__ = tuple() + __bases__ = () class HttpsServer(HTTPServer): def get_request(self): - # Prepare SSL context. - context = ssl._create_unverified_context( + context = ssl._create_unverified_context( # noqa: S323 protocol=ssl.PROTOCOL_TLS_SERVER, cert_reqs=ssl.CERT_OPTIONAL, check_hostname=False, purpose=ssl.Purpose.CLIENT_AUTH, certfile=HttpsTestServerLayer.CERT_FILE, keyfile=HttpsTestServerLayer.CERT_FILE, - cafile=HttpsTestServerLayer.CACERT_FILE) + cafile=HttpsTestServerLayer.CACERT_FILE, + ) # noqa: S323 # Set minimum protocol version, TLSv1 and TLSv1.1 are unsafe. context.minimum_version = ssl.TLSVersion.TLSv1_2 @@ -174,12 +184,16 @@ def get_request(self): return socket, client_address class HttpsHandler(BaseHTTPRequestHandler): - - payload = json.dumps({"name": "test", "status": 200, }) + payload = json.dumps( + { + "name": "test", + "status": 200, + } + ) def do_GET(self): self.send_response(200) - payload = self.payload.encode('UTF-8') + payload = self.payload.encode("UTF-8") self.send_header("Content-Length", len(payload)) self.send_header("Content-Type", "application/json; charset=UTF-8") self.end_headers() @@ -187,8 +201,7 @@ def do_GET(self): def setUp(self): self.server = self.HttpsServer( - (self.HOST, self.PORT), - self.HttpsHandler + (self.HOST, self.PORT), self.HttpsHandler ) thread = threading.Thread(target=self.serve_forever) thread.daemon = True # quit interpreter when only thread exists @@ -196,9 +209,9 @@ def setUp(self): self.waitForServer() def serve_forever(self): - print("listening on", self.HOST, self.PORT) + log.info("listening on", self.HOST, self.PORT) self.server.serve_forever() - print("server stopped.") + log.info("server stopped.") def tearDown(self): self.server.shutdown() @@ -224,21 +237,23 @@ def waitForServer(self, timeout=5): time.sleep(0.001) if not to_ctx_mgr: - raise TimeoutError("Could not properly start embedded webserver " - "within {} seconds".format(timeout)) + raise TimeoutError( + "Could not properly start embedded webserver " + "within {} seconds".format(timeout) + ) def setUpWithHttps(test): - test.globs['crate_host'] = "https://{0}:{1}".format( + test.globs["crate_host"] = "https://{0}:{1}".format( HttpsTestServerLayer.HOST, HttpsTestServerLayer.PORT ) - test.globs['pprint'] = pprint - test.globs['print'] = cprint + test.globs["pprint"] = pprint + test.globs["print"] = cprint - test.globs['cacert_valid'] = assets_path("pki/cacert_valid.pem") - test.globs['cacert_invalid'] = assets_path("pki/cacert_invalid.pem") - test.globs['clientcert_valid'] = assets_path("pki/client_valid.pem") - test.globs['clientcert_invalid'] = assets_path("pki/client_invalid.pem") + test.globs["cacert_valid"] = assets_path("pki/cacert_valid.pem") + test.globs["cacert_invalid"] = assets_path("pki/cacert_invalid.pem") + test.globs["clientcert_valid"] = assets_path("pki/client_valid.pem") + test.globs["clientcert_invalid"] = assets_path("pki/client_invalid.pem") def _execute_statements(statements, on_error="ignore"): @@ -253,10 +268,10 @@ def _execute_statement(cursor, stmt, on_error="ignore"): try: cursor.execute(stmt) except Exception: # pragma: no cover - # FIXME: Why does this croak on statements like ``DROP TABLE cities``? + # FIXME: Why does this trip on statements like `DROP TABLE cities`? # Note: When needing to debug the test environment, you may want to # enable this logger statement. - # log.exception("Executing SQL statement failed") + # log.exception("Executing SQL statement failed") # noqa: ERA001 if on_error == "ignore": pass elif on_error == "raise": diff --git a/tests/client/settings.py b/tests/client/settings.py index 228222fd..516da19c 100644 --- a/tests/client/settings.py +++ b/tests/client/settings.py @@ -25,7 +25,9 @@ def assets_path(*parts) -> str: - return str((project_root() / "tests" / "assets").joinpath(*parts).absolute()) + return str( + (project_root() / "tests" / "assets").joinpath(*parts).absolute() + ) def crate_path() -> str: @@ -36,9 +38,8 @@ def project_root() -> Path: return Path(__file__).parent.parent.parent - crate_port = 44209 crate_transport_port = 44309 -localhost = '127.0.0.1' +localhost = "127.0.0.1" crate_host = "{host}:{port}".format(host=localhost, port=crate_port) crate_uri = "http://%s" % crate_host diff --git a/tests/client/test_connection.py b/tests/client/test_connection.py index 5badfab2..0cc5e1ef 100644 --- a/tests/client/test_connection.py +++ b/tests/client/test_connection.py @@ -1,24 +1,23 @@ import datetime +from unittest import TestCase from urllib3 import Timeout +from crate.client import connect from crate.client.connection import Connection from crate.client.http import Client -from crate.client import connect -from unittest import TestCase from .settings import crate_host class ConnectionTest(TestCase): - def test_connection_mock(self): """ For testing purposes it is often useful to replace the client used for communication with the CrateDB server with a stub or mock. - This can be done by passing an object of the Client class when calling the - ``connect`` method. + This can be done by passing an object of the Client class when calling + the `connect` method. """ class MyConnectionClient: @@ -32,12 +31,17 @@ def server_infos(self, server): connection = connect([crate_host], client=MyConnectionClient()) self.assertIsInstance(connection, Connection) - self.assertEqual(connection.client.server_infos("foo"), ('localhost:4200', 'my server', '0.42.0')) + self.assertEqual( + connection.client.server_infos("foo"), + ("localhost:4200", "my server", "0.42.0"), + ) def test_lowest_server_version(self): - infos = [(None, None, '0.42.3'), - (None, None, '0.41.8'), - (None, None, 'not a version')] + infos = [ + (None, None, "0.42.3"), + (None, None, "0.41.8"), + (None, None, "not a version"), + ] client = Client(servers="localhost:4200 localhost:4201 localhost:4202") client.server_infos = lambda server: infos.pop() @@ -53,40 +57,45 @@ def test_invalid_server_version(self): connection.close() def test_context_manager(self): - with connect('localhost:4200') as conn: + with connect("localhost:4200") as conn: pass self.assertEqual(conn._closed, True) def test_with_timezone(self): """ - Verify the cursor objects will return timezone-aware `datetime` objects when requested to. - When switching the time zone at runtime on the connection object, only new cursor objects - will inherit the new time zone. + The cursor can return timezone-aware `datetime` objects when requested. + + When switching the time zone at runtime on the connection object, only + new cursor objects will inherit the new time zone. """ tz_mst = datetime.timezone(datetime.timedelta(hours=7), name="MST") - connection = connect('localhost:4200', time_zone=tz_mst) + connection = connect("localhost:4200", time_zone=tz_mst) cursor = connection.cursor() self.assertEqual(cursor.time_zone.tzname(None), "MST") - self.assertEqual(cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=25200)) + self.assertEqual( + cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=25200) + ) connection.time_zone = datetime.timezone.utc cursor = connection.cursor() self.assertEqual(cursor.time_zone.tzname(None), "UTC") - self.assertEqual(cursor.time_zone.utcoffset(None), datetime.timedelta(0)) + self.assertEqual( + cursor.time_zone.utcoffset(None), datetime.timedelta(0) + ) def test_timeout_float(self): """ Verify setting the timeout value as a scalar (float) works. """ - with connect('localhost:4200', timeout=2.42) as conn: + with connect("localhost:4200", timeout=2.42) as conn: self.assertEqual(conn.client._pool_kw["timeout"], 2.42) def test_timeout_string(self): """ Verify setting the timeout value as a scalar (string) works. """ - with connect('localhost:4200', timeout="2.42") as conn: + with connect("localhost:4200", timeout="2.42") as conn: self.assertEqual(conn.client._pool_kw["timeout"], 2.42) def test_timeout_object(self): @@ -94,5 +103,5 @@ def test_timeout_object(self): Verify setting the timeout value as a Timeout object works. """ timeout = Timeout(connect=2.42, read=0.01) - with connect('localhost:4200', timeout=timeout) as conn: + with connect("localhost:4200", timeout=timeout) as conn: self.assertEqual(conn.client._pool_kw["timeout"], timeout) diff --git a/tests/client/test_cursor.py b/tests/client/test_cursor.py index 318c172b..a1013979 100644 --- a/tests/client/test_cursor.py +++ b/tests/client/test_cursor.py @@ -23,6 +23,7 @@ from ipaddress import IPv4Address from unittest import TestCase from unittest.mock import MagicMock + try: import zoneinfo except ImportError: @@ -37,7 +38,6 @@ class CursorTest(TestCase): - @staticmethod def get_mocked_connection(): client = MagicMock(spec=Client) @@ -45,7 +45,7 @@ def get_mocked_connection(): def test_create_with_timezone_as_datetime_object(self): """ - Verify the cursor returns timezone-aware `datetime` objects when requested to. + The cursor can return timezone-aware `datetime` objects when requested. Switching the time zone at runtime on the cursor object is possible. Here: Use a `datetime.timezone` instance. """ @@ -56,63 +56,81 @@ def test_create_with_timezone_as_datetime_object(self): cursor = connection.cursor(time_zone=tz_mst) self.assertEqual(cursor.time_zone.tzname(None), "MST") - self.assertEqual(cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=25200)) + self.assertEqual( + cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=25200) + ) cursor.time_zone = datetime.timezone.utc self.assertEqual(cursor.time_zone.tzname(None), "UTC") - self.assertEqual(cursor.time_zone.utcoffset(None), datetime.timedelta(0)) + self.assertEqual( + cursor.time_zone.utcoffset(None), datetime.timedelta(0) + ) def test_create_with_timezone_as_pytz_object(self): """ - Verify the cursor returns timezone-aware `datetime` objects when requested to. + The cursor can return timezone-aware `datetime` objects when requested. Here: Use a `pytz.timezone` instance. """ connection = self.get_mocked_connection() - cursor = connection.cursor(time_zone=pytz.timezone('Australia/Sydney')) + cursor = connection.cursor(time_zone=pytz.timezone("Australia/Sydney")) self.assertEqual(cursor.time_zone.tzname(None), "Australia/Sydney") - # Apparently, when using `pytz`, the timezone object does not return an offset. - # Nevertheless, it works, as demonstrated per doctest in `cursor.txt`. + # Apparently, when using `pytz`, the timezone object does not return + # an offset. Nevertheless, it works, as demonstrated per doctest in + # `cursor.txt`. self.assertEqual(cursor.time_zone.utcoffset(None), None) def test_create_with_timezone_as_zoneinfo_object(self): """ - Verify the cursor returns timezone-aware `datetime` objects when requested to. + The cursor can return timezone-aware `datetime` objects when requested. Here: Use a `zoneinfo.ZoneInfo` instance. """ connection = self.get_mocked_connection() - cursor = connection.cursor(time_zone=zoneinfo.ZoneInfo('Australia/Sydney')) - self.assertEqual(cursor.time_zone.key, 'Australia/Sydney') + cursor = connection.cursor( + time_zone=zoneinfo.ZoneInfo("Australia/Sydney") + ) + self.assertEqual(cursor.time_zone.key, "Australia/Sydney") def test_create_with_timezone_as_utc_offset_success(self): """ - Verify the cursor returns timezone-aware `datetime` objects when requested to. + The cursor can return timezone-aware `datetime` objects when requested. Here: Use a UTC offset in string format. """ connection = self.get_mocked_connection() cursor = connection.cursor(time_zone="+0530") self.assertEqual(cursor.time_zone.tzname(None), "+0530") - self.assertEqual(cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=19800)) + self.assertEqual( + cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=19800) + ) connection = self.get_mocked_connection() cursor = connection.cursor(time_zone="-1145") self.assertEqual(cursor.time_zone.tzname(None), "-1145") - self.assertEqual(cursor.time_zone.utcoffset(None), datetime.timedelta(days=-1, seconds=44100)) + self.assertEqual( + cursor.time_zone.utcoffset(None), + datetime.timedelta(days=-1, seconds=44100), + ) def test_create_with_timezone_as_utc_offset_failure(self): """ - Verify the cursor croaks when trying to create it with invalid UTC offset strings. + Verify the cursor trips when trying to use invalid UTC offset strings. """ connection = self.get_mocked_connection() with self.assertRaises(AssertionError) as ex: connection.cursor(time_zone="foobar") - self.assertEqual(str(ex.exception), "Time zone 'foobar' is given in invalid UTC offset format") + self.assertEqual( + str(ex.exception), + "Time zone 'foobar' is given in invalid UTC offset format", + ) connection = self.get_mocked_connection() with self.assertRaises(ValueError) as ex: connection.cursor(time_zone="+abcd") - self.assertEqual(str(ex.exception), "Time zone '+abcd' is given in invalid UTC offset format: " - "invalid literal for int() with base 10: '+ab'") + self.assertEqual( + str(ex.exception), + "Time zone '+abcd' is given in invalid UTC offset format: " + "invalid literal for int() with base 10: '+ab'", + ) def test_create_with_timezone_connection_cursor_precedence(self): """ @@ -120,16 +138,20 @@ def test_create_with_timezone_connection_cursor_precedence(self): takes precedence over the one specified on the connection instance. """ client = MagicMock(spec=Client) - connection = connect(client=client, time_zone=pytz.timezone('Australia/Sydney')) + connection = connect( + client=client, time_zone=pytz.timezone("Australia/Sydney") + ) cursor = connection.cursor(time_zone="+0530") self.assertEqual(cursor.time_zone.tzname(None), "+0530") - self.assertEqual(cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=19800)) + self.assertEqual( + cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=19800) + ) def test_execute_with_args(self): client = MagicMock(spec=Client) conn = connect(client=client) c = conn.cursor() - statement = 'select * from locations where position = ?' + statement = "select * from locations where position = ?" c.execute(statement, 1) client.sql.assert_called_once_with(statement, 1, None) conn.close() @@ -138,7 +160,7 @@ def test_execute_with_bulk_args(self): client = MagicMock(spec=Client) conn = connect(client=client) c = conn.cursor() - statement = 'select * from locations where position = ?' + statement = "select * from locations where position = ?" c.execute(statement, bulk_parameters=[[1]]) client.sql.assert_called_once_with(statement, None, [[1]]) conn.close() @@ -150,30 +172,45 @@ def test_execute_with_converter(self): # Use the set of data type converters from `DefaultTypeConverter` # and add another custom converter. converter = DefaultTypeConverter( - {DataType.BIT: lambda value: value is not None and int(value[2:-1], 2) or None}) + { + DataType.BIT: lambda value: value is not None + and int(value[2:-1], 2) + or None + } + ) # Create a `Cursor` object with converter. c = conn.cursor(converter=converter) # Make up a response using CrateDB data types `TEXT`, `IP`, # `TIMESTAMP`, `BIT`. - conn.client.set_next_response({ - "col_types": [4, 5, 11, 25], - "cols": ["name", "address", "timestamp", "bitmask"], - "rows": [ - ["foo", "10.10.10.1", 1658167836758, "B'0110'"], - [None, None, None, None], - ], - "rowcount": 1, - "duration": 123 - }) + conn.client.set_next_response( + { + "col_types": [4, 5, 11, 25], + "cols": ["name", "address", "timestamp", "bitmask"], + "rows": [ + ["foo", "10.10.10.1", 1658167836758, "B'0110'"], + [None, None, None, None], + ], + "rowcount": 1, + "duration": 123, + } + ) c.execute("") result = c.fetchall() - self.assertEqual(result, [ - ['foo', IPv4Address('10.10.10.1'), datetime.datetime(2022, 7, 18, 18, 10, 36, 758000), 6], - [None, None, None, None], - ]) + self.assertEqual( + result, + [ + [ + "foo", + IPv4Address("10.10.10.1"), + datetime.datetime(2022, 7, 18, 18, 10, 36, 758000), + 6, + ], + [None, None, None, None], + ], + ) conn.close() @@ -187,15 +224,17 @@ def test_execute_with_converter_and_invalid_data_type(self): # Make up a response using CrateDB data types `TEXT`, `IP`, # `TIMESTAMP`, `BIT`. - conn.client.set_next_response({ - "col_types": [999], - "cols": ["foo"], - "rows": [ - ["n/a"], - ], - "rowcount": 1, - "duration": 123 - }) + conn.client.set_next_response( + { + "col_types": [999], + "cols": ["foo"], + "rows": [ + ["n/a"], + ], + "rowcount": 1, + "duration": 123, + } + ) c.execute("") with self.assertRaises(ValueError) as ex: @@ -208,20 +247,25 @@ def test_execute_array_with_converter(self): converter = DefaultTypeConverter() cursor = conn.cursor(converter=converter) - conn.client.set_next_response({ - "col_types": [4, [100, 5]], - "cols": ["name", "address"], - "rows": [["foo", ["10.10.10.1", "10.10.10.2"]]], - "rowcount": 1, - "duration": 123 - }) + conn.client.set_next_response( + { + "col_types": [4, [100, 5]], + "cols": ["name", "address"], + "rows": [["foo", ["10.10.10.1", "10.10.10.2"]]], + "rowcount": 1, + "duration": 123, + } + ) cursor.execute("") result = cursor.fetchone() - self.assertEqual(result, [ - 'foo', - [IPv4Address('10.10.10.1'), IPv4Address('10.10.10.2')], - ]) + self.assertEqual( + result, + [ + "foo", + [IPv4Address("10.10.10.1"), IPv4Address("10.10.10.2")], + ], + ) def test_execute_array_with_converter_and_invalid_collection_type(self): client = ClientMocked() @@ -231,19 +275,24 @@ def test_execute_array_with_converter_and_invalid_collection_type(self): # Converting collections only works for `ARRAY`s. (ID=100). # When using `DOUBLE` (ID=6), it should croak. - conn.client.set_next_response({ - "col_types": [4, [6, 5]], - "cols": ["name", "address"], - "rows": [["foo", ["10.10.10.1", "10.10.10.2"]]], - "rowcount": 1, - "duration": 123 - }) + conn.client.set_next_response( + { + "col_types": [4, [6, 5]], + "cols": ["name", "address"], + "rows": [["foo", ["10.10.10.1", "10.10.10.2"]]], + "rowcount": 1, + "duration": 123, + } + ) cursor.execute("") with self.assertRaises(ValueError) as ex: cursor.fetchone() - self.assertEqual(ex.exception.args, ("Data type 6 is not implemented as collection type",)) + self.assertEqual( + ex.exception.args, + ("Data type 6 is not implemented as collection type",), + ) def test_execute_nested_array_with_converter(self): client = ClientMocked() @@ -251,20 +300,40 @@ def test_execute_nested_array_with_converter(self): converter = DefaultTypeConverter() cursor = conn.cursor(converter=converter) - conn.client.set_next_response({ - "col_types": [4, [100, [100, 5]]], - "cols": ["name", "address_buckets"], - "rows": [["foo", [["10.10.10.1", "10.10.10.2"], ["10.10.10.3"], [], None]]], - "rowcount": 1, - "duration": 123 - }) + conn.client.set_next_response( + { + "col_types": [4, [100, [100, 5]]], + "cols": ["name", "address_buckets"], + "rows": [ + [ + "foo", + [ + ["10.10.10.1", "10.10.10.2"], + ["10.10.10.3"], + [], + None, + ], + ] + ], + "rowcount": 1, + "duration": 123, + } + ) cursor.execute("") result = cursor.fetchone() - self.assertEqual(result, [ - 'foo', - [[IPv4Address('10.10.10.1'), IPv4Address('10.10.10.2')], [IPv4Address('10.10.10.3')], [], None], - ]) + self.assertEqual( + result, + [ + "foo", + [ + [IPv4Address("10.10.10.1"), IPv4Address("10.10.10.2")], + [IPv4Address("10.10.10.3")], + [], + None, + ], + ], + ) def test_executemany_with_converter(self): client = ClientMocked() @@ -272,19 +341,21 @@ def test_executemany_with_converter(self): converter = DefaultTypeConverter() cursor = conn.cursor(converter=converter) - conn.client.set_next_response({ - "col_types": [4, 5], - "cols": ["name", "address"], - "rows": [["foo", "10.10.10.1"]], - "rowcount": 1, - "duration": 123 - }) + conn.client.set_next_response( + { + "col_types": [4, 5], + "cols": ["name", "address"], + "rows": [["foo", "10.10.10.1"]], + "rowcount": 1, + "duration": 123, + } + ) cursor.executemany("", []) result = cursor.fetchall() - # ``executemany()`` is not intended to be used with statements returning result - # sets. The result will always be empty. + # ``executemany()`` is not intended to be used with statements + # returning result sets. The result will always be empty. self.assertEqual(result, []) def test_execute_with_timezone(self): @@ -296,46 +367,73 @@ def test_execute_with_timezone(self): c = conn.cursor(time_zone=tz_mst) # Make up a response using CrateDB data type `TIMESTAMP`. - conn.client.set_next_response({ - "col_types": [4, 11], - "cols": ["name", "timestamp"], - "rows": [ - ["foo", 1658167836758], - [None, None], - ], - }) - - # Run execution and verify the returned `datetime` object is timezone-aware, - # using the designated timezone object. + conn.client.set_next_response( + { + "col_types": [4, 11], + "cols": ["name", "timestamp"], + "rows": [ + ["foo", 1658167836758], + [None, None], + ], + } + ) + + # Run execution and verify the returned `datetime` object is + # timezone-aware, using the designated timezone object. c.execute("") result = c.fetchall() - self.assertEqual(result, [ + self.assertEqual( + result, [ - 'foo', - datetime.datetime(2022, 7, 19, 1, 10, 36, 758000, - tzinfo=datetime.timezone(datetime.timedelta(seconds=25200), 'MST')), + [ + "foo", + datetime.datetime( + 2022, + 7, + 19, + 1, + 10, + 36, + 758000, + tzinfo=datetime.timezone( + datetime.timedelta(seconds=25200), "MST" + ), + ), + ], + [ + None, + None, + ], ], - [ - None, - None, - ], - ]) + ) self.assertEqual(result[0][1].tzname(), "MST") # Change timezone and verify the returned `datetime` object is using it. c.time_zone = datetime.timezone.utc c.execute("") result = c.fetchall() - self.assertEqual(result, [ - [ - 'foo', - datetime.datetime(2022, 7, 18, 18, 10, 36, 758000, tzinfo=datetime.timezone.utc), - ], + self.assertEqual( + result, [ - None, - None, + [ + "foo", + datetime.datetime( + 2022, + 7, + 18, + 18, + 10, + 36, + 758000, + tzinfo=datetime.timezone.utc, + ), + ], + [ + None, + None, + ], ], - ]) + ) self.assertEqual(result[0][1].tzname(), "UTC") conn.close() diff --git a/tests/client/test_exceptions.py b/tests/client/test_exceptions.py index 23f5ad68..cb91e1a9 100644 --- a/tests/client/test_exceptions.py +++ b/tests/client/test_exceptions.py @@ -4,7 +4,6 @@ class ErrorTestCase(unittest.TestCase): - def test_error_with_msg(self): err = Error("foo") self.assertEqual(str(err), "foo") diff --git a/tests/client/test_http.py b/tests/client/test_http.py index fd538fc1..610197a8 100644 --- a/tests/client/test_http.py +++ b/tests/client/test_http.py @@ -19,34 +19,42 @@ # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. +import datetime as dt import json -import time -import socket import multiprocessing -import sys import os import queue import random +import socket +import sys +import time import traceback +import uuid +from base64 import b64decode +from decimal import Decimal from http.server import BaseHTTPRequestHandler, HTTPServer from multiprocessing.context import ForkProcess +from threading import Event, Thread from unittest import TestCase -from unittest.mock import patch, MagicMock -from threading import Thread, Event -from decimal import Decimal -import datetime as dt - -import urllib3.exceptions -from base64 import b64decode -from urllib.parse import urlparse, parse_qs +from unittest.mock import MagicMock, patch +from urllib.parse import parse_qs, urlparse -import uuid import certifi +import urllib3.exceptions -from crate.client.http import Client, CrateJsonEncoder, _get_socket_opts, _remove_certs_for_non_https -from crate.client.exceptions import ConnectionError, ProgrammingError, IntegrityError - -REQUEST = 'crate.client.http.Server.request' +from crate.client.exceptions import ( + ConnectionError, + IntegrityError, + ProgrammingError, +) +from crate.client.http import ( + Client, + CrateJsonEncoder, + _get_socket_opts, + _remove_certs_for_non_https, +) + +REQUEST = "crate.client.http.Server.request" CA_CERT_PATH = certifi.where() @@ -60,14 +68,15 @@ def request(*args, **kwargs): return response else: return MagicMock(spec=urllib3.response.HTTPResponse) + return request -def fake_response(status, reason=None, content_type='application/json'): +def fake_response(status, reason=None, content_type="application/json"): m = MagicMock(spec=urllib3.response.HTTPResponse) m.status = status - m.reason = reason or '' - m.headers = {'content-type': content_type} + m.reason = reason or "" + m.headers = {"content-type": content_type} return m @@ -78,47 +87,61 @@ def fake_redirect(location): def bad_bulk_response(): - r = fake_response(400, 'Bad Request') - r.data = json.dumps({ - "results": [ - {"rowcount": 1}, - {"error_message": "an error occured"}, - {"error_message": "another error"}, - {"error_message": ""}, - {"error_message": None} - ]}).encode() + r = fake_response(400, "Bad Request") + r.data = json.dumps( + { + "results": [ + {"rowcount": 1}, + {"error_message": "an error occured"}, + {"error_message": "another error"}, + {"error_message": ""}, + {"error_message": None}, + ] + } + ).encode() return r def duplicate_key_exception(): - r = fake_response(409, 'Conflict') - r.data = json.dumps({ - "error": { - "code": 4091, - "message": "DuplicateKeyException[A document with the same primary key exists already]" + r = fake_response(409, "Conflict") + r.data = json.dumps( + { + "error": { + "code": 4091, + "message": "DuplicateKeyException[A document with the " + "same primary key exists already]", + } } - }).encode() + ).encode() return r def fail_sometimes(*args, **kwargs): if random.randint(1, 100) % 10 == 0: - raise urllib3.exceptions.MaxRetryError(None, '/_sql', '') + raise urllib3.exceptions.MaxRetryError(None, "/_sql", "") return fake_response(200) class HttpClientTest(TestCase): - - @patch(REQUEST, fake_request([fake_response(200), - fake_response(104, 'Connection reset by peer'), - fake_response(503, 'Service Unavailable')])) + @patch( + REQUEST, + fake_request( + [ + fake_response(200), + fake_response(104, "Connection reset by peer"), + fake_response(503, "Service Unavailable"), + ] + ), + ) def test_connection_reset_exception(self): client = Client(servers="localhost:4200") - client.sql('select 1') - client.sql('select 2') - self.assertEqual(['http://localhost:4200'], list(client._active_servers)) + client.sql("select 1") + client.sql("select 2") + self.assertEqual( + ["http://localhost:4200"], list(client._active_servers) + ) try: - client.sql('select 3') + client.sql("select 3") except ProgrammingError: self.assertEqual([], list(client._active_servers)) else: @@ -128,7 +151,7 @@ def test_connection_reset_exception(self): def test_no_connection_exception(self): client = Client(servers="localhost:9999") - self.assertRaises(ConnectionError, client.sql, 'select foo') + self.assertRaises(ConnectionError, client.sql, "select foo") client.close() @patch(REQUEST) @@ -136,16 +159,18 @@ def test_http_error_is_re_raised(self, request): request.side_effect = Exception client = Client() - self.assertRaises(ProgrammingError, client.sql, 'select foo') + self.assertRaises(ProgrammingError, client.sql, "select foo") client.close() @patch(REQUEST) - def test_programming_error_contains_http_error_response_content(self, request): + def test_programming_error_contains_http_error_response_content( + self, request + ): request.side_effect = Exception("this shouldn't be raised") client = Client() try: - client.sql('select 1') + client.sql("select 1") except ProgrammingError as e: self.assertEqual("this shouldn't be raised", e.message) else: @@ -153,18 +178,24 @@ def test_programming_error_contains_http_error_response_content(self, request): finally: client.close() - @patch(REQUEST, fake_request([fake_response(200), - fake_response(503, 'Service Unavailable')])) + @patch( + REQUEST, + fake_request( + [fake_response(200), fake_response(503, "Service Unavailable")] + ), + ) def test_server_error_50x(self): client = Client(servers="localhost:4200 localhost:4201") - client.sql('select 1') - client.sql('select 2') + client.sql("select 1") + client.sql("select 2") try: - client.sql('select 3') + client.sql("select 3") except ProgrammingError as e: - self.assertEqual("No more Servers available, " + - "exception from last server: Service Unavailable", - e.message) + self.assertEqual( + "No more Servers available, " + + "exception from last server: Service Unavailable", + e.message, + ) self.assertEqual([], list(client._active_servers)) else: self.assertTrue(False) @@ -173,8 +204,10 @@ def test_server_error_50x(self): def test_connect(self): client = Client(servers="localhost:4200 localhost:4201") - self.assertEqual(client._active_servers, - ["http://localhost:4200", "http://localhost:4201"]) + self.assertEqual( + client._active_servers, + ["http://localhost:4200", "http://localhost:4201"], + ) client.close() client = Client(servers="localhost:4200") @@ -186,54 +219,60 @@ def test_connect(self): client.close() client = Client(servers=["localhost:4200", "127.0.0.1:4201"]) - self.assertEqual(client._active_servers, - ["http://localhost:4200", "http://127.0.0.1:4201"]) + self.assertEqual( + client._active_servers, + ["http://localhost:4200", "http://127.0.0.1:4201"], + ) client.close() - @patch(REQUEST, fake_request(fake_redirect('http://localhost:4201'))) + @patch(REQUEST, fake_request(fake_redirect("http://localhost:4201"))) def test_redirect_handling(self): - client = Client(servers='localhost:4200') + client = Client(servers="localhost:4200") try: - client.blob_get('blobs', 'fake_digest') + client.blob_get("blobs", "fake_digest") except ProgrammingError: # 4201 gets added to serverpool but isn't available # that's why we run into an infinite recursion # exception message is: maximum recursion depth exceeded pass self.assertEqual( - ['http://localhost:4200', 'http://localhost:4201'], - sorted(list(client.server_pool.keys())) + ["http://localhost:4200", "http://localhost:4201"], + sorted(client.server_pool.keys()), ) # the new non-https server must not contain any SSL only arguments # regression test for github issue #179/#180 self.assertEqual( - {'socket_options': _get_socket_opts(keepalive=True)}, - client.server_pool['http://localhost:4201'].pool.conn_kw + {"socket_options": _get_socket_opts(keepalive=True)}, + client.server_pool["http://localhost:4201"].pool.conn_kw, ) client.close() @patch(REQUEST) def test_server_infos(self, request): request.side_effect = urllib3.exceptions.MaxRetryError( - None, '/', "this shouldn't be raised") + None, "/", "this shouldn't be raised" + ) client = Client(servers="localhost:4200 localhost:4201") self.assertRaises( - ConnectionError, client.server_infos, 'http://localhost:4200') + ConnectionError, client.server_infos, "http://localhost:4200" + ) client.close() @patch(REQUEST, fake_request(fake_response(503))) def test_server_infos_503(self): client = Client(servers="localhost:4200") self.assertRaises( - ConnectionError, client.server_infos, 'http://localhost:4200') + ConnectionError, client.server_infos, "http://localhost:4200" + ) client.close() - @patch(REQUEST, fake_request( - fake_response(401, 'Unauthorized', 'text/html'))) + @patch( + REQUEST, fake_request(fake_response(401, "Unauthorized", "text/html")) + ) def test_server_infos_401(self): client = Client(servers="localhost:4200") try: - client.server_infos('http://localhost:4200') + client.server_infos("http://localhost:4200") except ProgrammingError as e: self.assertEqual("401 Client Error: Unauthorized", e.message) else: @@ -245,8 +284,10 @@ def test_server_infos_401(self): def test_bad_bulk_400(self): client = Client(servers="localhost:4200") try: - client.sql("Insert into users (name) values(?)", - bulk_parameters=[["douglas"], ["monthy"]]) + client.sql( + "Insert into users (name) values(?)", + bulk_parameters=[["douglas"], ["monthy"]], + ) except ProgrammingError as e: self.assertEqual("an error occured\nanother error", e.message) else: @@ -260,10 +301,10 @@ def test_decimal_serialization(self, request): request.return_value = fake_response(200) dec = Decimal(0.12) - client.sql('insert into users (float_col) values (?)', (dec,)) + client.sql("insert into users (float_col) values (?)", (dec,)) - data = json.loads(request.call_args[1]['data']) - self.assertEqual(data['args'], [str(dec)]) + data = json.loads(request.call_args[1]["data"]) + self.assertEqual(data["args"], [str(dec)]) client.close() @patch(REQUEST, autospec=True) @@ -272,12 +313,12 @@ def test_datetime_is_converted_to_ts(self, request): request.return_value = fake_response(200) datetime = dt.datetime(2015, 2, 28, 7, 31, 40) - client.sql('insert into users (dt) values (?)', (datetime,)) + client.sql("insert into users (dt) values (?)", (datetime,)) # convert string to dict # because the order of the keys isn't deterministic - data = json.loads(request.call_args[1]['data']) - self.assertEqual(data['args'], [1425108700000]) + data = json.loads(request.call_args[1]["data"]) + self.assertEqual(data["args"], [1425108700000]) client.close() @patch(REQUEST, autospec=True) @@ -286,17 +327,18 @@ def test_date_is_converted_to_ts(self, request): request.return_value = fake_response(200) day = dt.date(2016, 4, 21) - client.sql('insert into users (dt) values (?)', (day,)) - data = json.loads(request.call_args[1]['data']) - self.assertEqual(data['args'], [1461196800000]) + client.sql("insert into users (dt) values (?)", (day,)) + data = json.loads(request.call_args[1]["data"]) + self.assertEqual(data["args"], [1461196800000]) client.close() def test_socket_options_contain_keepalive(self): - server = 'http://localhost:4200' + server = "http://localhost:4200" client = Client(servers=server) conn_kw = client.server_pool[server].pool.conn_kw self.assertIn( - (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), conn_kw['socket_options'] + (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), + conn_kw["socket_options"], ) client.close() @@ -306,10 +348,10 @@ def test_uuid_serialization(self, request): request.return_value = fake_response(200) uid = uuid.uuid4() - client.sql('insert into my_table (str_col) values (?)', (uid,)) + client.sql("insert into my_table (str_col) values (?)", (uid,)) - data = json.loads(request.call_args[1]['data']) - self.assertEqual(data['args'], [str(uid)]) + data = json.loads(request.call_args[1]["data"]) + self.assertEqual(data["args"], [str(uid)]) client.close() @patch(REQUEST, fake_request(duplicate_key_exception())) @@ -320,9 +362,12 @@ def test_duplicate_key_error(self): """ client = Client(servers="localhost:4200") with self.assertRaises(IntegrityError) as cm: - client.sql('INSERT INTO testdrive (foo) VALUES (42)') - self.assertEqual(cm.exception.message, - "DuplicateKeyException[A document with the same primary key exists already]") + client.sql("INSERT INTO testdrive (foo) VALUES (42)") + self.assertEqual( + cm.exception.message, + "DuplicateKeyException[A document with the " + "same primary key exists already]", + ) @patch(REQUEST, fail_sometimes) @@ -334,6 +379,7 @@ class ThreadSafeHttpClientTest(TestCase): check if number of servers in _inactive_servers and _active_servers always equals the number of servers initially given. """ + servers = [ "127.0.0.1:44209", "127.0.0.2:44209", @@ -358,20 +404,21 @@ def tearDown(self): def _run(self): self.event.wait() # wait for the others expected_num_servers = len(self.servers) - for x in range(self.num_commands): + for _ in range(self.num_commands): try: - self.client.sql('select name from sys.cluster') + self.client.sql("select name from sys.cluster") except ConnectionError: pass try: with self.client._lock: - num_servers = len(self.client._active_servers) + \ - len(self.client._inactive_servers) + num_servers = len(self.client._active_servers) + len( + self.client._inactive_servers + ) self.assertEqual( expected_num_servers, num_servers, - "expected %d but got %d" % (expected_num_servers, - num_servers) + "expected %d but got %d" + % (expected_num_servers, num_servers), ) except AssertionError: self.err_queue.put(sys.exc_info()) @@ -397,8 +444,12 @@ def test_client_threaded(self): t.join(self.thread_timeout) if not self.err_queue.empty(): - self.assertTrue(False, "".join( - traceback.format_exception(*self.err_queue.get(block=False)))) + self.assertTrue( + False, + "".join( + traceback.format_exception(*self.err_queue.get(block=False)) + ), + ) class ClientAddressRequestHandler(BaseHTTPRequestHandler): @@ -407,31 +458,30 @@ class ClientAddressRequestHandler(BaseHTTPRequestHandler): returns client host and port in crate-conform-responses """ - protocol_version = 'HTTP/1.1' + + protocol_version = "HTTP/1.1" def do_GET(self): content_length = self.headers.get("content-length") if content_length: self.rfile.read(int(content_length)) - response = json.dumps({ - "cols": ["host", "port"], - "rows": [ - self.client_address[0], - self.client_address[1] - ], - "rowCount": 1, - }) + response = json.dumps( + { + "cols": ["host", "port"], + "rows": [self.client_address[0], self.client_address[1]], + "rowCount": 1, + } + ) self.send_response(200) self.send_header("Content-Length", len(response)) self.send_header("Content-Type", "application/json; charset=UTF-8") self.end_headers() - self.wfile.write(response.encode('UTF-8')) + self.wfile.write(response.encode("UTF-8")) do_POST = do_PUT = do_DELETE = do_HEAD = do_GET class KeepAliveClientTest(TestCase): - server_address = ("127.0.0.1", 65535) def __init__(self, *args, **kwargs): @@ -442,7 +492,7 @@ def setUp(self): super(KeepAliveClientTest, self).setUp() self.client = Client(["%s:%d" % self.server_address]) self.server_process.start() - time.sleep(.10) + time.sleep(0.10) def tearDown(self): self.server_process.terminate() @@ -450,12 +500,13 @@ def tearDown(self): super(KeepAliveClientTest, self).tearDown() def _run_server(self): - self.server = HTTPServer(self.server_address, - ClientAddressRequestHandler) + self.server = HTTPServer( + self.server_address, ClientAddressRequestHandler + ) self.server.handle_request() def test_client_keepalive(self): - for x in range(10): + for _ in range(10): result = self.client.sql("select * from fake") another_result = self.client.sql("select again from fake") @@ -463,9 +514,8 @@ def test_client_keepalive(self): class ParamsTest(TestCase): - def test_params(self): - client = Client(['127.0.0.1:4200'], error_trace=True) + client = Client(["127.0.0.1:4200"], error_trace=True) parsed = urlparse(client.path) params = parse_qs(parsed.query) self.assertEqual(params["error_trace"], ["true"]) @@ -478,26 +528,25 @@ def test_no_params(self): class RequestsCaBundleTest(TestCase): - def test_open_client(self): os.environ["REQUESTS_CA_BUNDLE"] = CA_CERT_PATH try: - Client('http://127.0.0.1:4200') + Client("http://127.0.0.1:4200") except ProgrammingError: self.fail("HTTP not working with REQUESTS_CA_BUNDLE") finally: - os.unsetenv('REQUESTS_CA_BUNDLE') - os.environ["REQUESTS_CA_BUNDLE"] = '' + os.unsetenv("REQUESTS_CA_BUNDLE") + os.environ["REQUESTS_CA_BUNDLE"] = "" def test_remove_certs_for_non_https(self): - d = _remove_certs_for_non_https('https', {"ca_certs": 1}) - self.assertIn('ca_certs', d) + d = _remove_certs_for_non_https("https", {"ca_certs": 1}) + self.assertIn("ca_certs", d) - kwargs = {'ca_certs': 1, 'foobar': 2, 'cert_file': 3} - d = _remove_certs_for_non_https('http', kwargs) - self.assertNotIn('ca_certs', d) - self.assertNotIn('cert_file', d) - self.assertIn('foobar', d) + kwargs = {"ca_certs": 1, "foobar": 2, "cert_file": 3} + d = _remove_certs_for_non_https("http", kwargs) + self.assertNotIn("ca_certs", d) + self.assertNotIn("cert_file", d) + self.assertIn("foobar", d) class TimeoutRequestHandler(BaseHTTPRequestHandler): @@ -507,7 +556,7 @@ class TimeoutRequestHandler(BaseHTTPRequestHandler): """ def do_POST(self): - self.server.SHARED['count'] += 1 + self.server.SHARED["count"] += 1 time.sleep(5) @@ -518,45 +567,46 @@ class SharedStateRequestHandler(BaseHTTPRequestHandler): """ def do_POST(self): - self.server.SHARED['count'] += 1 - self.server.SHARED['schema'] = self.headers.get('Default-Schema') + self.server.SHARED["count"] += 1 + self.server.SHARED["schema"] = self.headers.get("Default-Schema") - if self.headers.get('Authorization') is not None: - auth_header = self.headers['Authorization'].replace('Basic ', '') - credentials = b64decode(auth_header).decode('utf-8').split(":", 1) - self.server.SHARED['username'] = credentials[0] + if self.headers.get("Authorization") is not None: + auth_header = self.headers["Authorization"].replace("Basic ", "") + credentials = b64decode(auth_header).decode("utf-8").split(":", 1) + self.server.SHARED["username"] = credentials[0] if len(credentials) > 1 and credentials[1]: - self.server.SHARED['password'] = credentials[1] + self.server.SHARED["password"] = credentials[1] else: - self.server.SHARED['password'] = None + self.server.SHARED["password"] = None else: - self.server.SHARED['username'] = None + self.server.SHARED["username"] = None - if self.headers.get('X-User') is not None: - self.server.SHARED['usernameFromXUser'] = self.headers['X-User'] + if self.headers.get("X-User") is not None: + self.server.SHARED["usernameFromXUser"] = self.headers["X-User"] else: - self.server.SHARED['usernameFromXUser'] = None + self.server.SHARED["usernameFromXUser"] = None # send empty response - response = '{}' + response = "{}" self.send_response(200) self.send_header("Content-Length", len(response)) self.send_header("Content-Type", "application/json; charset=UTF-8") self.end_headers() - self.wfile.write(response.encode('utf-8')) + self.wfile.write(response.encode("utf-8")) class TestingHTTPServer(HTTPServer): """ http server providing a shared dict """ + manager = multiprocessing.Manager() SHARED = manager.dict() - SHARED['count'] = 0 - SHARED['usernameFromXUser'] = None - SHARED['username'] = None - SHARED['password'] = None - SHARED['schema'] = None + SHARED["count"] = 0 + SHARED["usernameFromXUser"] = None + SHARED["username"] = None + SHARED["password"] = None + SHARED["schema"] = None @classmethod def run_server(cls, server_address, request_handler_cls): @@ -564,13 +614,14 @@ def run_server(cls, server_address, request_handler_cls): class TestingHttpServerTestCase(TestCase): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.assertIsNotNone(self.request_handler) - self.server_address = ('127.0.0.1', random.randint(65000, 65535)) - self.server_process = ForkProcess(target=TestingHTTPServer.run_server, - args=(self.server_address, self.request_handler)) + self.server_address = ("127.0.0.1", random.randint(65000, 65535)) + self.server_process = ForkProcess( + target=TestingHTTPServer.run_server, + args=(self.server_address, self.request_handler), + ) def setUp(self): self.server_process.start() @@ -582,7 +633,7 @@ def wait_for_server(self): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.connect(self.server_address) except Exception: - time.sleep(.25) + time.sleep(0.25) else: break @@ -594,7 +645,6 @@ def clientWithKwargs(self, **kwargs): class RetryOnTimeoutServerTest(TestingHttpServerTestCase): - request_handler = TimeoutRequestHandler def setUp(self): @@ -609,38 +659,40 @@ def test_no_retry_on_read_timeout(self): try: self.client.sql("select * from fake") except ConnectionError as e: - self.assertIn('Read timed out', e.message, - msg='Error message must contain: Read timed out') - self.assertEqual(TestingHTTPServer.SHARED['count'], 1) + self.assertIn( + "Read timed out", + e.message, + msg="Error message must contain: Read timed out", + ) + self.assertEqual(TestingHTTPServer.SHARED["count"], 1) class TestDefaultSchemaHeader(TestingHttpServerTestCase): - request_handler = SharedStateRequestHandler def setUp(self): super().setUp() - self.client = self.clientWithKwargs(schema='my_custom_schema') + self.client = self.clientWithKwargs(schema="my_custom_schema") def tearDown(self): self.client.close() super().tearDown() def test_default_schema(self): - self.client.sql('SELECT 1') - self.assertEqual(TestingHTTPServer.SHARED['schema'], 'my_custom_schema') + self.client.sql("SELECT 1") + self.assertEqual(TestingHTTPServer.SHARED["schema"], "my_custom_schema") class TestUsernameSentAsHeader(TestingHttpServerTestCase): - request_handler = SharedStateRequestHandler def setUp(self): super().setUp() self.clientWithoutUsername = self.clientWithKwargs() - self.clientWithUsername = self.clientWithKwargs(username='testDBUser') - self.clientWithUsernameAndPassword = self.clientWithKwargs(username='testDBUser', - password='test:password') + self.clientWithUsername = self.clientWithKwargs(username="testDBUser") + self.clientWithUsernameAndPassword = self.clientWithKwargs( + username="testDBUser", password="test:password" + ) def tearDown(self): self.clientWithoutUsername.close() @@ -650,23 +702,26 @@ def tearDown(self): def test_username(self): self.clientWithoutUsername.sql("select * from fake") - self.assertEqual(TestingHTTPServer.SHARED['usernameFromXUser'], None) - self.assertEqual(TestingHTTPServer.SHARED['username'], None) - self.assertEqual(TestingHTTPServer.SHARED['password'], None) + self.assertEqual(TestingHTTPServer.SHARED["usernameFromXUser"], None) + self.assertEqual(TestingHTTPServer.SHARED["username"], None) + self.assertEqual(TestingHTTPServer.SHARED["password"], None) self.clientWithUsername.sql("select * from fake") - self.assertEqual(TestingHTTPServer.SHARED['usernameFromXUser'], 'testDBUser') - self.assertEqual(TestingHTTPServer.SHARED['username'], 'testDBUser') - self.assertEqual(TestingHTTPServer.SHARED['password'], None) + self.assertEqual( + TestingHTTPServer.SHARED["usernameFromXUser"], "testDBUser" + ) + self.assertEqual(TestingHTTPServer.SHARED["username"], "testDBUser") + self.assertEqual(TestingHTTPServer.SHARED["password"], None) self.clientWithUsernameAndPassword.sql("select * from fake") - self.assertEqual(TestingHTTPServer.SHARED['usernameFromXUser'], 'testDBUser') - self.assertEqual(TestingHTTPServer.SHARED['username'], 'testDBUser') - self.assertEqual(TestingHTTPServer.SHARED['password'], 'test:password') + self.assertEqual( + TestingHTTPServer.SHARED["usernameFromXUser"], "testDBUser" + ) + self.assertEqual(TestingHTTPServer.SHARED["username"], "testDBUser") + self.assertEqual(TestingHTTPServer.SHARED["password"], "test:password") class TestCrateJsonEncoder(TestCase): - def test_naive_datetime(self): data = dt.datetime.fromisoformat("2023-06-26T09:24:00.123") result = json.dumps(data, cls=CrateJsonEncoder) diff --git a/tests/client/tests.py b/tests/client/tests.py index 10c2f03d..2e6619b9 100644 --- a/tests/client/tests.py +++ b/tests/client/tests.py @@ -1,18 +1,32 @@ import doctest import unittest +from .layer import ( + HttpsTestServerLayer, + ensure_cratedb_layer, + makeSuite, + setUpCrateLayerBaseline, + setUpWithHttps, + tearDownDropEntitiesBaseline, +) from .test_connection import ConnectionTest from .test_cursor import CursorTest -from .test_http import HttpClientTest, KeepAliveClientTest, ThreadSafeHttpClientTest, ParamsTest, \ - RetryOnTimeoutServerTest, RequestsCaBundleTest, TestUsernameSentAsHeader, TestCrateJsonEncoder, \ - TestDefaultSchemaHeader -from .layer import makeSuite, setUpWithHttps, HttpsTestServerLayer, setUpCrateLayerBaseline, \ - tearDownDropEntitiesBaseline, ensure_cratedb_layer +from .test_http import ( + HttpClientTest, + KeepAliveClientTest, + ParamsTest, + RequestsCaBundleTest, + RetryOnTimeoutServerTest, + TestCrateJsonEncoder, + TestDefaultSchemaHeader, + TestUsernameSentAsHeader, + ThreadSafeHttpClientTest, +) def test_suite(): suite = unittest.TestSuite() - flags = (doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS) + flags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS # Unit tests. suite.addTest(makeSuite(CursorTest)) @@ -26,24 +40,24 @@ def test_suite(): suite.addTest(makeSuite(TestUsernameSentAsHeader)) suite.addTest(makeSuite(TestCrateJsonEncoder)) suite.addTest(makeSuite(TestDefaultSchemaHeader)) - suite.addTest(doctest.DocTestSuite('crate.client.connection')) - suite.addTest(doctest.DocTestSuite('crate.client.http')) + suite.addTest(doctest.DocTestSuite("crate.client.connection")) + suite.addTest(doctest.DocTestSuite("crate.client.http")) s = doctest.DocFileSuite( - 'docs/by-example/connection.rst', - 'docs/by-example/cursor.rst', + "docs/by-example/connection.rst", + "docs/by-example/cursor.rst", module_relative=False, optionflags=flags, - encoding='utf-8' + encoding="utf-8", ) suite.addTest(s) s = doctest.DocFileSuite( - 'docs/by-example/https.rst', + "docs/by-example/https.rst", module_relative=False, setUp=setUpWithHttps, optionflags=flags, - encoding='utf-8' + encoding="utf-8", ) s.layer = HttpsTestServerLayer() suite.addTest(s) @@ -52,14 +66,14 @@ def test_suite(): layer = ensure_cratedb_layer() s = doctest.DocFileSuite( - 'docs/by-example/http.rst', - 'docs/by-example/client.rst', - 'docs/by-example/blob.rst', + "docs/by-example/http.rst", + "docs/by-example/client.rst", + "docs/by-example/blob.rst", module_relative=False, setUp=setUpCrateLayerBaseline, tearDown=tearDownDropEntitiesBaseline, optionflags=flags, - encoding='utf-8' + encoding="utf-8", ) s.layer = layer suite.addTest(s) diff --git a/tests/testing/test_layer.py b/tests/testing/test_layer.py index 38d53922..60e88b88 100644 --- a/tests/testing/test_layer.py +++ b/tests/testing/test_layer.py @@ -22,93 +22,111 @@ import os import tempfile import urllib -from verlib2 import Version -from unittest import TestCase, mock from io import BytesIO +from unittest import TestCase, mock import urllib3 +from verlib2 import Version import crate -from crate.testing.layer import CrateLayer, prepend_http, http_url_from_host_port, wait_for_http_url +from crate.testing.layer import ( + CrateLayer, + http_url_from_host_port, + prepend_http, + wait_for_http_url, +) + from .settings import crate_path class LayerUtilsTest(TestCase): - def test_prepend_http(self): - host = prepend_http('localhost') - self.assertEqual('http://localhost', host) - host = prepend_http('http://localhost') - self.assertEqual('http://localhost', host) - host = prepend_http('https://localhost') - self.assertEqual('https://localhost', host) - host = prepend_http('http') - self.assertEqual('http://http', host) + host = prepend_http("localhost") + self.assertEqual("http://localhost", host) + host = prepend_http("http://localhost") + self.assertEqual("http://localhost", host) + host = prepend_http("https://localhost") + self.assertEqual("https://localhost", host) + host = prepend_http("http") + self.assertEqual("http://http", host) def test_http_url(self): url = http_url_from_host_port(None, None) self.assertEqual(None, url) - url = http_url_from_host_port('localhost', None) + url = http_url_from_host_port("localhost", None) self.assertEqual(None, url) url = http_url_from_host_port(None, 4200) self.assertEqual(None, url) - url = http_url_from_host_port('localhost', 4200) - self.assertEqual('http://localhost:4200', url) - url = http_url_from_host_port('https://crate', 4200) - self.assertEqual('https://crate:4200', url) + url = http_url_from_host_port("localhost", 4200) + self.assertEqual("http://localhost:4200", url) + url = http_url_from_host_port("https://crate", 4200) + self.assertEqual("https://crate:4200", url) def test_wait_for_http(self): - log = BytesIO(b'[i.c.p.h.CrateNettyHttpServerTransport] [crate] publish_address {127.0.0.1:4200}') + log = BytesIO( + b"[i.c.p.h.CrateNettyHttpServerTransport] [crate] publish_address {127.0.0.1:4200}" # noqa: E501 + ) addr = wait_for_http_url(log) - self.assertEqual('http://127.0.0.1:4200', addr) - log = BytesIO(b'[i.c.p.h.CrateNettyHttpServerTransport] [crate] publish_address {}') + self.assertEqual("http://127.0.0.1:4200", addr) + log = BytesIO( + b"[i.c.p.h.CrateNettyHttpServerTransport] [crate] publish_address {}" # noqa: E501 + ) addr = wait_for_http_url(log=log, timeout=1) self.assertEqual(None, addr) - @mock.patch.object(crate.testing.layer, "_download_and_extract", lambda uri, directory: None) + @mock.patch.object( + crate.testing.layer, + "_download_and_extract", + lambda uri, directory: None, + ) def test_layer_from_uri(self): """ The CrateLayer can also be created by providing an URI that points to a CrateDB tarball. """ - with urllib.request.urlopen("https://crate.io/versions.json") as response: + with urllib.request.urlopen( + "https://crate.io/versions.json" + ) as response: versions = json.loads(response.read().decode()) version = versions["crate_testing"] self.assertGreaterEqual(Version(version), Version("4.5.0")) - uri = "https://cdn.crate.io/downloads/releases/crate-{}.tar.gz".format(version) + uri = "https://cdn.crate.io/downloads/releases/crate-{}.tar.gz".format( + version + ) layer = CrateLayer.from_uri(uri, name="crate-by-uri", http_port=42203) self.assertIsInstance(layer, CrateLayer) - @mock.patch.dict('os.environ', {}, clear=True) + @mock.patch.dict("os.environ", {}, clear=True) def test_java_home_env_not_set(self): with tempfile.TemporaryDirectory() as tmpdir: - layer = CrateLayer('java-home-test', tmpdir) - # JAVA_HOME must not be set to `None`, since it would be interpreted as a - # string 'None', and therefore intepreted as a path - self.assertEqual(layer.env['JAVA_HOME'], '') + layer = CrateLayer("java-home-test", tmpdir) + # JAVA_HOME must not be set to `None`: It would be literally + # interpreted as a string 'None', which is an invalid path. + self.assertEqual(layer.env["JAVA_HOME"], "") - @mock.patch.dict('os.environ', {}, clear=True) + @mock.patch.dict("os.environ", {}, clear=True) def test_java_home_env_set(self): - java_home = '/usr/lib/jvm/java-11-openjdk-amd64' + java_home = "/usr/lib/jvm/java-11-openjdk-amd64" with tempfile.TemporaryDirectory() as tmpdir: - os.environ['JAVA_HOME'] = java_home - layer = CrateLayer('java-home-test', tmpdir) - self.assertEqual(layer.env['JAVA_HOME'], java_home) + os.environ["JAVA_HOME"] = java_home + layer = CrateLayer("java-home-test", tmpdir) + self.assertEqual(layer.env["JAVA_HOME"], java_home) - @mock.patch.dict('os.environ', {}, clear=True) + @mock.patch.dict("os.environ", {}, clear=True) def test_java_home_env_override(self): - java_11_home = '/usr/lib/jvm/java-11-openjdk-amd64' - java_12_home = '/usr/lib/jvm/java-12-openjdk-amd64' + java_11_home = "/usr/lib/jvm/java-11-openjdk-amd64" + java_12_home = "/usr/lib/jvm/java-12-openjdk-amd64" with tempfile.TemporaryDirectory() as tmpdir: - os.environ['JAVA_HOME'] = java_11_home - layer = CrateLayer('java-home-test', tmpdir, env={'JAVA_HOME': java_12_home}) - self.assertEqual(layer.env['JAVA_HOME'], java_12_home) + os.environ["JAVA_HOME"] = java_11_home + layer = CrateLayer( + "java-home-test", tmpdir, env={"JAVA_HOME": java_12_home} + ) + self.assertEqual(layer.env["JAVA_HOME"], java_12_home) class LayerTest(TestCase): - def test_basic(self): """ This layer starts and stops a ``Crate`` instance on a given host, port, @@ -118,13 +136,14 @@ def test_basic(self): port = 44219 transport_port = 44319 - layer = CrateLayer('crate', - crate_home=crate_path(), - host='127.0.0.1', - port=port, - transport_port=transport_port, - cluster_name='my_cluster' - ) + layer = CrateLayer( + "crate", + crate_home=crate_path(), + host="127.0.0.1", + port=port, + transport_port=transport_port, + cluster_name="my_cluster", + ) # The working directory is defined on layer instantiation. # It is sometimes required to know it before starting the layer. @@ -142,7 +161,7 @@ def test_basic(self): http = urllib3.PoolManager() stats_uri = "http://127.0.0.1:{0}/".format(port) - response = http.request('GET', stats_uri) + response = http.request("GET", stats_uri) self.assertEqual(response.status, 200) # The layer can be shutdown using its `stop()` method. @@ -150,91 +169,98 @@ def test_basic(self): def test_dynamic_http_port(self): """ - It is also possible to define a port range instead of a static HTTP port for the layer. + Verify defining a port range instead of a static HTTP port. + + CrateDB will start with the first available port in the given range and + the test layer obtains the chosen port from the startup logs of the + CrateDB process. - Crate will start with the first available port in the given range and the test - layer obtains the chosen port from the startup logs of the Crate process. - Note, that this feature requires a logging configuration with at least loglevel - ``INFO`` on ``http``. + Note that this feature requires a logging configuration with at least + loglevel ``INFO`` on ``http``. """ - port = '44200-44299' - layer = CrateLayer('crate', crate_home=crate_path(), port=port) + port = "44200-44299" + layer = CrateLayer("crate", crate_home=crate_path(), port=port) layer.start() self.assertRegex(layer.crate_servers[0], r"http://127.0.0.1:442\d\d") layer.stop() def test_default_settings(self): """ - Starting a CrateDB layer leaving out optional parameters will apply the following - defaults. + Starting a CrateDB layer leaving out optional parameters will apply + the following defaults. - The default http port is the first free port in the range of ``4200-4299``, - the default transport port is the first free port in the range of ``4300-4399``, - the host defaults to ``127.0.0.1``. + The default http port is the first free port in the range of + ``4200-4299``, the default transport port is the first free port in + the range of ``4300-4399``, the host defaults to ``127.0.0.1``. The command to call is ``bin/crate`` inside the ``crate_home`` path. The default config file is ``config/crate.yml`` inside ``crate_home``. The default cluster name will be auto generated using the HTTP port. """ - layer = CrateLayer('crate_defaults', crate_home=crate_path()) + layer = CrateLayer("crate_defaults", crate_home=crate_path()) layer.start() self.assertEqual(layer.crate_servers[0], "http://127.0.0.1:4200") layer.stop() def test_additional_settings(self): """ - The ``Crate`` layer can be started with additional settings as well. - Add a dictionary for keyword argument ``settings`` which contains your settings. - Those additional setting will override settings given as keyword argument. + The CrateDB test layer can be started with additional settings as well. - The settings will be handed over to the ``Crate`` process with the ``-C`` flag. - So the setting ``threadpool.bulk.queue_size: 100`` becomes - the command line flag: ``-Cthreadpool.bulk.queue_size=100``:: + Add a dictionary for keyword argument ``settings`` which contains your + settings. Those additional setting will override settings given as + keyword argument. + + The settings will be handed over to the ``Crate`` process with the + ``-C`` flag. So, the setting ``threadpool.bulk.queue_size: 100`` + becomes the command line flag: ``-Cthreadpool.bulk.queue_size=100``:: """ layer = CrateLayer( - 'custom', + "custom", crate_path(), port=44401, settings={ "cluster.graceful_stop.min_availability": "none", - "http.port": 44402 - } + "http.port": 44402, + }, ) layer.start() self.assertEqual(layer.crate_servers[0], "http://127.0.0.1:44402") - self.assertIn("-Ccluster.graceful_stop.min_availability=none", layer.start_cmd) + self.assertIn( + "-Ccluster.graceful_stop.min_availability=none", layer.start_cmd + ) layer.stop() def test_verbosity(self): """ - The test layer hides the standard output of Crate per default. To increase the - verbosity level the additional keyword argument ``verbose`` needs to be set - to ``True``:: + The test layer hides the standard output of Crate per default. + + To increase the verbosity level, the additional keyword argument + ``verbose`` needs to be set to ``True``:: """ - layer = CrateLayer('crate', - crate_home=crate_path(), - verbose=True) + layer = CrateLayer("crate", crate_home=crate_path(), verbose=True) layer.start() self.assertTrue(layer.verbose) layer.stop() def test_environment_variables(self): """ - It is possible to provide environment variables for the ``Crate`` testing - layer. + Verify providing environment variables for the CrateDB testing layer. """ - layer = CrateLayer('crate', - crate_home=crate_path(), - env={"CRATE_HEAP_SIZE": "300m"}) + layer = CrateLayer( + "crate", crate_home=crate_path(), env={"CRATE_HEAP_SIZE": "300m"} + ) layer.start() sql_uri = layer.crate_servers[0] + "/_sql" http = urllib3.PoolManager() - response = http.urlopen('POST', sql_uri, - body='{"stmt": "select heap[\'max\'] from sys.nodes"}') - json_response = json.loads(response.data.decode('utf-8')) + response = http.urlopen( + "POST", + sql_uri, + body='{"stmt": "select heap[\'max\'] from sys.nodes"}', + ) + json_response = json.loads(response.data.decode("utf-8")) self.assertEqual(json_response["rows"][0][0], 314572800) @@ -243,25 +269,25 @@ def test_environment_variables(self): def test_cluster(self): """ To start a cluster of ``Crate`` instances, give each instance the same - ``cluster_name``. If you want to start instances on the same machine then + ``cluster_name``. If you want to start instances on the same machine, use value ``_local_`` for ``host`` and give every node different ports:: """ cluster_layer1 = CrateLayer( - 'crate1', + "crate1", crate_path(), - host='_local_', - cluster_name='my_cluster', + host="_local_", + cluster_name="my_cluster", ) cluster_layer2 = CrateLayer( - 'crate2', + "crate2", crate_path(), - host='_local_', - cluster_name='my_cluster', - settings={"discovery.initial_state_timeout": "10s"} + host="_local_", + cluster_name="my_cluster", + settings={"discovery.initial_state_timeout": "10s"}, ) - # If we start both layers, they will, after a small amount of time, find each other - # and form a cluster. + # If we start both layers, they will, after a small amount of time, + # find each other, and form a cluster. cluster_layer1.start() cluster_layer2.start() @@ -270,13 +296,18 @@ def test_cluster(self): def num_cluster_nodes(crate_layer): sql_uri = crate_layer.crate_servers[0] + "/_sql" - response = http.urlopen('POST', sql_uri, body='{"stmt":"select count(*) from sys.nodes"}') - json_response = json.loads(response.data.decode('utf-8')) + response = http.urlopen( + "POST", + sql_uri, + body='{"stmt":"select count(*) from sys.nodes"}', + ) + json_response = json.loads(response.data.decode("utf-8")) return json_response["rows"][0][0] # We might have to wait a moment before the cluster is finally created. num_nodes = num_cluster_nodes(cluster_layer1) import time + retries = 0 while num_nodes < 2: # pragma: no cover time.sleep(1) diff --git a/tests/testing/tests.py b/tests/testing/tests.py index 2a6e06d0..4ba58d91 100644 --- a/tests/testing/tests.py +++ b/tests/testing/tests.py @@ -21,8 +21,8 @@ # software solely pursuant to the terms of the relevant commercial agreement. import unittest -from .test_layer import LayerUtilsTest, LayerTest +from .test_layer import LayerTest, LayerUtilsTest makeSuite = unittest.TestLoader().loadTestsFromTestCase