diff --git a/doc/changes/changes_0.1.1.md b/doc/changes/changes_0.1.1.md index 636ab35e..cc989ecf 100644 --- a/doc/changes/changes_0.1.1.md +++ b/doc/changes/changes_0.1.1.md @@ -18,3 +18,4 @@ Code name: * #208: Replaced access to private attribute by public * #203: Cleaned-up package names and directory structure * #217: Rename dataflow abstraction files +* #219: Applied PTB checks and fixes diff --git a/exasol/analytics/__init__.py b/exasol/analytics/__init__.py index b794fd40..3dc1f76b 100644 --- a/exasol/analytics/__init__.py +++ b/exasol/analytics/__init__.py @@ -1 +1 @@ -__version__ = '0.1.0' +__version__ = "0.1.0" diff --git a/exasol/analytics/query_handler/context/connection_name.py b/exasol/analytics/query_handler/context/connection_name.py index 3d042843..5efdc89e 100644 --- a/exasol/analytics/query_handler/context/connection_name.py +++ b/exasol/analytics/query_handler/context/connection_name.py @@ -1,10 +1,7 @@ - -from exasol.analytics.schema import ( - DBObjectNameImpl, - DBObjectName, -) from typeguard import typechecked +from exasol.analytics.schema import DBObjectName, DBObjectNameImpl + class ConnectionName(DBObjectName): """A DBObjectName class which represents the name of a connection object""" diff --git a/exasol/analytics/query_handler/context/connection_name_proxy.py b/exasol/analytics/query_handler/context/connection_name_proxy.py index 726375d6..8daf6561 100644 --- a/exasol/analytics/query_handler/context/connection_name_proxy.py +++ b/exasol/analytics/query_handler/context/connection_name_proxy.py @@ -1,5 +1,10 @@ -from exasol.analytics.query_handler.context.connection_name import ConnectionName, ConnectionNameImpl -from exasol.analytics.query_handler.context.proxy.db_object_name_proxy import DBObjectNameProxy +from exasol.analytics.query_handler.context.connection_name import ( + ConnectionName, + ConnectionNameImpl, +) +from exasol.analytics.query_handler.context.proxy.db_object_name_proxy import ( + DBObjectNameProxy, +) from exasol.analytics.query_handler.query.drop.connection import DropConnectionQuery from exasol.analytics.query_handler.query.interface import Query diff --git a/exasol/analytics/query_handler/context/proxy/bucketfs_location_proxy.py b/exasol/analytics/query_handler/context/proxy/bucketfs_location_proxy.py index 6de8a1e9..73dd435b 100644 --- a/exasol/analytics/query_handler/context/proxy/bucketfs_location_proxy.py +++ b/exasol/analytics/query_handler/context/proxy/bucketfs_location_proxy.py @@ -1,8 +1,9 @@ import logging -from exasol.analytics.query_handler.context.proxy.object_proxy import ObjectProxy import exasol.bucketfs as bfs +from exasol.analytics.query_handler.context.proxy.object_proxy import ObjectProxy + LOGGER = logging.getLogger(__file__) @@ -18,7 +19,9 @@ def bucketfs_location(self) -> bfs.path.PathLike: def cleanup(self): if self._not_released: - raise Exception("Cleanup of BucketFSLocationProxy only allowed after release.") + raise Exception( + "Cleanup of BucketFSLocationProxy only allowed after release." + ) files = self._list_files() for file in files: self._remove_file(file) @@ -33,7 +36,11 @@ def _list_files(self): try: return list(self._bucketfs_location.iterdir()) except FileNotFoundError as e: - LOGGER.debug(f"File not found {self._bucketfs_location.as_udf_path} during cleanup.") + LOGGER.debug( + f"File not found {self._bucketfs_location.as_udf_path} during cleanup." + ) except Exception as e: - LOGGER.exception(f"Got exception during listing files in temporary BucketFSLocation") + LOGGER.exception( + f"Got exception during listing files in temporary BucketFSLocation" + ) return [] diff --git a/exasol/analytics/query_handler/context/proxy/db_object_name_proxy.py b/exasol/analytics/query_handler/context/proxy/db_object_name_proxy.py index e8f68850..26a33559 100644 --- a/exasol/analytics/query_handler/context/proxy/db_object_name_proxy.py +++ b/exasol/analytics/query_handler/context/proxy/db_object_name_proxy.py @@ -1,14 +1,12 @@ from abc import abstractmethod -from typing import TypeVar, Generic - -from exasol.analytics.schema import DBObjectName - -from exasol.analytics.utils.repr_generation_for_object import generate_repr_for_object +from typing import Generic, TypeVar from exasol.analytics.query_handler.context.proxy.object_proxy import ObjectProxy from exasol.analytics.query_handler.query.interface import Query +from exasol.analytics.schema import DBObjectName +from exasol.analytics.utils.repr_generation_for_object import generate_repr_for_object -NameType = TypeVar('NameType', bound=DBObjectName) +NameType = TypeVar("NameType", bound=DBObjectName) class DBObjectNameProxy(ObjectProxy, DBObjectName, Generic[NameType]): diff --git a/exasol/analytics/query_handler/context/proxy/db_object_name_with_schema_proxy.py b/exasol/analytics/query_handler/context/proxy/db_object_name_with_schema_proxy.py index 3691eec6..8b403291 100644 --- a/exasol/analytics/query_handler/context/proxy/db_object_name_with_schema_proxy.py +++ b/exasol/analytics/query_handler/context/proxy/db_object_name_with_schema_proxy.py @@ -1,16 +1,16 @@ -from typing import TypeVar, Generic +from typing import Generic, TypeVar -from exasol.analytics.schema import ( - DBObjectNameWithSchema, - SchemaName, +from exasol.analytics.query_handler.context.proxy.db_object_name_proxy import ( + DBObjectNameProxy, ) +from exasol.analytics.schema import DBObjectNameWithSchema, SchemaName -from exasol.analytics.query_handler.context.proxy.db_object_name_proxy import DBObjectNameProxy +NameType = TypeVar("NameType", bound=DBObjectNameWithSchema) -NameType = TypeVar('NameType', bound=DBObjectNameWithSchema) - -class DBObjectNameWithSchemaProxy(DBObjectNameProxy[NameType], DBObjectNameWithSchema, Generic[NameType]): +class DBObjectNameWithSchemaProxy( + DBObjectNameProxy[NameType], DBObjectNameWithSchema, Generic[NameType] +): def __init__(self, db_object_name_with_schema: NameType, global_counter_value: int): super().__init__(db_object_name_with_schema, global_counter_value) diff --git a/exasol/analytics/query_handler/context/proxy/drop_udf_query.py b/exasol/analytics/query_handler/context/proxy/drop_udf_query.py index 0ec793e3..f09cac04 100644 --- a/exasol/analytics/query_handler/context/proxy/drop_udf_query.py +++ b/exasol/analytics/query_handler/context/proxy/drop_udf_query.py @@ -1,7 +1,5 @@ -from exasol.analytics.schema import UDFName - - from exasol.analytics.query_handler.query.drop.interface import DropQuery +from exasol.analytics.schema import UDFName class DropUDFQuery(DropQuery): diff --git a/exasol/analytics/query_handler/context/proxy/table_like_name_proxy.py b/exasol/analytics/query_handler/context/proxy/table_like_name_proxy.py index 5e88103c..dfabf310 100644 --- a/exasol/analytics/query_handler/context/proxy/table_like_name_proxy.py +++ b/exasol/analytics/query_handler/context/proxy/table_like_name_proxy.py @@ -1,14 +1,16 @@ from typing import Generic, TypeVar +from exasol.analytics.query_handler.context.proxy.db_object_name_with_schema_proxy import ( + DBObjectNameWithSchemaProxy, +) from exasol.analytics.schema import TableLikeName +NameType = TypeVar("NameType", bound=TableLikeName) -from exasol.analytics.query_handler.context.proxy.db_object_name_with_schema_proxy import DBObjectNameWithSchemaProxy -NameType = TypeVar('NameType', bound=TableLikeName) - - -class TableLikeNameProxy(DBObjectNameWithSchemaProxy[NameType], TableLikeName, Generic[NameType]): +class TableLikeNameProxy( + DBObjectNameWithSchemaProxy[NameType], TableLikeName, Generic[NameType] +): def __init__(self, table_like_name: NameType, global_counter_value: int): super().__init__(table_like_name, global_counter_value) diff --git a/exasol/analytics/query_handler/context/proxy/table_name.py b/exasol/analytics/query_handler/context/proxy/table_name.py index 09778331..f77a2dcc 100644 --- a/exasol/analytics/query_handler/context/proxy/table_name.py +++ b/exasol/analytics/query_handler/context/proxy/table_name.py @@ -1,9 +1,9 @@ -from exasol.analytics.schema import TableName - - -from exasol.analytics.query_handler.context.proxy.table_like_name_proxy import TableLikeNameProxy +from exasol.analytics.query_handler.context.proxy.table_like_name_proxy import ( + TableLikeNameProxy, +) from exasol.analytics.query_handler.query.drop.table import DropTableQuery from exasol.analytics.query_handler.query.interface import Query +from exasol.analytics.schema import TableName class TableNameProxy(TableLikeNameProxy[TableName], TableName): diff --git a/exasol/analytics/query_handler/context/proxy/udf_name.py b/exasol/analytics/query_handler/context/proxy/udf_name.py index 0ada4999..556f71db 100644 --- a/exasol/analytics/query_handler/context/proxy/udf_name.py +++ b/exasol/analytics/query_handler/context/proxy/udf_name.py @@ -1,9 +1,9 @@ -from exasol.analytics.schema import UDFName - - -from exasol.analytics.query_handler.context.proxy.db_object_name_with_schema_proxy import DBObjectNameWithSchemaProxy +from exasol.analytics.query_handler.context.proxy.db_object_name_with_schema_proxy import ( + DBObjectNameWithSchemaProxy, +) from exasol.analytics.query_handler.context.proxy.drop_udf_query import DropUDFQuery from exasol.analytics.query_handler.query.interface import Query +from exasol.analytics.schema import UDFName class UDFNameProxy(DBObjectNameWithSchemaProxy[UDFName], UDFName): diff --git a/exasol/analytics/query_handler/context/proxy/view_name_proxy.py b/exasol/analytics/query_handler/context/proxy/view_name_proxy.py index 876c959d..5343ec4c 100644 --- a/exasol/analytics/query_handler/context/proxy/view_name_proxy.py +++ b/exasol/analytics/query_handler/context/proxy/view_name_proxy.py @@ -1,9 +1,9 @@ -from exasol.analytics.schema import ViewName - - -from exasol.analytics.query_handler.context.proxy.table_like_name_proxy import TableLikeNameProxy +from exasol.analytics.query_handler.context.proxy.table_like_name_proxy import ( + TableLikeNameProxy, +) from exasol.analytics.query_handler.query.drop.view import DropViewQuery from exasol.analytics.query_handler.query.interface import Query +from exasol.analytics.schema import ViewName class ViewNameProxy(TableLikeNameProxy[ViewName], ViewName): diff --git a/exasol/analytics/query_handler/context/query_handler_context.py b/exasol/analytics/query_handler/context/query_handler_context.py index 12880a33..5672efb4 100644 --- a/exasol/analytics/query_handler/context/query_handler_context.py +++ b/exasol/analytics/query_handler/context/query_handler_context.py @@ -1,14 +1,11 @@ import abc from abc import ABC -from exasol.analytics.schema import ( - TableName, - UDFName, - ViewName, -) - from exasol.analytics.query_handler.context.connection_name import ConnectionName -from exasol.analytics.query_handler.context.proxy.bucketfs_location_proxy import BucketFSLocationProxy +from exasol.analytics.query_handler.context.proxy.bucketfs_location_proxy import ( + BucketFSLocationProxy, +) +from exasol.analytics.schema import TableName, UDFName, ViewName class QueryHandlerContext(ABC): diff --git a/exasol/analytics/query_handler/context/scope.py b/exasol/analytics/query_handler/context/scope.py index 7ae8a0d5..d68d27b5 100644 --- a/exasol/analytics/query_handler/context/scope.py +++ b/exasol/analytics/query_handler/context/scope.py @@ -1,8 +1,10 @@ import enum -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod -from exasol.analytics.query_handler.context.query_handler_context import QueryHandlerContext from exasol.analytics.query_handler.context.proxy.object_proxy import ObjectProxy +from exasol.analytics.query_handler.context.query_handler_context import ( + QueryHandlerContext, +) class Connection(ABC): @@ -42,8 +44,11 @@ def get_child_query_handler_context(self) -> "ScopeQueryHandlerContext": pass @abstractmethod - def transfer_object_to(self, object_proxy: ObjectProxy, - scope_query_handler_context: "ScopeQueryHandlerContext"): + def transfer_object_to( + self, + object_proxy: ObjectProxy, + scope_query_handler_context: "ScopeQueryHandlerContext", + ): """ This function transfers the ownership of the object to a different context. That means, that the object isn't released if this context is released, diff --git a/exasol/analytics/query_handler/context/top_level_query_handler_context.py b/exasol/analytics/query_handler/context/top_level_query_handler_context.py index 92e14d98..f71215c2 100644 --- a/exasol/analytics/query_handler/context/top_level_query_handler_context.py +++ b/exasol/analytics/query_handler/context/top_level_query_handler_context.py @@ -1,29 +1,41 @@ import textwrap import traceback from abc import ABC -from typing import Set, List, Callable +from typing import Callable, List, Set import exasol.bucketfs as bfs -from exasol.analytics.schema import ( - SchemaName, - TableNameBuilder, - UDFNameBuilder, - UDFName, - ViewNameBuilder, - TableName, - ViewName, -) -from exasol.analytics.query_handler.context.connection_name_proxy import ConnectionNameProxy -from exasol.analytics.query_handler.context.connection_name import ConnectionName, ConnectionNameImpl -from exasol.analytics.query_handler.context.proxy.bucketfs_location_proxy import BucketFSLocationProxy -from exasol.analytics.query_handler.context.proxy.db_object_name_proxy import DBObjectNameProxy +from exasol.analytics.query_handler.context.connection_name import ( + ConnectionName, + ConnectionNameImpl, +) +from exasol.analytics.query_handler.context.connection_name_proxy import ( + ConnectionNameProxy, +) +from exasol.analytics.query_handler.context.proxy.bucketfs_location_proxy import ( + BucketFSLocationProxy, +) +from exasol.analytics.query_handler.context.proxy.db_object_name_proxy import ( + DBObjectNameProxy, +) from exasol.analytics.query_handler.context.proxy.object_proxy import ObjectProxy from exasol.analytics.query_handler.context.proxy.table_name import TableNameProxy from exasol.analytics.query_handler.context.proxy.udf_name import UDFNameProxy from exasol.analytics.query_handler.context.proxy.view_name_proxy import ViewNameProxy -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext, Connection +from exasol.analytics.query_handler.context.scope import ( + Connection, + ScopeQueryHandlerContext, +) from exasol.analytics.query_handler.query.interface import Query +from exasol.analytics.schema import ( + SchemaName, + TableName, + TableNameBuilder, + UDFName, + UDFNameBuilder, + ViewName, + ViewNameBuilder, +) class TemporaryObjectCounter: @@ -38,26 +50,40 @@ def get_current_value(self) -> int: class ChildContextNotReleasedError(Exception): - def __init__(self, - not_released_child_contexts: List[ScopeQueryHandlerContext], - exceptions_thrown_by_not_released_child_contexts: List["ChildContextNotReleasedError"]): + def __init__( + self, + not_released_child_contexts: List[ScopeQueryHandlerContext], + exceptions_thrown_by_not_released_child_contexts: List[ + "ChildContextNotReleasedError" + ], + ): """ :param not_released_child_contexts: A list of child contexts which were not yet released :param exceptions_thrown_by_not_released_child_contexts: A list of ChildContextNotReleasedError thrown by the call to _invalidate of the child contexts """ - self.exceptions_thrown_by_not_released_child_contexts = exceptions_thrown_by_not_released_child_contexts + self.exceptions_thrown_by_not_released_child_contexts = ( + exceptions_thrown_by_not_released_child_contexts + ) self.not_released_child_contexts = not_released_child_contexts - concatenated_contexts = "\n- ".join([str(c) for c in self.get_all_not_released_contexts()]) - self.message = \ - f"The following child contexts were not released,\n" \ - f"please release all contexts to avoid ressource leakage:\n" \ + concatenated_contexts = "\n- ".join( + [str(c) for c in self.get_all_not_released_contexts()] + ) + self.message = ( + f"The following child contexts were not released,\n" + f"please release all contexts to avoid ressource leakage:\n" f"- {concatenated_contexts}\n" - super(ChildContextNotReleasedError, self).__init__(self.message) + ) + super().__init__(self.message) def get_all_not_released_contexts(self): - result = sum([e.get_all_not_released_contexts() for e in - self.exceptions_thrown_by_not_released_child_contexts], []) + result = sum( + [ + e.get_all_not_released_contexts() + for e in self.exceptions_thrown_by_not_released_child_contexts + ], + [], + ) result = self.not_released_child_contexts + result return result @@ -66,12 +92,14 @@ def get_all_not_released_contexts(self): class _ScopeQueryHandlerContextBase(ScopeQueryHandlerContext, ABC): - def __init__(self, - temporary_bucketfs_location: bfs.path.PathLike, - temporary_db_object_name_prefix: str, - temporary_schema_name: str, - connection_lookup: ConnectionLookup, - global_temporary_object_counter: TemporaryObjectCounter): + def __init__( + self, + temporary_bucketfs_location: bfs.path.PathLike, + temporary_db_object_name_prefix: str, + temporary_schema_name: str, + connection_lookup: ConnectionLookup, + global_temporary_object_counter: TemporaryObjectCounter, + ): self._connection_lookup = connection_lookup self._global_temporary_object_counter = global_temporary_object_counter self._temporary_schema_name = temporary_schema_name @@ -104,7 +132,8 @@ def _get_temporary_table_name(self) -> TableName: temporary_name = self._get_temporary_db_object_name() temporary_table_name = TableNameBuilder.create( name=temporary_name, - schema=SchemaName(schema_name=self._temporary_schema_name)) + schema=SchemaName(schema_name=self._temporary_schema_name), + ) return temporary_table_name def _get_temporary_view_name(self) -> ViewName: @@ -112,7 +141,8 @@ def _get_temporary_view_name(self) -> ViewName: temporary_name = self._get_temporary_db_object_name() temporary_view_name = ViewNameBuilder.create( name=temporary_name, - schema=SchemaName(schema_name=self._temporary_schema_name)) + schema=SchemaName(schema_name=self._temporary_schema_name), + ) return temporary_view_name def _get_temporary_udf_name(self) -> UDFName: @@ -120,7 +150,8 @@ def _get_temporary_udf_name(self) -> UDFName: temporary_name = self._get_temporary_db_object_name() temporary_script_name = UDFNameBuilder.create( name=temporary_name, - schema=SchemaName(schema_name=self._temporary_schema_name)) + schema=SchemaName(schema_name=self._temporary_schema_name), + ) return temporary_script_name def _get_temporary_connection_name(self) -> ConnectionName: @@ -130,7 +161,9 @@ def _get_temporary_connection_name(self) -> ConnectionName: return temporary_connection_name def _get_temporary_db_object_name(self) -> str: - temporary_name = f"{self._temporary_db_object_name_prefix}_{self._get_counter_value()}" + temporary_name = ( + f"{self._temporary_db_object_name_prefix}_{self._get_counter_value()}" + ) return temporary_name def _own_object(self, object_proxy: ObjectProxy): @@ -145,39 +178,49 @@ def get_temporary_name(self) -> str: def get_temporary_table_name(self) -> TableName: self._check_if_released() temporary_table_name = self._get_temporary_table_name() - object_proxy = TableNameProxy(temporary_table_name, - self._global_temporary_object_counter.get_current_value()) + object_proxy = TableNameProxy( + temporary_table_name, + self._global_temporary_object_counter.get_current_value(), + ) self._own_object(object_proxy) return object_proxy def get_temporary_view_name(self) -> ViewName: self._check_if_released() temporary_view_name = self._get_temporary_view_name() - object_proxy = ViewNameProxy(temporary_view_name, - self._global_temporary_object_counter.get_current_value()) + object_proxy = ViewNameProxy( + temporary_view_name, + self._global_temporary_object_counter.get_current_value(), + ) self._own_object(object_proxy) return object_proxy def get_temporary_udf_name(self) -> UDFName: self._check_if_released() temporary_script_name = self._get_temporary_udf_name() - object_proxy = UDFNameProxy(temporary_script_name, - self._global_temporary_object_counter.get_current_value()) + object_proxy = UDFNameProxy( + temporary_script_name, + self._global_temporary_object_counter.get_current_value(), + ) self._own_object(object_proxy) return object_proxy def get_temporary_connection_name(self) -> ConnectionName: self._check_if_released() temporary_connection_name = self._get_temporary_connection_name() - object_proxy = ConnectionNameProxy(connection_name=temporary_connection_name, - global_counter_value=self._global_temporary_object_counter.get_current_value()) + object_proxy = ConnectionNameProxy( + connection_name=temporary_connection_name, + global_counter_value=self._global_temporary_object_counter.get_current_value(), + ) self._own_object(object_proxy) return object_proxy def get_temporary_bucketfs_location(self) -> BucketFSLocationProxy: self._check_if_released() temporary_path = self._get_temporary_path() - child_bucketfs_location = self._temporary_bucketfs_location.joinpath(temporary_path) + child_bucketfs_location = self._temporary_bucketfs_location.joinpath( + temporary_path + ) object_proxy = BucketFSLocationProxy(child_bucketfs_location) self._own_object(object_proxy) return object_proxy @@ -189,25 +232,32 @@ def _get_temporary_path(self): def get_child_query_handler_context(self) -> ScopeQueryHandlerContext: self._check_if_released() temporary_path = self._get_temporary_path() - new_temporary_bucketfs_location = self._temporary_bucketfs_location.joinpath(temporary_path) + new_temporary_bucketfs_location = self._temporary_bucketfs_location.joinpath( + temporary_path + ) child_query_handler_context = _ChildQueryHandlerContext( self, new_temporary_bucketfs_location, self._get_temporary_db_object_name(), self._temporary_schema_name, self._connection_lookup, - self._global_temporary_object_counter + self._global_temporary_object_counter, ) self._child_query_handler_context_list.append(child_query_handler_context) return child_query_handler_context def _is_child(self, scope_query_handler_context: ScopeQueryHandlerContext) -> bool: - result = isinstance(scope_query_handler_context, _ChildQueryHandlerContext) and \ - scope_query_handler_context._parent == self + result = ( + isinstance(scope_query_handler_context, _ChildQueryHandlerContext) + and scope_query_handler_context._parent == self + ) return result - def _transfer_object_to(self, object_proxy: ObjectProxy, - scope_query_handler_context: ScopeQueryHandlerContext) -> None: + def _transfer_object_to( + self, + object_proxy: ObjectProxy, + scope_query_handler_context: ScopeQueryHandlerContext, + ) -> None: self._check_if_released() if object_proxy in self._owned_object_proxies: if isinstance(scope_query_handler_context, _ScopeQueryHandlerContextBase): @@ -216,8 +266,10 @@ def _transfer_object_to(self, object_proxy: ObjectProxy, if not self._is_child(scope_query_handler_context): self._remove_object(object_proxy) else: - raise ValueError(f"{scope_query_handler_context.__class__} not allowed, " - f"use a context created with get_child_query_handler_context") + raise ValueError( + f"{scope_query_handler_context.__class__} not allowed, " + f"use a context created with get_child_query_handler_context" + ) else: raise RuntimeError("Object not owned by this ScopeQueryHandlerContext.") @@ -233,7 +285,9 @@ def _check_if_released(self): def _release(self): self._check_if_released() - self._released_object_proxies = self._released_object_proxies.union(self._not_released_object_proxies) + self._released_object_proxies = self._released_object_proxies.union( + self._not_released_object_proxies + ) self._not_released_object_proxies = set() self._owned_object_proxies = set() self._not_released = False @@ -252,7 +306,7 @@ def _check_if_children_released(self): if not_released_child_contexts: raise ChildContextNotReleasedError( not_released_child_contexts=not_released_child_contexts, - exceptions_thrown_by_not_released_child_contexts=exceptions_from_not_released_child_contexts + exceptions_thrown_by_not_released_child_contexts=exceptions_from_not_released_child_contexts, ) def _register_object(self, object_proxy: ObjectProxy): @@ -271,17 +325,21 @@ def get_connection(self, name: str) -> Connection: class TopLevelQueryHandlerContext(_ScopeQueryHandlerContextBase): - def __init__(self, - temporary_bucketfs_location: bfs.path.PathLike, - temporary_db_object_name_prefix: str, - temporary_schema_name: str, - connection_lookup: ConnectionLookup, - global_temporary_object_counter: TemporaryObjectCounter = TemporaryObjectCounter()): - super().__init__(temporary_bucketfs_location, - temporary_db_object_name_prefix, - temporary_schema_name, - connection_lookup, - global_temporary_object_counter) + def __init__( + self, + temporary_bucketfs_location: bfs.path.PathLike, + temporary_db_object_name_prefix: str, + temporary_schema_name: str, + connection_lookup: ConnectionLookup, + global_temporary_object_counter: TemporaryObjectCounter = TemporaryObjectCounter(), + ): + super().__init__( + temporary_bucketfs_location, + temporary_db_object_name_prefix, + temporary_schema_name, + connection_lookup, + global_temporary_object_counter, + ) def _release_object(self, object_proxy: ObjectProxy): super()._release_object(object_proxy) @@ -295,17 +353,25 @@ def cleanup_released_object_proxies(self) -> List[Query]: The clean up queries are sorted in reverse order of their creation, such that, we remove first objects that might depend on previous objects. """ - db_objects: List[DBObjectNameProxy] = \ - [object_proxy for object_proxy in self._released_object_proxies - if isinstance(object_proxy, DBObjectNameProxy)] - bucketfs_objects: List[BucketFSLocationProxy] = \ - [object_proxy for object_proxy in self._released_object_proxies - if isinstance(object_proxy, BucketFSLocationProxy)] + db_objects: List[DBObjectNameProxy] = [ + object_proxy + for object_proxy in self._released_object_proxies + if isinstance(object_proxy, DBObjectNameProxy) + ] + bucketfs_objects: List[BucketFSLocationProxy] = [ + object_proxy + for object_proxy in self._released_object_proxies + if isinstance(object_proxy, BucketFSLocationProxy) + ] self._released_object_proxies = set() self._remove_bucketfs_objects(bucketfs_objects) - reverse_sorted_db_objects = sorted(db_objects, key=lambda x: x._global_counter_value, reverse=True) - cleanup_queries = [object_proxy.get_cleanup_query() - for object_proxy in reverse_sorted_db_objects] + reverse_sorted_db_objects = sorted( + db_objects, key=lambda x: x._global_counter_value, reverse=True + ) + cleanup_queries = [ + object_proxy.get_cleanup_query() + for object_proxy in reverse_sorted_db_objects + ] return cleanup_queries @staticmethod @@ -313,8 +379,11 @@ def _remove_bucketfs_objects(bucketfs_object_proxies: List[BucketFSLocationProxy for object_proxy in bucketfs_object_proxies: object_proxy.cleanup() - def transfer_object_to(self, object_proxy: ObjectProxy, - scope_query_handler_context: ScopeQueryHandlerContext): + def transfer_object_to( + self, + object_proxy: ObjectProxy, + scope_query_handler_context: ScopeQueryHandlerContext, + ): if self._is_child(scope_query_handler_context): self._transfer_object_to(object_proxy, scope_query_handler_context) else: @@ -322,17 +391,22 @@ def transfer_object_to(self, object_proxy: ObjectProxy, class _ChildQueryHandlerContext(_ScopeQueryHandlerContextBase): - def __init__(self, parent: _ScopeQueryHandlerContextBase, - temporary_bucketfs_location: bfs.path.PathLike, - temporary_db_object_name_prefix: str, - temporary_schema_name: str, - connection_lookup: ConnectionLookup, - global_temporary_object_counter: TemporaryObjectCounter): - super().__init__(temporary_bucketfs_location, - temporary_db_object_name_prefix, - temporary_schema_name, - connection_lookup, - global_temporary_object_counter) + def __init__( + self, + parent: _ScopeQueryHandlerContextBase, + temporary_bucketfs_location: bfs.path.PathLike, + temporary_db_object_name_prefix: str, + temporary_schema_name: str, + connection_lookup: ConnectionLookup, + global_temporary_object_counter: TemporaryObjectCounter, + ): + super().__init__( + temporary_bucketfs_location, + temporary_db_object_name_prefix, + temporary_schema_name, + connection_lookup, + global_temporary_object_counter, + ) self.__parent = parent @property @@ -351,16 +425,27 @@ def _is_parent(self, scope_query_handler_context: ScopeQueryHandlerContext) -> b result = self._parent == scope_query_handler_context return result - def _is_sibling(self, scope_query_handler_context: ScopeQueryHandlerContext) -> bool: - result = isinstance(scope_query_handler_context, _ChildQueryHandlerContext) and \ - scope_query_handler_context._parent == self._parent + def _is_sibling( + self, scope_query_handler_context: ScopeQueryHandlerContext + ) -> bool: + result = ( + isinstance(scope_query_handler_context, _ChildQueryHandlerContext) + and scope_query_handler_context._parent == self._parent + ) return result - def transfer_object_to(self, object_proxy: ObjectProxy, - scope_query_handler_context: ScopeQueryHandlerContext): - if self._is_child(scope_query_handler_context) or \ - self._is_parent(scope_query_handler_context) or \ - self._is_sibling(scope_query_handler_context): + def transfer_object_to( + self, + object_proxy: ObjectProxy, + scope_query_handler_context: ScopeQueryHandlerContext, + ): + if ( + self._is_child(scope_query_handler_context) + or self._is_parent(scope_query_handler_context) + or self._is_sibling(scope_query_handler_context) + ): self._transfer_object_to(object_proxy, scope_query_handler_context) else: - raise RuntimeError("Given ScopeQueryHandlerContext not a child, parent or sibling.") + raise RuntimeError( + "Given ScopeQueryHandlerContext not a child, parent or sibling." + ) diff --git a/exasol/analytics/query_handler/deployment/aaf_exasol_lua_script_generator.py b/exasol/analytics/query_handler/deployment/aaf_exasol_lua_script_generator.py index 6a5f425c..ed5af7d3 100644 --- a/exasol/analytics/query_handler/deployment/aaf_exasol_lua_script_generator.py +++ b/exasol/analytics/query_handler/deployment/aaf_exasol_lua_script_generator.py @@ -1,20 +1,26 @@ import importlib_resources from exasol.analytics.query_handler.deployment import constants -from exasol.analytics.query_handler.deployment.exasol_lua_script_generator import ExasolLuaScriptGenerator -from exasol.analytics.query_handler.deployment.jinja_template_location import JinjaTemplateLocation -from exasol.analytics.query_handler.deployment.lua_script_bundle import LuaScriptBundle, logger +from exasol.analytics.query_handler.deployment.exasol_lua_script_generator import ( + ExasolLuaScriptGenerator, +) +from exasol.analytics.query_handler.deployment.jinja_template_location import ( + JinjaTemplateLocation, +) +from exasol.analytics.query_handler.deployment.lua_script_bundle import ( + LuaScriptBundle, + logger, +) def get_aaf_query_loop_lua_script_generator() -> ExasolLuaScriptGenerator: - base_dir = importlib_resources.files( - constants.BASE_PACKAGE) + base_dir = importlib_resources.files(constants.BASE_PACKAGE) lua_src_dir = base_dir / "lua" / "src" lua_source_files = [ lua_src_dir.joinpath("query_handler_runner_main.lua"), lua_src_dir.joinpath("query_handler_runner.lua"), lua_src_dir.joinpath("query_loop.lua"), - lua_src_dir.joinpath("exasol_script_tools.lua") + lua_src_dir.joinpath("exasol_script_tools.lua"), ] lua_main_file = lua_src_dir.joinpath("query_handler_runner_main.lua") lua_modules = [ @@ -22,17 +28,20 @@ def get_aaf_query_loop_lua_script_generator() -> ExasolLuaScriptGenerator: "query_handler_runner", "exasol_script_tools", "ExaError", - "MessageExpander" + "MessageExpander", ] jinja_template_location = JinjaTemplateLocation( package_name=constants.BASE_PACKAGE, package_path=constants.TEMPLATES_DIR, - template_file_name=constants.LUA_SCRIPT_TEMPLATE) + template_file_name=constants.LUA_SCRIPT_TEMPLATE, + ) generator = ExasolLuaScriptGenerator( - LuaScriptBundle(lua_main_file=lua_main_file, - lua_modules=lua_modules, - lua_source_files=lua_source_files), - jinja_template_location.get_template() + LuaScriptBundle( + lua_main_file=lua_main_file, + lua_modules=lua_modules, + lua_source_files=lua_source_files, + ), + jinja_template_location.get_template(), ) return generator diff --git a/exasol/analytics/query_handler/deployment/constants.py b/exasol/analytics/query_handler/deployment/constants.py index 587cb57c..32a212cb 100644 --- a/exasol/analytics/query_handler/deployment/constants.py +++ b/exasol/analytics/query_handler/deployment/constants.py @@ -1,6 +1,6 @@ import pathlib -from importlib_resources import files +from importlib_resources import files BASE_PACKAGE = "exasol.analytics" BASE_DIR = BASE_PACKAGE.replace(".", "/") @@ -8,8 +8,6 @@ OUTPUTS_DIR = pathlib.Path("resources", "outputs") SOURCE_DIR = files(f"{BASE_PACKAGE}.query_handler.udf.runner") -UDF_CALL_TEMPLATES = { - "call_udf.py": "create_query_handler.jinja.sql" -} +UDF_CALL_TEMPLATES = {"call_udf.py": "create_query_handler.jinja.sql"} LUA_SCRIPT_TEMPLATE = "create_query_loop.jinja.sql" LUA_SCRIPT_OUTPUT = pathlib.Path(BASE_DIR, OUTPUTS_DIR, "create_query_loop.sql") diff --git a/exasol/analytics/query_handler/deployment/deploy.py b/exasol/analytics/query_handler/deployment/deploy.py index 33a9c684..b6017c4e 100644 --- a/exasol/analytics/query_handler/deployment/deploy.py +++ b/exasol/analytics/query_handler/deployment/deploy.py @@ -1,16 +1,16 @@ import logging + import click -from exasol.analytics.query_handler.deployment.slc import ( - SLC_FILE_NAME, - SLC_URL_FORMATTER, -) -from exasol.analytics.query_handler.deployment import ( - scripts_deployer_cli, -) from exasol.python_extension_common.deployment.language_container_deployer_cli import ( + CustomizableParameters, language_container_deployer_main, slc_parameter_formatters, - CustomizableParameters, +) + +from exasol.analytics.query_handler.deployment import scripts_deployer_cli +from exasol.analytics.query_handler.deployment.slc import ( + SLC_FILE_NAME, + SLC_URL_FORMATTER, ) @@ -19,16 +19,20 @@ def main(): pass -slc_parameter_formatters.set_formatter(CustomizableParameters.container_url, SLC_URL_FORMATTER) -slc_parameter_formatters.set_formatter(CustomizableParameters.container_name, SLC_FILE_NAME) +slc_parameter_formatters.set_formatter( + CustomizableParameters.container_url, SLC_URL_FORMATTER +) +slc_parameter_formatters.set_formatter( + CustomizableParameters.container_name, SLC_FILE_NAME +) main.add_command(scripts_deployer_cli.scripts_deployer_main) main.add_command(language_container_deployer_main) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig( - format='%(asctime)s - %(module)s - %(message)s', - level=logging.DEBUG) + format="%(asctime)s - %(module)s - %(message)s", level=logging.DEBUG + ) main() diff --git a/exasol/analytics/query_handler/deployment/exasol_lua_script_generator.py b/exasol/analytics/query_handler/deployment/exasol_lua_script_generator.py index 4527ca55..fb816ef6 100644 --- a/exasol/analytics/query_handler/deployment/exasol_lua_script_generator.py +++ b/exasol/analytics/query_handler/deployment/exasol_lua_script_generator.py @@ -7,9 +7,9 @@ class ExasolLuaScriptGenerator: - def __init__(self, lua_script_bundle: LuaScriptBundle, - jinja_template: Template, - **kwargs): + def __init__( + self, lua_script_bundle: LuaScriptBundle, jinja_template: Template, **kwargs + ): self._jinja_template = jinja_template self._lua_script_bundle = lua_script_bundle self._kwargs = kwargs @@ -17,6 +17,10 @@ def __init__(self, lua_script_bundle: LuaScriptBundle, def generate_script(self, output_buffer: typing.IO) -> None: bundle_output_buffer = StringIO() self._lua_script_bundle.bundle_lua_scripts(bundle_output_buffer) - output = self._jinja_template.render(bundled_script=bundle_output_buffer.getvalue(), **self._kwargs) - output_buffer.write("-- This file was generated by the ExasolLuaScriptGenerator.\n\n") - output_buffer.write(output) \ No newline at end of file + output = self._jinja_template.render( + bundled_script=bundle_output_buffer.getvalue(), **self._kwargs + ) + output_buffer.write( + "-- This file was generated by the ExasolLuaScriptGenerator.\n\n" + ) + output_buffer.write(output) diff --git a/exasol/analytics/query_handler/deployment/jinja_template_location.py b/exasol/analytics/query_handler/deployment/jinja_template_location.py index 4df23f23..c445e52b 100644 --- a/exasol/analytics/query_handler/deployment/jinja_template_location.py +++ b/exasol/analytics/query_handler/deployment/jinja_template_location.py @@ -1,4 +1,4 @@ -from jinja2 import Template, Environment, PackageLoader, select_autoescape +from jinja2 import Environment, PackageLoader, Template, select_autoescape class JinjaTemplateLocation: @@ -10,7 +10,9 @@ def __init__(self, package_name: str, package_path: str, template_file_name: str def get_template(self) -> Template: env = Environment( - loader=PackageLoader(package_name=self.package_name, package_path=self.package_path), - autoescape=select_autoescape() + loader=PackageLoader( + package_name=self.package_name, package_path=self.package_path + ), + autoescape=select_autoescape(), ) return env.get_template(self.template_file_name) diff --git a/exasol/analytics/query_handler/deployment/lua_script_bundle.py b/exasol/analytics/query_handler/deployment/lua_script_bundle.py index af128193..19dbfb30 100644 --- a/exasol/analytics/query_handler/deployment/lua_script_bundle.py +++ b/exasol/analytics/query_handler/deployment/lua_script_bundle.py @@ -4,7 +4,7 @@ import tempfile import time from pathlib import Path -from typing import List, IO, Union +from typing import IO, List, Union from importlib_resources.abc import Traversable @@ -14,10 +14,12 @@ class LuaScriptBundle: - def __init__(self, - lua_main_file: PathLike, - lua_source_files: List[PathLike], - lua_modules: List[str]): + def __init__( + self, + lua_main_file: PathLike, + lua_source_files: List[PathLike], + lua_modules: List[str], + ): self.lua_main_file = lua_main_file self.lua_modules = lua_modules self.lua_source_files = lua_source_files @@ -36,12 +38,12 @@ def copy_lua_source_files(self, tmp_dir: Path): def run_lua_amlg(self, tmp_dir: Path, output_buffer: IO): output_file = tmp_dir / f"bundle_{time.time()}.lua" - bash_command = \ - "amalg.lua -o {out_path} -s {main_file} {modules}".format( - tmp_dir=tmp_dir, - out_path=output_file, - main_file=self.lua_main_file.name, - modules=" ".join(self.lua_modules)) + bash_command = "amalg.lua -o {out_path} -s {main_file} {modules}".format( + tmp_dir=tmp_dir, + out_path=output_file, + main_file=self.lua_main_file.name, + modules=" ".join(self.lua_modules), + ) subprocess.check_call(bash_command, shell=True, cwd=tmp_dir) with output_file.open() as f: shutil.copyfileobj(f, output_buffer) diff --git a/exasol/analytics/query_handler/deployment/regenerate_scripts.py b/exasol/analytics/query_handler/deployment/regenerate_scripts.py index dbb3f1a6..5476063c 100644 --- a/exasol/analytics/query_handler/deployment/regenerate_scripts.py +++ b/exasol/analytics/query_handler/deployment/regenerate_scripts.py @@ -1,7 +1,9 @@ import logging -from exasol.analytics.query_handler.deployment.aaf_exasol_lua_script_generator import save_aaf_query_loop_lua_script -from exasol.analytics.query_handler.deployment.lua_script_bundle import LuaScriptBundle +from exasol.analytics.query_handler.deployment.aaf_exasol_lua_script_generator import ( + save_aaf_query_loop_lua_script, +) +from exasol.analytics.query_handler.deployment.lua_script_bundle import LuaScriptBundle def generate_scripts(): @@ -9,8 +11,8 @@ def generate_scripts(): Generate the Lua sql statement of the Query-Loop from scratch and save it. """ logging.basicConfig( - format='%(asctime)s - %(module)s - %(message)s', - level=logging.DEBUG) + format="%(asctime)s - %(module)s - %(message)s", level=logging.DEBUG + ) save_aaf_query_loop_lua_script() diff --git a/exasol/analytics/query_handler/deployment/scripts_deployer.py b/exasol/analytics/query_handler/deployment/scripts_deployer.py index 93caf76f..d2603e19 100644 --- a/exasol/analytics/query_handler/deployment/scripts_deployer.py +++ b/exasol/analytics/query_handler/deployment/scripts_deployer.py @@ -2,48 +2,53 @@ import pyexasol -from exasol.analytics.query_handler.deployment import ( - constants, - utils, +from exasol.analytics.query_handler.deployment import constants, utils +from exasol.analytics.query_handler.deployment.aaf_exasol_lua_script_generator import ( + save_aaf_query_loop_lua_script, ) -from exasol.analytics.query_handler.deployment.aaf_exasol_lua_script_generator import save_aaf_query_loop_lua_script logger = logging.getLogger(__name__) class ScriptsDeployer: - def __init__(self, language_alias: str, schema: str, - pyexasol_conn: pyexasol.ExaConnection): + def __init__( + self, language_alias: str, schema: str, pyexasol_conn: pyexasol.ExaConnection + ): self._language_alias = language_alias self._schema = schema self._pyexasol_conn = pyexasol_conn logger.debug(f"Init {ScriptsDeployer.__name__}.") def _open_schema(self) -> None: - queries = ["CREATE SCHEMA IF NOT EXISTS {schema_name}", - "OPEN SCHEMA {schema_name}"] + queries = [ + "CREATE SCHEMA IF NOT EXISTS {schema_name}", + "OPEN SCHEMA {schema_name}", + ] for query in queries: self._pyexasol_conn.execute(query.format(schema_name=self._schema)) logger.debug(f"Schema {self._schema} is opened.") def _deploy_udf_scripts(self) -> None: for udf_call_src, template_src in constants.UDF_CALL_TEMPLATES.items(): - udf_content = constants.SOURCE_DIR.joinpath( - udf_call_src).read_text() + udf_content = constants.SOURCE_DIR.joinpath(udf_call_src).read_text() udf_query = utils.load_and_render_statement( template_src, script_content=udf_content, - language_alias=self._language_alias) + language_alias=self._language_alias, + ) self._pyexasol_conn.execute(udf_query) - logger.debug(f"UDF statement of the template " - f"{template_src} is executed.") + logger.debug( + f"UDF statement of the template " f"{template_src} is executed." + ) def _deploy_lua_scripts(self) -> None: - with open(constants.LUA_SCRIPT_OUTPUT, "r") as file: + with open(constants.LUA_SCRIPT_OUTPUT) as file: lua_query = file.read() self._pyexasol_conn.execute(lua_query) - logger.debug(f"The Lua statement of the template " - f"{constants.LUA_SCRIPT_TEMPLATE} is executed.") + logger.debug( + f"The Lua statement of the template " + f"{constants.LUA_SCRIPT_TEMPLATE} is executed." + ) def deploy_scripts(self) -> None: self._open_schema() @@ -52,8 +57,15 @@ def deploy_scripts(self) -> None: logger.debug(f"Scripts are deployed.") @classmethod - def run(cls, dsn: str, user: str, password: str, - schema: str, language_alias: str, develop: bool): + def run( + cls, + dsn: str, + user: str, + password: str, + schema: str, + language_alias: str, + develop: bool, + ): if develop: save_aaf_query_loop_lua_script() pyexasol_conn = pyexasol.connect(dsn=dsn, user=user, password=password) diff --git a/exasol/analytics/query_handler/deployment/scripts_deployer_cli.py b/exasol/analytics/query_handler/deployment/scripts_deployer_cli.py index 1ab1aa59..be3b99c8 100644 --- a/exasol/analytics/query_handler/deployment/scripts_deployer_cli.py +++ b/exasol/analytics/query_handler/deployment/scripts_deployer_cli.py @@ -1,20 +1,23 @@ import click + from exasol.analytics.query_handler.deployment import utils -from exasol.analytics.query_handler.deployment.scripts_deployer import ScriptsDeployer +from exasol.analytics.query_handler.deployment.scripts_deployer import ScriptsDeployer from exasol.analytics.query_handler.deployment.slc import LANGUAGE_ALIAS + @click.command(name="scripts") -@click.option('--dsn', type=str, required=True) -@click.option('--user', type=str, required=True) -@click.option('--pass', 'pwd', type=str) -@click.option('--schema', type=str, required=True) -@click.option('--language-alias', type=str, default=LANGUAGE_ALIAS) -@click.option('--develop', type=bool, is_flag=True) +@click.option("--dsn", type=str, required=True) +@click.option("--user", type=str, required=True) +@click.option("--pass", "pwd", type=str) +@click.option("--schema", type=str, required=True) +@click.option("--language-alias", type=str, default=LANGUAGE_ALIAS) +@click.option("--develop", type=bool, is_flag=True) def scripts_deployer_main( - dsn: str, user: str, pwd: str, schema: str, - language_alias: str, develop: bool): + dsn: str, user: str, pwd: str, schema: str, language_alias: str, develop: bool +): password = utils.get_password( - pwd, user, utils.DB_PASSWORD_ENVIRONMENT_VARIABLE, "DB Password") + pwd, user, utils.DB_PASSWORD_ENVIRONMENT_VARIABLE, "DB Password" + ) ScriptsDeployer.run( dsn=dsn, user=user, @@ -25,10 +28,11 @@ def scripts_deployer_main( ) -if __name__ == '__main__': +if __name__ == "__main__": import logging + logging.basicConfig( - format='%(asctime)s - %(module)s - %(message)s', - level=logging.DEBUG) + format="%(asctime)s - %(module)s - %(message)s", level=logging.DEBUG + ) scripts_deployer_main() diff --git a/exasol/analytics/query_handler/deployment/slc.py b/exasol/analytics/query_handler/deployment/slc.py index dfa3c444..6fc2b1c3 100644 --- a/exasol/analytics/query_handler/deployment/slc.py +++ b/exasol/analytics/query_handler/deployment/slc.py @@ -1,13 +1,17 @@ from contextlib import contextmanager + from exasol.python_extension_common.deployment.language_container_builder import ( LanguageContainerBuilder, - find_path_backwards + find_path_backwards, ) LANGUAGE_ALIAS = "PYTHON3_AAF" SLC_NAME = "exasol_advanced_analytics_framework_container" SLC_FILE_NAME = SLC_NAME + "_release.tar.gz" -SLC_URL_FORMATTER = "https://github.com/exasol/advanced-analytics-framework/releases/download/{version}/" + SLC_FILE_NAME +SLC_URL_FORMATTER = ( + "https://github.com/exasol/advanced-analytics-framework/releases/download/{version}/" + + SLC_FILE_NAME +) @contextmanager diff --git a/exasol/analytics/query_handler/deployment/utils.py b/exasol/analytics/query_handler/deployment/utils.py index 5b0d9a6d..e02fc7ae 100644 --- a/exasol/analytics/query_handler/deployment/utils.py +++ b/exasol/analytics/query_handler/deployment/utils.py @@ -1,6 +1,6 @@ +import logging import os from getpass import getpass -import logging from jinja2 import Environment, PackageLoader, select_autoescape @@ -24,9 +24,8 @@ def get_password(pwd: str, user: str, env_var: str, descr: str) -> str: def load_and_render_statement(template_name, **kwargs) -> str: env = Environment( - loader=PackageLoader( - constants.BASE_PACKAGE, constants.TEMPLATES_DIR), - autoescape=select_autoescape() + loader=PackageLoader(constants.BASE_PACKAGE, constants.TEMPLATES_DIR), + autoescape=select_autoescape(), ) template = env.get_template(template_name) statement = template.render(**kwargs) diff --git a/exasol/analytics/query_handler/graph/execution_graph.py b/exasol/analytics/query_handler/graph/execution_graph.py index 04d1da95..c1b4ede6 100644 --- a/exasol/analytics/query_handler/graph/execution_graph.py +++ b/exasol/analytics/query_handler/graph/execution_graph.py @@ -1,18 +1,15 @@ import json import typing -from typing import TypeVar, Generic, Set, Tuple, List +from typing import Generic, List, Set, Tuple, TypeVar import networkx as nx -T = TypeVar('T') +T = TypeVar("T") class ExecutionGraph(Generic[T]): - def __init__(self, - start_node: T, - end_node: T, - edges: Set[Tuple[T, T]]): + def __init__(self, start_node: T, end_node: T, edges: Set[Tuple[T, T]]): self._graph = nx.DiGraph() self._graph.add_edges_from(edges) self._graph.add_node(start_node) @@ -22,18 +19,24 @@ def __init__(self, if not nx.is_directed_acyclic_graph(self._graph): raise Exception("Graph not directed acyclic") nodes = set(self._graph) - descendants_plus_start_node = nx.descendants(self._graph, self._start_node) | {self._start_node} + descendants_plus_start_node = nx.descendants(self._graph, self._start_node) | { + self._start_node + } if not descendants_plus_start_node == nodes: raise Exception("Not all Nodes are reachable from start node") - ancestors_plus_end_node = nx.ancestors(self._graph, self._end_node) | {self._end_node} + ancestors_plus_end_node = nx.ancestors(self._graph, self._end_node) | { + self._end_node + } if not ancestors_plus_end_node == nodes: raise Exception("End node not reachable by all nodes") def __eq__(self, other) -> bool: if isinstance(other, self.__class__): - result = self._start_node == other._start_node and \ - self._end_node == other._end_node and \ - self._graph.edges == other._graph.edges + result = ( + self._start_node == other._start_node + and self._end_node == other._end_node + and self._graph.edges == other._graph.edges + ) return result else: return False @@ -43,7 +46,7 @@ def __repr__(self) -> str: result = { "start_node": str(self._start_node), "end_node": str(self._end_node), - "edges": sorted_edges + "edges": sorted_edges, } # return f"ExecutionGraph(start_node={self._start_node},end_node={self._end_node},edges={sorted_edges})" return json.dumps(result, indent=2) @@ -74,13 +77,17 @@ def edges(self) -> Set[Tuple[T, T]]: def compute_reverse_dependency_order(self) -> List[T]: reversed_graph = self._graph.reverse() - post_order_of_reversed_graph = \ - list(nx.traversal.dfs_postorder_nodes(reversed_graph, self._end_node)) - reversed_post_order_of_reversed_graph = list(reversed(post_order_of_reversed_graph)) + post_order_of_reversed_graph = list( + nx.traversal.dfs_postorder_nodes(reversed_graph, self._end_node) + ) + reversed_post_order_of_reversed_graph = list( + reversed(post_order_of_reversed_graph) + ) return reversed_post_order_of_reversed_graph def compute_dependency_order(self) -> List[T]: - post_order_of__graph = \ - list(nx.traversal.dfs_postorder_nodes(self._graph, self._start_node)) + post_order_of__graph = list( + nx.traversal.dfs_postorder_nodes(self._graph, self._start_node) + ) reversed_post_order_of_graph = list(reversed(post_order_of__graph)) return reversed_post_order_of_graph diff --git a/exasol/analytics/query_handler/graph/parameter.py b/exasol/analytics/query_handler/graph/parameter.py index 41c781ee..a28179e7 100644 --- a/exasol/analytics/query_handler/graph/parameter.py +++ b/exasol/analytics/query_handler/graph/parameter.py @@ -1,5 +1,6 @@ import dataclasses + @dataclasses.dataclass(frozen=True) -class Parameter(): +class Parameter: pass diff --git a/exasol/analytics/query_handler/graph/result.py b/exasol/analytics/query_handler/graph/result.py index 184cec64..85531810 100644 --- a/exasol/analytics/query_handler/graph/result.py +++ b/exasol/analytics/query_handler/graph/result.py @@ -1,4 +1,4 @@ -from typing import Type, TypeVar, Any, Tuple +from typing import Any, Tuple, Type, TypeVar def _init(self): @@ -12,18 +12,23 @@ def _setattr(self, key: str, value: Any): If an attribute, already exists this functions raises a AttributeError. """ if key not in self.__annotations__ and key != "result_id": - raise AttributeError(f"Attribute '{key}' is not defined for class '{self.__class__.__name__}'.") + raise AttributeError( + f"Attribute '{key}' is not defined for class '{self.__class__.__name__}'." + ) if hasattr(self, key): raise AttributeError(f"Attribute '{key}' is already set.") object.__setattr__(self, key, value) + def _delattr(self, key: str): """ With this __delattr__ implementation, we disallow the deletion of attributes. If an attribute, already exists this functions raises a AttributeError. """ if key not in self.__annotations__ and key != "result_id": - raise AttributeError(f"Attribute '{key}' is not defined for class '{self.__class__.__name__}'.") + raise AttributeError( + f"Attribute '{key}' is not defined for class '{self.__class__.__name__}'." + ) else: raise AttributeError(f"Attribute '{key}' cannnot be deleted.") @@ -103,14 +108,18 @@ class Result(metaclass=_Meta): def update(self, other: "Result"): if self.__class__ != other.__class__: - raise TypeError(f"Incompatible classes for " - f"self '{self.__class__.__name__}' and " - f"other '{other.__class__.__name__}'.") + raise TypeError( + f"Incompatible classes for " + f"self '{self.__class__.__name__}' and " + f"other '{other.__class__.__name__}'." + ) if self.result_id != other.result_id: raise ValueError("Self and other have different result ids.") for key in self.__annotations__.keys(): if not hasattr(other, key) and hasattr(self, key): - raise AttributeError(f"Attribute '{key}' is set in self, but not in other.") + raise AttributeError( + f"Attribute '{key}' is set in self, but not in other." + ) if not hasattr(other, key): continue other_value = getattr(other, key) @@ -119,7 +128,8 @@ def update(self, other: "Result"): if other_value == self_value: continue raise AttributeError( - f"Values for attribute '{key}' are different in self '{self_value}' and other '{other_value}'") + f"Values for attribute '{key}' are different in self '{self_value}' and other '{other_value}'" + ) setattr(self, key, other_value) return self diff --git a/exasol/analytics/query_handler/graph/stage/sql/data_partition.py b/exasol/analytics/query_handler/graph/stage/sql/data_partition.py index 06a6a48d..b9433126 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/data_partition.py +++ b/exasol/analytics/query_handler/graph/stage/sql/data_partition.py @@ -1,11 +1,9 @@ import dataclasses +from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependencies from exasol.analytics.schema import TableLike - from exasol.analytics.utils.data_classes_runtime_type_check import check_dataclass_types -from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependencies - @dataclasses.dataclass(frozen=True) class DataPartition: @@ -14,6 +12,7 @@ class DataPartition: the dependencies contain everything which is needed to execute the view, such as tables, udfs, connection objects, .... """ + table_like: TableLike dependencies: Dependencies = dataclasses.field(default_factory=dict) """ diff --git a/exasol/analytics/query_handler/graph/stage/sql/dataset.py b/exasol/analytics/query_handler/graph/stage/sql/dataset.py index 7930c80a..b7cde316 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/dataset.py +++ b/exasol/analytics/query_handler/graph/stage/sql/dataset.py @@ -1,13 +1,11 @@ import dataclasses from enum import Enum -from typing import Dict, Union, Tuple, List +from typing import Dict, List, Tuple, Union +from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition from exasol.analytics.schema import Column - from exasol.analytics.utils.data_classes_runtime_type_check import check_dataclass_types -from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition - DataPartitionName = Union[Enum, Tuple[Enum, int]] @@ -17,6 +15,7 @@ class Dataset: A Dataset consists of multiple data partitions and column lists which indicate the identifier, sample and target columns, The data paritions can be used to describe train and test sets. """ + data_partitions: Dict[DataPartitionName, DataPartition] identifier_columns: List[Column] sample_columns: List[Column] @@ -28,17 +27,25 @@ def __post_init__(self): self._check_columns() def _check_table_name(self): - all_table_like_names = {data_partition.table_like.name - for data_partition in self.data_partitions.values()} + all_table_like_names = { + data_partition.table_like.name + for data_partition in self.data_partitions.values() + } if len(all_table_like_names) != len(self.data_partitions): - raise ValueError("The names of table likes of the data partitions should be different.") + raise ValueError( + "The names of table likes of the data partitions should be different." + ) def _check_columns(self): - all_columns = {column for data_partition in self.data_partitions.values() - for column in data_partition.table_like.columns} + all_columns = { + column + for data_partition in self.data_partitions.values() + for column in data_partition.table_like.columns + } all_data_partition_have_same_columns = all( len(data_partition.table_like.columns) == len(all_columns) - for data_partition in self.data_partitions.values()) + for data_partition in self.data_partitions.values() + ) if not all_data_partition_have_same_columns: raise ValueError("Not all data partitions have the same columns.") if not all_columns.issuperset(self.sample_columns): diff --git a/exasol/analytics/query_handler/graph/stage/sql/dependency.py b/exasol/analytics/query_handler/graph/stage/sql/dependency.py index 50eb1458..00c42698 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/dependency.py +++ b/exasol/analytics/query_handler/graph/stage/sql/dependency.py @@ -27,8 +27,9 @@ def __post_init__(self): # We can't use check_dataclass_types(self) here, because the forward definition of "Dependency" # can be only resolved if check_type uses the locals and globals of this frame try: - typeguard.check_type(value=self.dependencies, - expected_type=Dict[Enum, "Dependency"]) + typeguard.check_type( + value=self.dependencies, expected_type=Dict[Enum, "Dependency"] + ) except TypeCheckError as e: raise TypeCheckError(f"Field 'dependencies' has wrong type: {e}") diff --git a/exasol/analytics/query_handler/graph/stage/sql/execution/find_object_proxies.py b/exasol/analytics/query_handler/graph/stage/sql/execution/find_object_proxies.py index 41ea7f48..082191c1 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/execution/find_object_proxies.py +++ b/exasol/analytics/query_handler/graph/stage/sql/execution/find_object_proxies.py @@ -1,6 +1,8 @@ -from typing import List, Any, Iterable, Dict, Set +from typing import Any, Dict, Iterable, List, Set -from exasol.analytics.query_handler.context.proxy.db_object_name_proxy import ObjectProxy +from exasol.analytics.query_handler.context.proxy.db_object_name_proxy import ( + ObjectProxy, +) def find_object_proxies(obj: Any) -> List[ObjectProxy]: diff --git a/exasol/analytics/query_handler/graph/stage/sql/execution/input.py b/exasol/analytics/query_handler/graph/stage/sql/execution/input.py index 04ed601a..d1ec4b54 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/execution/input.py +++ b/exasol/analytics/query_handler/graph/stage/sql/execution/input.py @@ -1,9 +1,13 @@ import dataclasses -from exasol_bucketfs_utils_python.abstract_bucketfs_location import AbstractBucketFSLocation +from exasol_bucketfs_utils_python.abstract_bucketfs_location import ( + AbstractBucketFSLocation, +) +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) from exasol.analytics.query_handler.graph.stage.sql.sql_stage_graph import SQLStageGraph -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput @dataclasses.dataclass(frozen=True, eq=True) diff --git a/exasol/analytics/query_handler/graph/stage/sql/execution/object_proxy_reference_counter.py b/exasol/analytics/query_handler/graph/stage/sql/execution/object_proxy_reference_counter.py index e7043ad9..9dbe9a60 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/execution/object_proxy_reference_counter.py +++ b/exasol/analytics/query_handler/graph/stage/sql/execution/object_proxy_reference_counter.py @@ -17,18 +17,30 @@ class ObjectProxyReferenceCounter: calls release on it, when the counter gets 0. This releases the ObjectProxy. """ - def __init__(self, parent_query_context_handler: ScopeQueryHandlerContext, object_proxy: ObjectProxy): + def __init__( + self, + parent_query_context_handler: ScopeQueryHandlerContext, + object_proxy: ObjectProxy, + ): self._object_proxy = object_proxy self._valid = True self._parent_query_context_handler = parent_query_context_handler - self._child_query_context_handler = self._parent_query_context_handler.get_child_query_handler_context() - self._parent_query_context_handler.transfer_object_to(object_proxy, self._child_query_context_handler) - self._counter = 1 # counter is one, because with zero this object wouldn't exist + self._child_query_context_handler = ( + self._parent_query_context_handler.get_child_query_handler_context() + ) + self._parent_query_context_handler.transfer_object_to( + object_proxy, self._child_query_context_handler + ) + self._counter = ( + 1 # counter is one, because with zero this object wouldn't exist + ) def _check_if_valid(self): if not self._valid: - raise RuntimeError("ReferenceCounter not valid anymore. " - "ObjectProxy got already garbage collected or transfered back.") + raise RuntimeError( + "ReferenceCounter not valid anymore. " + "ObjectProxy got already garbage collected or transfered back." + ) def add(self): self._check_if_valid() @@ -49,7 +61,8 @@ def _release_if_not_used(self) -> ReferenceCounterStatus: def transfer_back_to_parent_query_handler_context(self): self._check_if_valid() self._child_query_context_handler.transfer_object_to( - self._object_proxy, self._parent_query_context_handler) + self._object_proxy, self._parent_query_context_handler + ) self._invalidate_and_release() def _invalidate_and_release(self): diff --git a/exasol/analytics/query_handler/graph/stage/sql/execution/object_proxy_reference_counting_bag.py b/exasol/analytics/query_handler/graph/stage/sql/execution/object_proxy_reference_counting_bag.py index 4d05e761..863f35d0 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/execution/object_proxy_reference_counting_bag.py +++ b/exasol/analytics/query_handler/graph/stage/sql/execution/object_proxy_reference_counting_bag.py @@ -1,11 +1,15 @@ -from typing import Dict, Callable +from typing import Callable, Dict from exasol.analytics.query_handler.context.proxy.object_proxy import ObjectProxy from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counter import ( + ObjectProxyReferenceCounter, + ReferenceCounterStatus, +) -from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counter import ObjectProxyReferenceCounter, ReferenceCounterStatus - -ObjectProxyReferenceCounterFactory = Callable[[ScopeQueryHandlerContext, ObjectProxy], ObjectProxyReferenceCounter] +ObjectProxyReferenceCounterFactory = Callable[ + [ScopeQueryHandlerContext, ObjectProxy], ObjectProxyReferenceCounter +] class ObjectProxyReferenceCountingBag: @@ -15,18 +19,24 @@ class ObjectProxyReferenceCountingBag: reaches zero the corresponding ScopeQueryHandlerContext gets released. """ - def __init__(self, - parent_query_context_handler: ScopeQueryHandlerContext, - object_proxy_reference_counter_factory: ObjectProxyReferenceCounterFactory = - ObjectProxyReferenceCounter): - self._object_proxy_reference_counter_factory = object_proxy_reference_counter_factory + def __init__( + self, + parent_query_context_handler: ScopeQueryHandlerContext, + object_proxy_reference_counter_factory: ObjectProxyReferenceCounterFactory = ObjectProxyReferenceCounter, + ): + self._object_proxy_reference_counter_factory = ( + object_proxy_reference_counter_factory + ) self._parent_query_context_handler = parent_query_context_handler self._reference_counter_map: Dict[ObjectProxy, ObjectProxyReferenceCounter] = {} def add(self, object_proxy: ObjectProxy): if object_proxy not in self._reference_counter_map: - self._reference_counter_map[object_proxy] = \ - self._object_proxy_reference_counter_factory(self._parent_query_context_handler, object_proxy) + self._reference_counter_map[object_proxy] = ( + self._object_proxy_reference_counter_factory( + self._parent_query_context_handler, object_proxy + ) + ) else: self._reference_counter_map[object_proxy].add() @@ -40,5 +50,7 @@ def __contains__(self, item: ObjectProxy): return result def transfer_back_to_parent_query_handler_context(self, object_proxy: ObjectProxy): - self._reference_counter_map[object_proxy].transfer_back_to_parent_query_handler_context() + self._reference_counter_map[ + object_proxy + ].transfer_back_to_parent_query_handler_context() del self._reference_counter_map[object_proxy] diff --git a/exasol/analytics/query_handler/graph/stage/sql/execution/query_handler.py b/exasol/analytics/query_handler/graph/stage/sql/execution/query_handler.py index a524f188..414145ed 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/execution/query_handler.py +++ b/exasol/analytics/query_handler/graph/stage/sql/execution/query_handler.py @@ -1,26 +1,35 @@ -from typing import Union, Callable +from typing import Callable, Union from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.graph.stage.sql.execution.input import ( + SQLStageGraphExecutionInput, +) +from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import ( + ResultHandlerReturnValue, + SQLStageGraphExecutionQueryHandlerState, +) +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) +from exasol.analytics.query_handler.query.result.interface import QueryResult from exasol.analytics.query_handler.query_handler import QueryHandler from exasol.analytics.query_handler.result import Continue, Finish -from exasol.analytics.query_handler.query.result.interface import QueryResult - -from exasol.analytics.query_handler.graph.stage.sql.execution.input import SQLStageGraphExecutionInput -from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import SQLStageGraphExecutionQueryHandlerState, ResultHandlerReturnValue -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput -SQLStageGraphExecutionQueryHandlerStateFactory = \ - Callable[ - [SQLStageGraphExecutionInput, ScopeQueryHandlerContext], - SQLStageGraphExecutionQueryHandlerState] +SQLStageGraphExecutionQueryHandlerStateFactory = Callable[ + [SQLStageGraphExecutionInput, ScopeQueryHandlerContext], + SQLStageGraphExecutionQueryHandlerState, +] -class SQLStageGraphExecutionQueryHandler(QueryHandler[SQLStageGraphExecutionInput, SQLStageInputOutput]): - def __init__(self, - parameter: SQLStageGraphExecutionInput, - query_handler_context: ScopeQueryHandlerContext, - query_handler_state_factory: SQLStageGraphExecutionQueryHandlerStateFactory = - SQLStageGraphExecutionQueryHandlerState): +class SQLStageGraphExecutionQueryHandler( + QueryHandler[SQLStageGraphExecutionInput, SQLStageInputOutput] +): + def __init__( + self, + parameter: SQLStageGraphExecutionInput, + query_handler_context: ScopeQueryHandlerContext, + query_handler_state_factory: SQLStageGraphExecutionQueryHandlerStateFactory = SQLStageGraphExecutionQueryHandlerState, + ): super().__init__(parameter, query_handler_context) self._state = query_handler_state_factory(parameter, query_handler_context) @@ -28,26 +37,36 @@ def start(self) -> Union[Continue, Finish[SQLStageInputOutput]]: result = self._run_until_continue_or_last_stage_finished() return result - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[SQLStageInputOutput]]: - result = self._state.get_current_query_handler().handle_query_result(query_result) + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[SQLStageInputOutput]]: + result = self._state.get_current_query_handler().handle_query_result( + query_result + ) result_handler_return_value = self._state.handle_result(result) if result_handler_return_value == ResultHandlerReturnValue.RETURN_RESULT: return result - elif result_handler_return_value == ResultHandlerReturnValue.CONTINUE_PROCESSING: + elif ( + result_handler_return_value == ResultHandlerReturnValue.CONTINUE_PROCESSING + ): result = self._run_until_continue_or_last_stage_finished() else: raise RuntimeError("Unknown result_handler_return_value") return result - def _run_until_continue_or_last_stage_finished(self) \ - -> Union[Continue, Finish[SQLStageInputOutput]]: + def _run_until_continue_or_last_stage_finished( + self, + ) -> Union[Continue, Finish[SQLStageInputOutput]]: while True: handler = self._state.get_current_query_handler() result = handler.start() result_handler_return_value = self._state.handle_result(result) if result_handler_return_value == ResultHandlerReturnValue.RETURN_RESULT: return result - elif result_handler_return_value == ResultHandlerReturnValue.CONTINUE_PROCESSING: + elif ( + result_handler_return_value + == ResultHandlerReturnValue.CONTINUE_PROCESSING + ): pass else: raise RuntimeError("Unknown result_handler_return_value") diff --git a/exasol/analytics/query_handler/graph/stage/sql/execution/query_handler_state.py b/exasol/analytics/query_handler/graph/stage/sql/execution/query_handler_state.py index 1a07aa6e..a563ab36 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/execution/query_handler_state.py +++ b/exasol/analytics/query_handler/graph/stage/sql/execution/query_handler_state.py @@ -1,18 +1,27 @@ import enum -from typing import DefaultDict, List, Optional, Union, Callable +from typing import Callable, DefaultDict, List, Optional, Union from exasol.analytics.query_handler.context.proxy.object_proxy import ObjectProxy from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.graph.stage.sql.execution.find_object_proxies import ( + find_object_proxies, +) +from exasol.analytics.query_handler.graph.stage.sql.execution.input import ( + SQLStageGraphExecutionInput, +) +from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counting_bag import ( + ObjectProxyReferenceCountingBag, +) +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) +from exasol.analytics.query_handler.graph.stage.sql.sql_stage import SQLStage +from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import ( + SQLStageTrainQueryHandlerInput, +) from exasol.analytics.query_handler.query_handler import QueryHandler from exasol.analytics.query_handler.result import Continue, Finish -from exasol.analytics.query_handler.graph.stage.sql.execution.find_object_proxies import find_object_proxies -from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counting_bag import ObjectProxyReferenceCountingBag -from exasol.analytics.query_handler.graph.stage.sql.execution.input import SQLStageGraphExecutionInput -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput -from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import SQLStageTrainQueryHandlerInput -from exasol.analytics.query_handler.graph.stage.sql.sql_stage import SQLStage - class ResultHandlerReturnValue(enum.Enum): RETURN_RESULT = enum.auto() @@ -20,21 +29,32 @@ class ResultHandlerReturnValue(enum.Enum): class SQLStageGraphExecutionQueryHandlerState: - def __init__(self, - parameter: SQLStageGraphExecutionInput, - query_handler_context: ScopeQueryHandlerContext, - reference_counting_bag_factory: Callable[ - [ScopeQueryHandlerContext], ObjectProxyReferenceCountingBag] = ObjectProxyReferenceCountingBag): + def __init__( + self, + parameter: SQLStageGraphExecutionInput, + query_handler_context: ScopeQueryHandlerContext, + reference_counting_bag_factory: Callable[ + [ScopeQueryHandlerContext], ObjectProxyReferenceCountingBag + ] = ObjectProxyReferenceCountingBag, + ): self._query_handler_context = query_handler_context self._sql_stage_graph = parameter.sql_stage_graph self._result_bucketfs_location = parameter.result_bucketfs_location - self._reference_counting_bag = reference_counting_bag_factory(query_handler_context) + self._reference_counting_bag = reference_counting_bag_factory( + query_handler_context + ) self._stage_inputs_map = DefaultDict[SQLStage, List[SQLStageInputOutput]](list) - self._stages_in_execution_order = self._sql_stage_graph.compute_dependency_order() + self._stages_in_execution_order = ( + self._sql_stage_graph.compute_dependency_order() + ) self._current_stage_index = 0 - self._current_stage: Optional[SQLStage] = self._stages_in_execution_order[self._current_stage_index] + self._current_stage: Optional[SQLStage] = self._stages_in_execution_order[ + self._current_stage_index + ] self._stage_inputs_map[self._current_stage].append(parameter.input) - self._current_query_handler: Optional[QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]] = None + self._current_query_handler: Optional[ + QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput] + ] = None self._current_query_handler_context: Optional[ScopeQueryHandlerContext] = None self._create_current_query_handler() @@ -42,11 +62,15 @@ def _check_is_valid(self): if self._current_query_handler is None: raise RuntimeError("No current query handler set.") - def get_current_query_handler(self) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: + def get_current_query_handler( + self, + ) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: self._check_is_valid() return self._current_query_handler - def handle_result(self, result: Union[Continue, Finish[SQLStageInputOutput]]) -> ResultHandlerReturnValue: + def handle_result( + self, result: Union[Continue, Finish[SQLStageInputOutput]] + ) -> ResultHandlerReturnValue: self._check_is_valid() if isinstance(result, Finish): return self._handle_finished_result(result) @@ -55,7 +79,9 @@ def handle_result(self, result: Union[Continue, Finish[SQLStageInputOutput]]) -> else: raise RuntimeError("Unkown result type") - def _handle_finished_result(self, result: Finish[SQLStageInputOutput]) -> ResultHandlerReturnValue: + def _handle_finished_result( + self, result: Finish[SQLStageInputOutput] + ) -> ResultHandlerReturnValue: if self._is_not_last_stage(): self._add_result_to_successors(result.result) else: @@ -87,15 +113,19 @@ def _move_to_next_stage(self): def _create_current_query_handler(self): stage_inputs = self._stage_inputs_map[self._current_stage] - self._current_query_handler_context = self._query_handler_context.get_child_query_handler_context() - result_bucketfs_location = self._result_bucketfs_location.joinpath(str(self._current_stage_index)) + self._current_query_handler_context = ( + self._query_handler_context.get_child_query_handler_context() + ) + result_bucketfs_location = self._result_bucketfs_location.joinpath( + str(self._current_stage_index) + ) stage_input = SQLStageTrainQueryHandlerInput( result_bucketfs_location=result_bucketfs_location, - sql_stage_inputs=stage_inputs + sql_stage_inputs=stage_inputs, + ) + self._current_query_handler = self._current_stage.create_train_query_handler( + stage_input, self._current_query_handler_context ) - self._current_query_handler = \ - self._current_stage.create_train_query_handler( - stage_input, self._current_query_handler_context) def _add_result_to_successors(self, result: SQLStageInputOutput): successors = self._sql_stage_graph.successors(self._current_stage) @@ -104,16 +134,21 @@ def _add_result_to_successors(self, result: SQLStageInputOutput): self._add_result_to_inputs_of_successors(result, successors) self._add_result_to_reference_counting_bag(result, successors) - def _add_result_to_inputs_of_successors(self, result: SQLStageInputOutput, successors: List[SQLStage]): + def _add_result_to_inputs_of_successors( + self, result: SQLStageInputOutput, successors: List[SQLStage] + ): for successor in successors: self._stage_inputs_map[successor].append(result) def _add_result_to_reference_counting_bag( - self, result: SQLStageInputOutput, successors: List[SQLStage]): + self, result: SQLStageInputOutput, successors: List[SQLStage] + ): object_proxies = find_object_proxies(result) for object_proxy in object_proxies: if object_proxy not in self._reference_counting_bag: - self._current_query_handler_context.transfer_object_to(object_proxy, self._query_handler_context) + self._current_query_handler_context.transfer_object_to( + object_proxy, self._query_handler_context + ) for _ in successors: self._reference_counting_bag.add(object_proxy) @@ -121,16 +156,22 @@ def _transfer_ownership_of_result_to_query_result_handler(self, result): object_proxies = find_object_proxies(result) for object_proxy in object_proxies: if object_proxy in self._reference_counting_bag: - self._reference_counting_bag.transfer_back_to_parent_query_handler_context(object_proxy) + self._reference_counting_bag.transfer_back_to_parent_query_handler_context( + object_proxy + ) else: - self._current_query_handler_context.transfer_object_to(object_proxy, self._query_handler_context) + self._current_query_handler_context.transfer_object_to( + object_proxy, self._query_handler_context + ) def _remove_current_stage_inputs(self): for stage_input in self._stage_inputs_map[self._current_stage]: object_proxies = find_object_proxies(stage_input) self._remove_object_proxies_from_reference_counting_bag(object_proxies) - def _remove_object_proxies_from_reference_counting_bag(self, object_proxies: List[ObjectProxy]): + def _remove_object_proxies_from_reference_counting_bag( + self, object_proxies: List[ObjectProxy] + ): for object_proxy in object_proxies: if object_proxy in self._reference_counting_bag: self._reference_counting_bag.remove(object_proxy) diff --git a/exasol/analytics/query_handler/graph/stage/sql/input_output.py b/exasol/analytics/query_handler/graph/stage/sql/input_output.py index 066ea88f..2c25a887 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/input_output.py +++ b/exasol/analytics/query_handler/graph/stage/sql/input_output.py @@ -1,9 +1,8 @@ import dataclasses -from exasol.analytics.utils.data_classes_runtime_type_check import check_dataclass_types - from exasol.analytics.query_handler.graph.stage.sql.dataset import Dataset from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependencies +from exasol.analytics.utils.data_classes_runtime_type_check import check_dataclass_types @dataclasses.dataclass(frozen=True, eq=True) diff --git a/exasol/analytics/query_handler/graph/stage/sql/sql_stage.py b/exasol/analytics/query_handler/graph/stage/sql/sql_stage.py index b3f5823f..ffc5e79a 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/sql_stage.py +++ b/exasol/analytics/query_handler/graph/stage/sql/sql_stage.py @@ -1,14 +1,18 @@ import abc from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import ( + SQLStageQueryHandler, + SQLStageTrainQueryHandlerInput, +) from exasol.analytics.query_handler.graph.stage.stage import Stage -from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import SQLStageQueryHandler, SQLStageTrainQueryHandlerInput + class SQLStage(Stage): @abc.abstractmethod def create_train_query_handler( - self, - stage_input: SQLStageTrainQueryHandlerInput, - query_handler_context: ScopeQueryHandlerContext, + self, + stage_input: SQLStageTrainQueryHandlerInput, + query_handler_context: ScopeQueryHandlerContext, ) -> SQLStageQueryHandler: pass diff --git a/exasol/analytics/query_handler/graph/stage/sql/sql_stage_query_handler.py b/exasol/analytics/query_handler/graph/stage/sql/sql_stage_query_handler.py index 9b2fe470..2f9b96e4 100644 --- a/exasol/analytics/query_handler/graph/stage/sql/sql_stage_query_handler.py +++ b/exasol/analytics/query_handler/graph/stage/sql/sql_stage_query_handler.py @@ -2,10 +2,14 @@ from abc import ABC from typing import List, Sized -from exasol.analytics.query_handler.query_handler import QueryHandler -from exasol_bucketfs_utils_python.abstract_bucketfs_location import AbstractBucketFSLocation +from exasol_bucketfs_utils_python.abstract_bucketfs_location import ( + AbstractBucketFSLocation, +) -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) +from exasol.analytics.query_handler.query_handler import QueryHandler def is_empty(obj: Sized): @@ -23,5 +27,6 @@ def __post_init__(self): class SQLStageQueryHandler( - QueryHandler[SQLStageTrainQueryHandlerInput, SQLStageInputOutput], ABC): + QueryHandler[SQLStageTrainQueryHandlerInput, SQLStageInputOutput], ABC +): pass diff --git a/exasol/analytics/query_handler/json_udf_query_handler.py b/exasol/analytics/query_handler/json_udf_query_handler.py index 113ca4fc..d4a5b2c7 100644 --- a/exasol/analytics/query_handler/json_udf_query_handler.py +++ b/exasol/analytics/query_handler/json_udf_query_handler.py @@ -1,4 +1,4 @@ -from typing import Dict, Any +from typing import Any, Dict from exasol.analytics.query_handler.query_handler import QueryHandler diff --git a/exasol/analytics/query_handler/python_query_handler_runner.py b/exasol/analytics/query_handler/python_query_handler_runner.py index 92f52a77..57b0090d 100644 --- a/exasol/analytics/query_handler/python_query_handler_runner.py +++ b/exasol/analytics/query_handler/python_query_handler_runner.py @@ -1,18 +1,22 @@ import logging import textwrap from inspect import cleandoc -from typing import Callable, TypeVar, Generic, Tuple, Union, List +from typing import Callable, Generic, List, Tuple, TypeVar, Union -from exasol.analytics.sql_executor.interface import SQLExecutor - -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.context.top_level_query_handler_context import TopLevelQueryHandlerContext, ConnectionLookup +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.context.top_level_query_handler_context import ( + ConnectionLookup, + TopLevelQueryHandlerContext, +) from exasol.analytics.query_handler.query.interface import Query +from exasol.analytics.query_handler.query.result.python_query_result import ( + PythonQueryResult, +) from exasol.analytics.query_handler.query.select import SelectQueryWithColumnDefinition from exasol.analytics.query_handler.query_handler import QueryHandler from exasol.analytics.query_handler.result import Continue, Finish -from exasol.analytics.query_handler.query.result.python_query_result import PythonQueryResult from exasol.analytics.query_handler.udf.runner.state import QueryHandlerRunnerState +from exasol.analytics.sql_executor.interface import SQLExecutor LOGGER = logging.getLogger(__file__) @@ -22,19 +26,24 @@ class PythonQueryHandlerRunner(Generic[ParameterType, ResultType]): - def __init__(self, - sql_executor: SQLExecutor, - top_level_query_handler_context: TopLevelQueryHandlerContext, - parameter: ParameterType, - query_handler_factory: Callable[ - [ParameterType, ScopeQueryHandlerContext], - QueryHandler[ParameterType, ResultType]]): + def __init__( + self, + sql_executor: SQLExecutor, + top_level_query_handler_context: TopLevelQueryHandlerContext, + parameter: ParameterType, + query_handler_factory: Callable[ + [ParameterType, ScopeQueryHandlerContext], + QueryHandler[ParameterType, ResultType], + ], + ): self._sql_executor = sql_executor - query_handler = query_handler_factory(parameter, top_level_query_handler_context) + query_handler = query_handler_factory( + parameter, top_level_query_handler_context + ) self._state = QueryHandlerRunnerState( top_level_query_handler_context=top_level_query_handler_context, query_handler=query_handler, - connection_lookup=None + connection_lookup=None, ) def run(self) -> ResultType: @@ -51,8 +60,12 @@ def run(self) -> ResultType: try: self._handle_finish() except Exception as e1: - LOGGER.exception("Catched exeception during cleanup after an exception.") - raise RuntimeError(f"Execution of query handler {self._state.query_handler} failed.") from e + LOGGER.exception( + "Catched exeception during cleanup after an exception." + ) + raise RuntimeError( + f"Execution of query handler {self._state.query_handler} failed." + ) from e def _handle_continue(self, result: Continue) -> Union[Continue, Finish[ResultType]]: self._release_and_create_query_handler_context_of_input_query() @@ -67,11 +80,14 @@ def _run_input_query(self, result: Continue) -> PythonQueryResult: self._sql_executor.execute(input_query_view) input_query_result_set = self._sql_executor.execute(input_query) if input_query_result_set.columns() != result.input_query.output_columns: - raise RuntimeError(f"Specified columns {result.input_query.output_columns} of the input query " - f"are not equal to the actual received columns {input_query_result_set.columns()}") + raise RuntimeError( + f"Specified columns {result.input_query.output_columns} of the input query " + f"are not equal to the actual received columns {input_query_result_set.columns()}" + ) input_query_result_table = input_query_result_set.fetchall() - input_query_result = PythonQueryResult(data=input_query_result_table, - columns=result.input_query.output_columns) + input_query_result = PythonQueryResult( + data=input_query_result_table, columns=result.input_query.output_columns + ) return input_query_result def _handle_finish(self): @@ -81,8 +97,9 @@ def _handle_finish(self): self._cleanup_query_handler_context() def _cleanup_query_handler_context(self): - cleanup_query_list = \ + cleanup_query_list = ( self._state.top_level_query_handler_context.cleanup_released_object_proxies() + ) self._execute_queries(cleanup_query_list) def _execute_queries(self, queries: List[Query]): @@ -92,23 +109,31 @@ def _execute_queries(self, queries: List[Query]): def _release_and_create_query_handler_context_of_input_query(self): if self._state.input_query_query_handler_context is not None: self._state.input_query_query_handler_context.release() - self._state.input_query_query_handler_context = \ + self._state.input_query_query_handler_context = ( self._state.top_level_query_handler_context.get_child_query_handler_context() + ) - def _wrap_return_query(self, input_query: SelectQueryWithColumnDefinition) -> Tuple[str, str]: - temporary_view_name = self._state.input_query_query_handler_context.get_temporary_view_name() + def _wrap_return_query( + self, input_query: SelectQueryWithColumnDefinition + ) -> Tuple[str, str]: + temporary_view_name = ( + self._state.input_query_query_handler_context.get_temporary_view_name() + ) input_query_create_view_string = cleandoc( f""" CREATE OR REPLACE VIEW {temporary_view_name.fully_qualified} AS {input_query.query_string}; -""") - full_qualified_columns = [col.name.fully_qualified - for col in input_query.output_columns] +""" + ) + full_qualified_columns = [ + col.name.fully_qualified for col in input_query.output_columns + ] columns_str = ",\n".join(full_qualified_columns) input_query_string = cleandoc( f""" SELECT {textwrap.indent(columns_str, " " * 4)} FROM {temporary_view_name.fully_qualified}; -""") +""" + ) return input_query_create_view_string, input_query_string diff --git a/exasol/analytics/query_handler/query/drop/interface.py b/exasol/analytics/query_handler/query/drop/interface.py index dbb583d5..ef09c2a5 100644 --- a/exasol/analytics/query_handler/query/drop/interface.py +++ b/exasol/analytics/query_handler/query/drop/interface.py @@ -1,4 +1,5 @@ from exasol.analytics.query_handler.query.interface import Query + class DropQuery(Query): pass diff --git a/exasol/analytics/query_handler/query/drop/table.py b/exasol/analytics/query_handler/query/drop/table.py index 4dd92e2e..2d616040 100644 --- a/exasol/analytics/query_handler/query/drop/table.py +++ b/exasol/analytics/query_handler/query/drop/table.py @@ -1,7 +1,5 @@ -from exasol.analytics.schema import TableName - - from exasol.analytics.query_handler.query.drop.interface import DropQuery +from exasol.analytics.schema import TableName class DropTableQuery(DropQuery): diff --git a/exasol/analytics/query_handler/query/drop/view.py b/exasol/analytics/query_handler/query/drop/view.py index fc806167..24ca6b41 100644 --- a/exasol/analytics/query_handler/query/drop/view.py +++ b/exasol/analytics/query_handler/query/drop/view.py @@ -1,10 +1,5 @@ - -from exasol.analytics.schema import ( - TableName, - ViewName, -) - from exasol.analytics.query_handler.query.drop.interface import DropQuery +from exasol.analytics.schema import TableName, ViewName class DropViewQuery(DropQuery): @@ -17,5 +12,5 @@ def query_string(self) -> str: return f"DROP VIEW IF EXISTS {self._view_name.fully_qualified};" @property - def view_name(self)-> TableName: + def view_name(self) -> TableName: return self._view_name diff --git a/exasol/analytics/query_handler/query/result/interface.py b/exasol/analytics/query_handler/query/result/interface.py index c5552240..dbf92362 100644 --- a/exasol/analytics/query_handler/query/result/interface.py +++ b/exasol/analytics/query_handler/query/result/interface.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod -from typing import Union, List, Any, Tuple, Iterator -from exasol.analytics.schema.column import \ - Column +from typing import Any, Iterator, List, Tuple, Union + +from exasol.analytics.schema.column import Column Row = Tuple[Any, ...] @@ -33,8 +33,7 @@ def rowcount(self) -> int: pass @abstractmethod - def fetch_as_dataframe( - self, num_rows: Union[str, int], start_col: int = 0): + def fetch_as_dataframe(self, num_rows: Union[str, int], start_col: int = 0): pass @abstractmethod diff --git a/exasol/analytics/query_handler/query/result/python_query_result.py b/exasol/analytics/query_handler/query/result/python_query_result.py index a3ddd3e2..26e4e963 100644 --- a/exasol/analytics/query_handler/query/result/python_query_result.py +++ b/exasol/analytics/query_handler/query/result/python_query_result.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Any, Union, Optional, Iterator +from typing import Any, Iterator, List, Optional, Tuple, Union import pandas as pd from exasol_udf_mock_python.column import Column @@ -42,10 +42,14 @@ def __init__(self, data: List[Tuple[Any, ...]], columns: List[Column]): self._columns = columns self._data = data self._iter = iter(data) - self._column_name_index_mapping = {column.name.name: index for index, column in enumerate(columns)} + self._column_name_index_mapping = { + column.name.name: index for index, column in enumerate(columns) + } self._next() - def fetch_as_dataframe(self, num_rows: Union[int, str], start_col=0) -> Optional[pd.DataFrame]: + def fetch_as_dataframe( + self, num_rows: Union[int, str], start_col=0 + ) -> Optional[pd.DataFrame]: batch_list = [] if num_rows == "all": num_rows = len(self._data) @@ -59,8 +63,9 @@ def fetch_as_dataframe(self, num_rows: Union[int, str], start_col=0) -> Optional break self._next() if len(batch_list) > 0: - df = pd.DataFrame(data=batch_list, - columns=[column.name.name for column in self._columns]) # TODO dtype + df = pd.DataFrame( + data=batch_list, columns=[column.name.name for column in self._columns] + ) # TODO dtype df = df.iloc[:, start_col:] return df else: diff --git a/exasol/analytics/query_handler/query/result/udf_query_result.py b/exasol/analytics/query_handler/query/result/udf_query_result.py index e6f196fa..99664cf4 100644 --- a/exasol/analytics/query_handler/query/result/udf_query_result.py +++ b/exasol/analytics/query_handler/query/result/udf_query_result.py @@ -1,26 +1,23 @@ import collections -from typing import Union, List, Any, OrderedDict, Iterator +from typing import Any, Iterator, List, OrderedDict, Union -from exasol.analytics.schema.column import \ - Column -from exasol.analytics.schema.column_name import \ - ColumnName -from exasol.analytics.schema.column_type import \ - ColumnType - -from exasol.analytics.query_handler.query.result.interface import QueryResult, Row +from exasol.analytics.query_handler.query.result.interface import QueryResult, Row +from exasol.analytics.schema.column import Column +from exasol.analytics.schema.column_name import ColumnName +from exasol.analytics.schema.column_type import ColumnType class UDFQueryResult(QueryResult): - def __init__(self, ctx, exa, column_mapping: OrderedDict[str, str], - start_col: int = 0): + def __init__( + self, ctx, exa, column_mapping: OrderedDict[str, str], start_col: int = 0 + ): self._start_col = start_col self._ctx = ctx self._has_next = True - self._reverse_column_mapping = \ - collections.OrderedDict( - [(value, key) for key, value in column_mapping.items()]) + self._reverse_column_mapping = collections.OrderedDict( + [(value, key) for key, value in column_mapping.items()] + ) self._columns = self._compute_columns(exa) self._initialized = False @@ -52,7 +49,9 @@ def __next__(self) -> Row: def rowcount(self) -> int: return self._ctx.size() - def fetch_as_dataframe(self, num_rows: Union[str, int], start_col: int = 0) -> "pandas.DataFrame": + def fetch_as_dataframe( + self, num_rows: Union[str, int], start_col: int = 0 + ) -> "pandas.DataFrame": df = self._ctx.get_dataframe(num_rows, start_col=self._start_col) self._initialized = True if df is None: @@ -67,10 +66,13 @@ def columns(self) -> List[Column]: return list(self._columns) def _compute_columns(self, exa) -> List[Column]: - column_dict = {column.name: column.sql_type - for column in exa.meta.input_columns} - columns = [Column(ColumnName(key), ColumnType(column_dict[value])) - for key, value in self._reverse_column_mapping.items()] + column_dict = { + column.name: column.sql_type for column in exa.meta.input_columns + } + columns = [ + Column(ColumnName(key), ColumnType(column_dict[value])) + for key, value in self._reverse_column_mapping.items() + ] return columns def column_names(self) -> List[str]: diff --git a/exasol/analytics/query_handler/query/select.py b/exasol/analytics/query_handler/query/select.py index 124d9e0b..eac26641 100644 --- a/exasol/analytics/query_handler/query/select.py +++ b/exasol/analytics/query_handler/query/select.py @@ -1,10 +1,8 @@ from abc import abstractmethod from typing import List -from exasol.analytics.schema import Column - - from exasol.analytics.query_handler.query.interface import Query +from exasol.analytics.schema import Column class SelectQuery(Query): diff --git a/exasol/analytics/query_handler/query_handler.py b/exasol/analytics/query_handler/query_handler.py index b8a583c6..1d2ebcae 100644 --- a/exasol/analytics/query_handler/query_handler.py +++ b/exasol/analytics/query_handler/query_handler.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, TypeVar, Generic, Union +from typing import Any, Dict, Generic, TypeVar, Union +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext from exasol.analytics.query_handler.query.result.interface import QueryResult -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.result import Result, Continue, Finish +from exasol.analytics.query_handler.result import Continue, Finish, Result ResultType = TypeVar("ResultType") ParameterType = TypeVar("ParameterType") @@ -11,9 +11,9 @@ class QueryHandler(ABC, Generic[ParameterType, ResultType]): - def __init__(self, - parameter: ParameterType, - query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: ParameterType, query_handler_context: ScopeQueryHandlerContext + ): self._query_handler_context = query_handler_context @abstractmethod @@ -21,6 +21,7 @@ def start(self) -> Union[Continue, Finish[ResultType]]: raise NotImplementedError() @abstractmethod - def handle_query_result(self, query_result: QueryResult) \ - -> Union[Continue, Finish[ResultType]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[ResultType]]: raise NotImplementedError() diff --git a/exasol/analytics/query_handler/result.py b/exasol/analytics/query_handler/result.py index 22d81b76..8fe65830 100644 --- a/exasol/analytics/query_handler/result.py +++ b/exasol/analytics/query_handler/result.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Generic, TypeVar +from typing import Generic, List, TypeVar from exasol.analytics.query_handler.query.interface import Query from exasol.analytics.query_handler.query.select import SelectQueryWithColumnDefinition diff --git a/exasol/analytics/query_handler/udf/factory.py b/exasol/analytics/query_handler/udf/factory.py deleted file mode 100644 index 63c85a31..00000000 --- a/exasol/analytics/query_handler/udf/factory.py +++ /dev/null @@ -1,15 +0,0 @@ -from abc import ABC, abstractmethod - -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.udf.interface import UDFQueryHandler - - -class UDFQueryHandlerFactory(ABC): - """ - An abstract class for factories which are injected by name to the QueryHandlerRunnerUDF - which then will create the instance from the name. - """ - - @abstractmethod - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - """Creates a UDFQueryHandler""" diff --git a/exasol/analytics/query_handler/udf/interface.py b/exasol/analytics/query_handler/udf/interface.py index 176018d8..22083ad6 100644 --- a/exasol/analytics/query_handler/udf/interface.py +++ b/exasol/analytics/query_handler/udf/interface.py @@ -1,9 +1,7 @@ -from exasol.analytics.query_handler.query_handler import QueryHandler - - from abc import ABC, abstractmethod -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.query_handler import QueryHandler class UDFQueryHandler(QueryHandler[str, str]): @@ -17,7 +15,7 @@ class UDFQueryHandlerFactory(ABC): """ @abstractmethod - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: """Creates a UDFQueryHandler""" - - diff --git a/exasol/analytics/query_handler/udf/json_impl.py b/exasol/analytics/query_handler/udf/json_impl.py index ac7d3d3c..bc41094e 100644 --- a/exasol/analytics/query_handler/udf/json_impl.py +++ b/exasol/analytics/query_handler/udf/json_impl.py @@ -2,35 +2,47 @@ from abc import ABC from typing import Type, Union -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.result import Continue, Finish +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.json_udf_query_handler import ( + JSONQueryHandler, + JSONType, +) from exasol.analytics.query_handler.query.result.interface import QueryResult -from exasol.analytics.query_handler.json_udf_query_handler import JSONQueryHandler, JSONType -from exasol.analytics.query_handler.udf.interface import UDFQueryHandler -from exasol.analytics.query_handler.udf.interface import UDFQueryHandlerFactory +from exasol.analytics.query_handler.result import Continue, Finish +from exasol.analytics.query_handler.udf.interface import ( + UDFQueryHandler, + UDFQueryHandlerFactory, +) class JsonUDFQueryHandler(UDFQueryHandler): - def __init__(self, parameter: str, - query_handler_context: ScopeQueryHandlerContext, - wrapped_json_query_handler_class: Type[JSONQueryHandler]): + def __init__( + self, + parameter: str, + query_handler_context: ScopeQueryHandlerContext, + wrapped_json_query_handler_class: Type[JSONQueryHandler], + ): super().__init__(parameter, query_handler_context) json_parameter = json.loads(parameter) self._wrapped_json_query_handler = wrapped_json_query_handler_class( - parameter=json_parameter, - query_handler_context=query_handler_context) + parameter=json_parameter, query_handler_context=query_handler_context + ) def start(self) -> Union[Continue, Finish[str]]: result = self._wrapped_json_query_handler.start() return self._handle_result(result) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[str]]: result = self._wrapped_json_query_handler.handle_query_result(query_result) return self._handle_result(result) @staticmethod - def _handle_result(result: Union[Continue, Finish[JSONType]]) -> Union[Continue, Finish[str]]: + def _handle_result( + result: Union[Continue, Finish[JSONType]] + ) -> Union[Continue, Finish[str]]: if isinstance(result, Continue): return result elif isinstance(result, Finish): @@ -45,8 +57,11 @@ class JsonUDFQueryHandlerFactory(UDFQueryHandlerFactory, ABC): def __init__(self, wrapped_json_query_handler_class: Type[JSONQueryHandler]): self._wrapped_json_query_handler_class = wrapped_json_query_handler_class - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: return JsonUDFQueryHandler( parameter=parameter, query_handler_context=query_handler_context, - wrapped_json_query_handler_class=self._wrapped_json_query_handler_class) + wrapped_json_query_handler_class=self._wrapped_json_query_handler_class, + ) diff --git a/exasol/analytics/query_handler/udf/runner/call_udf.py b/exasol/analytics/query_handler/udf/runner/call_udf.py index 1396014b..3c52b31e 100644 --- a/exasol/analytics/query_handler/udf/runner/call_udf.py +++ b/exasol/analytics/query_handler/udf/runner/call_udf.py @@ -1,4 +1,4 @@ -from exasol.analytics.query_handler.udf.runner.udf import QueryHandlerRunnerUDF +from exasol.analytics.query_handler.udf.runner.udf import QueryHandlerRunnerUDF udf = QueryHandlerRunnerUDF(exa) diff --git a/exasol/analytics/query_handler/udf/runner/state.py b/exasol/analytics/query_handler/udf/runner/state.py index c4f9de14..8870efe9 100644 --- a/exasol/analytics/query_handler/udf/runner/state.py +++ b/exasol/analytics/query_handler/udf/runner/state.py @@ -1,13 +1,13 @@ from dataclasses import dataclass from typing import List, Optional -from exasol.analytics.schema.column import \ - Column - -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.context.top_level_query_handler_context import TopLevelQueryHandlerContext -from exasol.analytics.query_handler.query_handler import QueryHandler +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.context.top_level_query_handler_context import ( + TopLevelQueryHandlerContext, +) +from exasol.analytics.query_handler.query_handler import QueryHandler from exasol.analytics.query_handler.udf.connection_lookup import UDFConnectionLookup +from exasol.analytics.schema.column import Column @dataclass() diff --git a/exasol/analytics/query_handler/udf/runner/udf.py b/exasol/analytics/query_handler/udf/runner/udf.py index 71e92cd3..1310a68d 100644 --- a/exasol/analytics/query_handler/udf/runner/udf.py +++ b/exasol/analytics/query_handler/udf/runner/udf.py @@ -1,32 +1,33 @@ import dataclasses import importlib import json -import joblib import logging import traceback from collections import OrderedDict from enum import Enum, auto -from typing import Any, Tuple, List, Optional +from io import BytesIO +from typing import Any, List, Optional, Tuple import exasol.bucketfs as bfs -from io import BytesIO +import joblib +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.context.top_level_query_handler_context import ( + TopLevelQueryHandlerContext, +) +from exasol.analytics.query_handler.query.result.udf_query_result import UDFQueryResult +from exasol.analytics.query_handler.query.select import SelectQueryWithColumnDefinition +from exasol.analytics.query_handler.result import Continue, Finish, Result +from exasol.analytics.query_handler.udf.connection_lookup import UDFConnectionLookup +from exasol.analytics.query_handler.udf.runner.state import QueryHandlerRunnerState from exasol.analytics.schema import ( - SchemaName, Column, - UDFNameBuilder, ColumnName, ColumnType, + SchemaName, + UDFNameBuilder, ) -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.context.top_level_query_handler_context import TopLevelQueryHandlerContext -from exasol.analytics.query_handler.query.select import SelectQueryWithColumnDefinition -from exasol.analytics.query_handler.result import Finish, Continue, Result -from exasol.analytics.query_handler.query.result.udf_query_result import UDFQueryResult -from exasol.analytics.query_handler.udf.runner.state import QueryHandlerRunnerState -from exasol.analytics.query_handler.udf.connection_lookup import UDFConnectionLookup - def create_bucketfs_location_from_conn_object(bfs_conn_obj) -> bfs.path.PathLike: bfs_params = json.loads(bfs_conn_obj.address) @@ -91,10 +92,16 @@ def run(self, ctx) -> None: if self.parameter.iter_num == 0: query_handler_result = current_state.query_handler.start() else: - query_result = self._create_udf_query_result(ctx, current_state.input_query_output_columns) - query_handler_result = current_state.query_handler.handle_query_result(query_result) - - udf_result = self.handle_query_handler_result(query_handler_result, current_state) + query_result = self._create_udf_query_result( + ctx, current_state.input_query_output_columns + ) + query_handler_result = current_state.query_handler.handle_query_result( + query_result + ) + + udf_result = self.handle_query_handler_result( + query_handler_result, current_state + ) if isinstance(query_handler_result, Continue): self._save_current_state(current_state) if self.parameter.iter_num > 0: @@ -103,38 +110,45 @@ def run(self, ctx) -> None: except Exception as e: self.handle_exception(ctx, current_state) - def handle_exception(self, ctx, - current_state: QueryHandlerRunnerState): + def handle_exception(self, ctx, current_state: QueryHandlerRunnerState): stacktrace = traceback.format_exc() logging.exception("Catched exception, starting cleanup.") try: self.release_query_handler_context(current_state) except: - logging.exception("Catched exception during handling cleanup of another exception") - cleanup_queries = current_state.top_level_query_handler_context.cleanup_released_object_proxies() + logging.exception( + "Catched exception during handling cleanup of another exception" + ) + cleanup_queries = ( + current_state.top_level_query_handler_context.cleanup_released_object_proxies() + ) udf_result = UDFResult() udf_result.cleanup_query_list = cleanup_queries udf_result.final_result = stacktrace udf_result.status = QueryHandlerStatus.ERROR self.emit_udf_result(ctx, udf_result) - def handle_query_handler_result(self, - query_handler_result: Result, - current_state: QueryHandlerRunnerState) -> UDFResult: + def handle_query_handler_result( + self, query_handler_result: Result, current_state: QueryHandlerRunnerState + ) -> UDFResult: if isinstance(query_handler_result, Finish): - udf_result = self.handle_query_handler_result_finished(current_state, query_handler_result) + udf_result = self.handle_query_handler_result_finished( + current_state, query_handler_result + ) elif isinstance(query_handler_result, Continue): - udf_result = self.handle_query_handler_result_continue(current_state, query_handler_result) + udf_result = self.handle_query_handler_result_continue( + current_state, query_handler_result + ) else: raise RuntimeError(f"Unknown query_handler_result {query_handler_result}") - udf_result.cleanup_query_list = \ + udf_result.cleanup_query_list = ( current_state.top_level_query_handler_context.cleanup_released_object_proxies() + ) return udf_result def handle_query_handler_result_finished( - self, - current_state: QueryHandlerRunnerState, - query_handler_result: Finish) -> UDFResult: + self, current_state: QueryHandlerRunnerState, query_handler_result: Finish + ) -> UDFResult: udf_result = UDFResult() udf_result.final_result = query_handler_result.result udf_result.status = QueryHandlerStatus.FINISHED @@ -147,25 +161,31 @@ def release_query_handler_context(current_state: QueryHandlerRunnerState): current_state.input_query_query_handler_context.release() current_state.top_level_query_handler_context.release() - def handle_query_handler_result_continue(self, - current_state: QueryHandlerRunnerState, - query_handler_result: Continue) -> UDFResult: + def handle_query_handler_result_continue( + self, current_state: QueryHandlerRunnerState, query_handler_result: Continue + ) -> UDFResult: udf_result = UDFResult() udf_result.status = QueryHandlerStatus.CONTINUE udf_result.query_list = query_handler_result.query_list - current_state.input_query_output_columns = query_handler_result.input_query.output_columns + current_state.input_query_output_columns = ( + query_handler_result.input_query.output_columns + ) self.release_and_create_query_handler_context_if_input_query(current_state) - udf_result.input_query_view, udf_result.input_query = \ - self._wrap_return_query(current_state.input_query_query_handler_context, - query_handler_result.input_query) + udf_result.input_query_view, udf_result.input_query = self._wrap_return_query( + current_state.input_query_query_handler_context, + query_handler_result.input_query, + ) return udf_result @staticmethod - def release_and_create_query_handler_context_if_input_query(current_state: QueryHandlerRunnerState): + def release_and_create_query_handler_context_if_input_query( + current_state: QueryHandlerRunnerState, + ): if current_state.input_query_query_handler_context is not None: current_state.input_query_query_handler_context.release() - current_state.input_query_query_handler_context = \ + current_state.input_query_query_handler_context = ( current_state.top_level_query_handler_context.get_child_query_handler_context() + ) def _get_parameter(self, ctx): iter_num = ctx[0] @@ -178,21 +198,26 @@ def _get_parameter(self, ctx): temporary_schema_name=ctx[4], python_class_name=ctx[5], python_class_module=ctx[6], - parameter=ctx[7]) + parameter=ctx[7], + ) else: self.parameter = UDFParameter( iter_num=iter_num, temporary_bfs_location_conn=ctx[1], temporary_bfs_location_directory=ctx[2], - temporary_name_prefix=ctx[3]) + temporary_name_prefix=ctx[3], + ) def _create_bucketfs_location(self): - bucketfs_connection_obj = self.exa.get_connection(self.parameter.temporary_bfs_location_conn) + bucketfs_connection_obj = self.exa.get_connection( + self.parameter.temporary_bfs_location_conn + ) bucketfs_location_from_con = create_bucketfs_location_from_conn_object( - bucketfs_connection_obj) - self.bucketfs_location = bucketfs_location_from_con \ - .joinpath(self.parameter.temporary_bfs_location_directory) \ - .joinpath(self.parameter.temporary_name_prefix) + bucketfs_connection_obj + ) + self.bucketfs_location = bucketfs_location_from_con.joinpath( + self.parameter.temporary_bfs_location_directory + ).joinpath(self.parameter.temporary_name_prefix) def _create_state_or_load_latest_state(self) -> QueryHandlerRunnerState: if self.parameter.iter_num > 0: @@ -207,15 +232,17 @@ def _create_state(self) -> QueryHandlerRunnerState: self.bucketfs_location, self.parameter.temporary_name_prefix, self.parameter.temporary_schema_name, - connection_lookup + connection_lookup, ) module = importlib.import_module(self.parameter.python_class_module) query_handler_factory_class = getattr(module, self.parameter.python_class_name) - query_handler_obj = query_handler_factory_class().create(self.parameter.parameter, context) + query_handler_obj = query_handler_factory_class().create( + self.parameter.parameter, context + ) query_handler_state = QueryHandlerRunnerState( top_level_query_handler_context=context, query_handler=query_handler_obj, - connection_lookup=connection_lookup + connection_lookup=connection_lookup, ) return query_handler_state @@ -233,27 +260,31 @@ def _remove_previous_state(self) -> None: self._state_file_bucketfs_location().rm() def _create_udf_query_result( - self, ctx, query_columns: List[Column]) -> UDFQueryResult: + self, ctx, query_columns: List[Column] + ) -> UDFQueryResult: colum_start_ix = 8 if self.parameter.iter_num == 0 else 4 - column_mapping = OrderedDict([ - (str(colum_start_ix + index), column.name.name) - for index, column in enumerate(query_columns)]) + column_mapping = OrderedDict( + [ + (str(colum_start_ix + index), column.name.name) + for index, column in enumerate(query_columns) + ] + ) return UDFQueryResult(ctx, self.exa, column_mapping=column_mapping) - def _wrap_return_query(self, - query_handler_context: ScopeQueryHandlerContext, - input_query: SelectQueryWithColumnDefinition) \ - -> Tuple[str, str]: + def _wrap_return_query( + self, + query_handler_context: ScopeQueryHandlerContext, + input_query: SelectQueryWithColumnDefinition, + ) -> Tuple[str, str]: temporary_view_name = query_handler_context.get_temporary_view_name() - query_handler_udf_name = \ - UDFNameBuilder.create( - name=self.exa.meta.script_name, - schema=SchemaName(self.exa.meta.script_schema) - ) - query_create_view = \ - f"CREATE VIEW {temporary_view_name.fully_qualified} AS {input_query.query_string};" - full_qualified_columns = [col.name.fully_qualified - for col in input_query.output_columns] + query_handler_udf_name = UDFNameBuilder.create( + name=self.exa.meta.script_name, + schema=SchemaName(self.exa.meta.script_schema), + ) + query_create_view = f"CREATE VIEW {temporary_view_name.fully_qualified} AS {input_query.query_string};" + full_qualified_columns = [ + col.name.fully_qualified for col in input_query.output_columns + ] call_columns = [ f"{self.parameter.iter_num + 1}", f"'{self.parameter.temporary_bfs_location_conn}'", @@ -261,9 +292,10 @@ def _wrap_return_query(self, f"'{self.parameter.temporary_name_prefix}'", ] columns_str = ",".join(call_columns + full_qualified_columns) - query_query_handler = \ - f"SELECT {query_handler_udf_name.fully_qualified}({columns_str}) " \ + query_query_handler = ( + f"SELECT {query_handler_udf_name.fully_qualified}({columns_str}) " f"FROM {temporary_view_name.fully_qualified};" + ) return query_create_view, query_query_handler def _get_query_columns(self): @@ -271,8 +303,7 @@ def _get_query_columns(self): for i in range(len(self.exa.meta.input_columns)): col_name = self.exa.meta.input_columns[i].name col_type = self.exa.meta.input_columns[i].sql_type - query_columns.append( - Column(ColumnName(col_name), ColumnType(col_type))) + query_columns.append(Column(ColumnName(col_name), ColumnType(col_type))) return query_columns def _state_file_bucketfs_location(self, iter_offset: int = 0) -> bfs.path.PathLike: diff --git a/exasol/analytics/schema/__init__.py b/exasol/analytics/schema/__init__.py index 3af1d3f1..92826833 100644 --- a/exasol/analytics/schema/__init__.py +++ b/exasol/analytics/schema/__init__.py @@ -1,37 +1,33 @@ -from exasol.analytics.schema.exasol_identifier import ExasolIdentifier -from exasol.analytics.schema.exasol_identifier_impl import UnicodeCategories -from exasol.analytics.schema.exasol_identifier_impl import ExasolIdentifierImpl - -from exasol.analytics.schema.dbobject_name import DBObjectName -from exasol.analytics.schema.dbobject_name_impl import DBObjectNameImpl - -from exasol.analytics.schema.schema_name import SchemaName -from exasol.analytics.schema.dbobject_name_with_schema import DBObjectNameWithSchema -from exasol.analytics.schema.table_like_name import TableLikeName -from exasol.analytics.schema.table_name import TableName +from exasol.analytics.schema.column import Column +from exasol.analytics.schema.column_builder import ColumnBuilder from exasol.analytics.schema.column_name import ColumnName +from exasol.analytics.schema.column_name_builder import ColumnNameBuilder from exasol.analytics.schema.column_type import ColumnType from exasol.analytics.schema.connection_object_name import ConnectionObjectName -from exasol.analytics.schema.experiment_name import ExperimentName -from exasol.analytics.schema.udf_name import UDFName -from exasol.analytics.schema.view_name import ViewName - -from exasol.analytics.schema.column import Column +from exasol.analytics.schema.connection_object_name_builder import ConnectionObjectNameBuilder from exasol.analytics.schema.connection_object_name_impl import ConnectionObjectNameImpl -from exasol.analytics.schema.dbobject_name_with_schema_impl import DBObjectNameWithSchemaImpl from exasol.analytics.schema.dbobject import DBObject -from exasol.analytics.schema.table_like_name_impl import TableLikeNameImpl +from exasol.analytics.schema.dbobject_name import DBObjectName +from exasol.analytics.schema.dbobject_name_impl import DBObjectNameImpl +from exasol.analytics.schema.dbobject_name_with_schema import DBObjectNameWithSchema +from exasol.analytics.schema.dbobject_name_with_schema_impl import DBObjectNameWithSchemaImpl +from exasol.analytics.schema.exasol_identifier import ExasolIdentifier +from exasol.analytics.schema.exasol_identifier_impl import ExasolIdentifierImpl +from exasol.analytics.schema.exasol_identifier_impl import UnicodeCategories +from exasol.analytics.schema.experiment_name import ExperimentName +from exasol.analytics.schema.schema_name import SchemaName +from exasol.analytics.schema.table import Table +from exasol.analytics.schema.table_builder import TableBuilder from exasol.analytics.schema.table_like import TableLike +from exasol.analytics.schema.table_like_name import TableLikeName +from exasol.analytics.schema.table_like_name_impl import TableLikeNameImpl +from exasol.analytics.schema.table_name import TableName +from exasol.analytics.schema.table_name_builder import TableNameBuilder from exasol.analytics.schema.table_name_impl import TableNameImpl -from exasol.analytics.schema.table import Table +from exasol.analytics.schema.udf_name import UDFName +from exasol.analytics.schema.udf_name_builder import UDFNameBuilder from exasol.analytics.schema.udf_name_impl import UDFNameImpl -from exasol.analytics.schema.view_name_impl import ViewNameImpl from exasol.analytics.schema.view import View - +from exasol.analytics.schema.view_name import ViewName from exasol.analytics.schema.view_name_builder import ViewNameBuilder -from exasol.analytics.schema.udf_name_builder import UDFNameBuilder -from exasol.analytics.schema.table_name_builder import TableNameBuilder -from exasol.analytics.schema.column_builder import ColumnBuilder -from exasol.analytics.schema.column_name_builder import ColumnNameBuilder -from exasol.analytics.schema.table_builder import TableBuilder -from exasol.analytics.schema.connection_object_name_builder import ConnectionObjectNameBuilder +from exasol.analytics.schema.view_name_impl import ViewNameImpl diff --git a/exasol/analytics/schema/column.py b/exasol/analytics/schema/column.py index 3535aa3d..dc01a1ad 100644 --- a/exasol/analytics/schema/column.py +++ b/exasol/analytics/schema/column.py @@ -2,10 +2,8 @@ import typeguard -from exasol.analytics.schema import ( - ColumnType, - ColumnName, -) +from exasol.analytics.schema.column_name import ColumnName +from exasol.analytics.schema.column_type import ColumnType from exasol.analytics.utils.data_classes_runtime_type_check import check_dataclass_types @@ -15,4 +13,4 @@ class Column: type: ColumnType def __post_init__(self): - check_dataclass_types(self) \ No newline at end of file + check_dataclass_types(self) diff --git a/exasol/analytics/schema/column_builder.py b/exasol/analytics/schema/column_builder.py index 49046961..58af0932 100644 --- a/exasol/analytics/schema/column_builder.py +++ b/exasol/analytics/schema/column_builder.py @@ -1,10 +1,8 @@ from typing import Union -from exasol.analytics.schema import ( - Column, - ColumnType, - ColumnName, -) +from exasol.analytics.schema.column import Column +from exasol.analytics.schema.column_name import ColumnName +from exasol.analytics.schema.column_type import ColumnType class ColumnBuilder: diff --git a/exasol/analytics/schema/column_name.py b/exasol/analytics/schema/column_name.py index c7dcf6cc..7370c889 100644 --- a/exasol/analytics/schema/column_name.py +++ b/exasol/analytics/schema/column_name.py @@ -1,17 +1,15 @@ from typeguard import typechecked -from exasol.analytics.schema import ( - ExasolIdentifierImpl, - TableLikeName, - ExasolIdentifier, -) +from exasol.analytics.schema.exasol_identifier import ExasolIdentifier +from exasol.analytics.schema.exasol_identifier_impl import ExasolIdentifierImpl +from exasol.analytics.schema.table_like_name import TableLikeName from exasol.analytics.utils.hash_generation_for_object import generate_hash_for_object from exasol.analytics.utils.repr_generation_for_object import generate_repr_for_object class ColumnName(ExasolIdentifierImpl): @typechecked - def __init__(self, name: str, table_like_name: TableLikeName|None = None): + def __init__(self, name: str, table_like_name: TableLikeName | None = None): super().__init__(name) self._table_like_name = table_like_name @@ -22,14 +20,16 @@ def table_like_name(self): @property def fully_qualified(self) -> str: if self.table_like_name is not None: - return f'{self._table_like_name.fully_qualified}.{self.quoted_name}' + return f"{self._table_like_name.fully_qualified}.{self.quoted_name}" else: return self.quoted_name def __eq__(self, other): - return isinstance(other, ColumnName) and \ - self._name == other.name and \ - self._table_like_name == other.table_like_name + return ( + isinstance(other, ColumnName) + and self._name == other.name + and self._table_like_name == other.table_like_name + ) def __repr__(self): return generate_repr_for_object(self) diff --git a/exasol/analytics/schema/column_name_builder.py b/exasol/analytics/schema/column_name_builder.py index fd77c8cb..15728da8 100644 --- a/exasol/analytics/schema/column_name_builder.py +++ b/exasol/analytics/schema/column_name_builder.py @@ -1,16 +1,16 @@ from typing import Optional -from exasol.analytics.schema import ( - ColumnName, - TableLikeName, -) +from exasol.analytics.schema.column_name import ColumnName +from exasol.analytics.schema.table_like_name import TableLikeName class ColumnNameBuilder: - def __init__(self, - name: Optional[str] = None, - table_like_name: Optional[TableLikeName] = None, - column_name: Optional[ColumnName] = None): + def __init__( + self, + name: Optional[str] = None, + table_like_name: Optional[TableLikeName] = None, + column_name: Optional[ColumnName] = None, + ): """ Creates a builder for ColumnName objects, either by copying a ColumnName object or @@ -30,7 +30,9 @@ def with_name(self, name: str) -> "ColumnNameBuilder": self._name = name return self - def with_table_like_name(self, table_like_name: TableLikeName) -> "ColumnNameBuilder": + def with_table_like_name( + self, table_like_name: TableLikeName + ) -> "ColumnNameBuilder": self._table_like_name = table_like_name return self @@ -39,5 +41,7 @@ def build(self) -> ColumnName: return name @staticmethod - def create(name: str, table_like_name: Optional[TableLikeName] = None) -> ColumnName: + def create( + name: str, table_like_name: Optional[TableLikeName] = None + ) -> ColumnName: return ColumnName(name, table_like_name) diff --git a/exasol/analytics/schema/column_type.py b/exasol/analytics/schema/column_type.py index 9633c7cd..81d59ca1 100644 --- a/exasol/analytics/schema/column_type.py +++ b/exasol/analytics/schema/column_type.py @@ -18,4 +18,4 @@ class ColumnType: srid: Optional[int] = None def __post_init__(self): - check_dataclass_types(self) \ No newline at end of file + check_dataclass_types(self) diff --git a/exasol/analytics/schema/connection_object_name.py b/exasol/analytics/schema/connection_object_name.py index 8c1af2f2..ece0903b 100644 --- a/exasol/analytics/schema/connection_object_name.py +++ b/exasol/analytics/schema/connection_object_name.py @@ -1,7 +1,6 @@ from abc import abstractmethod -from exasol.analytics.schema import DBObjectName - +from exasol.analytics.schema.dbobject_name import DBObjectName class ConnectionObjectName(DBObjectName): diff --git a/exasol/analytics/schema/connection_object_name_builder.py b/exasol/analytics/schema/connection_object_name_builder.py index 695eb28c..365809c0 100644 --- a/exasol/analytics/schema/connection_object_name_builder.py +++ b/exasol/analytics/schema/connection_object_name_builder.py @@ -1,14 +1,12 @@ -from typing import Union, Optional +from typing import Optional, Union -from exasol.analytics.schema import ( - ConnectionObjectName, - SchemaName, - ConnectionObjectNameImpl, - TableName, - ViewNameImpl, - TableNameImpl, - ViewName, -) +from exasol.analytics.schema.connection_object_name import ConnectionObjectName +from exasol.analytics.schema.connection_object_name_impl import ConnectionObjectNameImpl +from exasol.analytics.schema.schema_name import SchemaName +from exasol.analytics.schema.table_name import TableName +from exasol.analytics.schema.table_name_impl import TableNameImpl +from exasol.analytics.schema.view_name import ViewName +from exasol.analytics.schema.view_name_impl import ViewNameImpl class ConnectionObjectNameBuilder: diff --git a/exasol/analytics/schema/connection_object_name_impl.py b/exasol/analytics/schema/connection_object_name_impl.py index ba951703..c6ebdf64 100644 --- a/exasol/analytics/schema/connection_object_name_impl.py +++ b/exasol/analytics/schema/connection_object_name_impl.py @@ -1,9 +1,7 @@ from typing import cast -from exasol.analytics.schema import ( - DBObjectNameImpl, - ConnectionObjectName, -) +from exasol.analytics.schema.connection_object_name import ConnectionObjectName +from exasol.analytics.schema.dbobject_name_impl import DBObjectNameImpl from exasol.analytics.utils.repr_generation_for_object import generate_repr_for_object @@ -25,11 +23,15 @@ def __repr__(self): def __eq__(self, other): # Connection names are case-insensitive https://docs.exasol.com/db/latest/sql/create_connection.htm - return type(other) == type(self) and \ - self._name.upper() == cast(ConnectionObjectName, other).name.upper() + return ( + type(other) == type(self) + and self._name.upper() == cast(ConnectionObjectName, other).name.upper() + ) def __hash__(self): # Connection names are case-insensitive https://docs.exasol.com/db/latest/sql/create_connection.htm - assert len(self.__dict__) == 1, f"The attributes of {self.__class__} changed, " \ - f"you need to update the __hash__ method" + assert len(self.__dict__) == 1, ( + f"The attributes of {self.__class__} changed, " + f"you need to update the __hash__ method" + ) return hash(self._name.upper()) diff --git a/exasol/analytics/schema/dbobject.py b/exasol/analytics/schema/dbobject.py index 0b8c21e8..040051a8 100644 --- a/exasol/analytics/schema/dbobject.py +++ b/exasol/analytics/schema/dbobject.py @@ -1,14 +1,13 @@ from abc import ABC -from typing import TypeVar, Generic +from typing import Generic, TypeVar from typeguard import typechecked -from exasol.analytics.schema import DBObjectName - +from exasol.analytics.schema.dbobject_name import DBObjectName from exasol.analytics.utils.hash_generation_for_object import generate_hash_for_object from exasol.analytics.utils.repr_generation_for_object import generate_repr_for_object -NameType = TypeVar('NameType', bound=DBObjectName) +NameType = TypeVar("NameType", bound=DBObjectName) class DBObject(Generic[NameType], ABC): @@ -22,8 +21,7 @@ def name(self) -> NameType: return self._name def __eq__(self, other): - return type(other) == type(self) and \ - self._name == other.name + return type(other) == type(self) and self._name == other.name def __repr__(self): return generate_repr_for_object(self) diff --git a/exasol/analytics/schema/dbobject_name.py b/exasol/analytics/schema/dbobject_name.py index 5a93377c..13b49a07 100644 --- a/exasol/analytics/schema/dbobject_name.py +++ b/exasol/analytics/schema/dbobject_name.py @@ -1,5 +1,4 @@ -from exasol.analytics.schema import ExasolIdentifier - +from exasol.analytics.schema.exasol_identifier import ExasolIdentifier class DBObjectName(ExasolIdentifier): diff --git a/exasol/analytics/schema/dbobject_name_impl.py b/exasol/analytics/schema/dbobject_name_impl.py index 1fd58643..3c0891ce 100644 --- a/exasol/analytics/schema/dbobject_name_impl.py +++ b/exasol/analytics/schema/dbobject_name_impl.py @@ -1,9 +1,7 @@ from typeguard import typechecked -from exasol.analytics.schema import ( - ExasolIdentifierImpl, - DBObjectName, -) +from exasol.analytics.schema.dbobject_name import DBObjectName +from exasol.analytics.schema.exasol_identifier_impl import ExasolIdentifierImpl from exasol.analytics.utils.hash_generation_for_object import generate_hash_for_object from exasol.analytics.utils.repr_generation_for_object import generate_repr_for_object @@ -18,8 +16,7 @@ def __repr__(self) -> str: return generate_repr_for_object(self) def __eq__(self, other) -> bool: - return type(other) == type(self) and \ - self.name == other.name + return type(other) == type(self) and self.name == other.name def __hash__(self): return generate_hash_for_object(self) diff --git a/exasol/analytics/schema/dbobject_name_with_schema.py b/exasol/analytics/schema/dbobject_name_with_schema.py index bff34797..93ee2022 100644 --- a/exasol/analytics/schema/dbobject_name_with_schema.py +++ b/exasol/analytics/schema/dbobject_name_with_schema.py @@ -1,9 +1,7 @@ from abc import abstractmethod -from exasol.analytics.schema import ( - SchemaName, - DBObjectName, -) +from exasol.analytics.schema.dbobject_name import DBObjectName +from exasol.analytics.schema.schema_name import SchemaName class DBObjectNameWithSchema(DBObjectName): diff --git a/exasol/analytics/schema/dbobject_name_with_schema_impl.py b/exasol/analytics/schema/dbobject_name_with_schema_impl.py index d545a718..c6fc14d6 100644 --- a/exasol/analytics/schema/dbobject_name_with_schema_impl.py +++ b/exasol/analytics/schema/dbobject_name_with_schema_impl.py @@ -2,11 +2,9 @@ from typeguard import typechecked -from exasol.analytics.schema import ( - DBObjectNameWithSchema, - DBObjectNameImpl, - SchemaName, -) +from exasol.analytics.schema.dbobject_name_impl import DBObjectNameImpl +from exasol.analytics.schema.dbobject_name_with_schema import DBObjectNameWithSchema +from exasol.analytics.schema.schema_name import SchemaName from exasol.analytics.utils.hash_generation_for_object import generate_hash_for_object from exasol.analytics.utils.repr_generation_for_object import generate_repr_for_object @@ -25,7 +23,7 @@ def schema_name(self) -> SchemaName: @property def fully_qualified(self) -> str: if self.schema_name is not None: - return f'{self._schema_name.fully_qualified}.{self.quoted_name}' + return f"{self._schema_name.fully_qualified}.{self.quoted_name}" else: return self.quoted_name @@ -33,9 +31,11 @@ def __repr__(self) -> str: return generate_repr_for_object(self) def __eq__(self, other) -> bool: - return type(other) == type(self) and \ - self._name == other.name and \ - self._schema_name == other.schema_name + return ( + type(other) == type(self) + and self._name == other.name + and self._schema_name == other.schema_name + ) def __hash__(self): return generate_hash_for_object(self) diff --git a/exasol/analytics/schema/exasol_identifier_impl.py b/exasol/analytics/schema/exasol_identifier_impl.py index 407a0a68..ebd1b1f5 100644 --- a/exasol/analytics/schema/exasol_identifier_impl.py +++ b/exasol/analytics/schema/exasol_identifier_impl.py @@ -2,22 +2,21 @@ from typeguard import typechecked -from exasol.analytics.schema import ExasolIdentifier - +from exasol.analytics.schema.exasol_identifier import ExasolIdentifier class UnicodeCategories: - UPPERCASE_LETTER = 'Lu' - LOWERCASE_LETTER = 'Ll' - TITLECASE_LETTER = 'Lt' - MODIFIER_LETTER = 'Lm' - OTHER_LETTER = 'Lo' - LETTER_NUMBER = 'Nl' - NON_SPACING_MARK = 'Mn' - COMBINING_SPACING_MARK = 'Mc' - DECIMAL_DIGIT_NUMBER = 'Nd' - CONNECTOR_PUNCTUATION = 'Pc' - FORMAT = 'Cf' + UPPERCASE_LETTER = "Lu" + LOWERCASE_LETTER = "Ll" + TITLECASE_LETTER = "Lt" + MODIFIER_LETTER = "Lm" + OTHER_LETTER = "Lo" + LETTER_NUMBER = "Nl" + NON_SPACING_MARK = "Mn" + COMBINING_SPACING_MARK = "Mc" + DECIMAL_DIGIT_NUMBER = "Nd" + CONNECTOR_PUNCTUATION = "Pc" + FORMAT = "Cf" class ExasolIdentifierImpl(ExasolIdentifier): @@ -50,28 +49,30 @@ def _validate_name(cls, name: str) -> bool: @classmethod def _validate_first_character(cls, chararcter: str) -> bool: unicode_category = unicodedata.category(chararcter) - return \ - unicode_category == UnicodeCategories.UPPERCASE_LETTER or \ - unicode_category == UnicodeCategories.LOWERCASE_LETTER or \ - unicode_category == UnicodeCategories.TITLECASE_LETTER or \ - unicode_category == UnicodeCategories.MODIFIER_LETTER or \ - unicode_category == UnicodeCategories.OTHER_LETTER or \ - unicode_category == UnicodeCategories.LETTER_NUMBER or \ - unicode_category == UnicodeCategories.DECIMAL_DIGIT_NUMBER + return ( + unicode_category == UnicodeCategories.UPPERCASE_LETTER + or unicode_category == UnicodeCategories.LOWERCASE_LETTER + or unicode_category == UnicodeCategories.TITLECASE_LETTER + or unicode_category == UnicodeCategories.MODIFIER_LETTER + or unicode_category == UnicodeCategories.OTHER_LETTER + or unicode_category == UnicodeCategories.LETTER_NUMBER + or unicode_category == UnicodeCategories.DECIMAL_DIGIT_NUMBER + ) @classmethod def _validate_follow_up_character(cls, chararcter: str) -> bool: unicode_category = unicodedata.category(chararcter) - return \ - unicode_category == UnicodeCategories.UPPERCASE_LETTER or \ - unicode_category == UnicodeCategories.LOWERCASE_LETTER or \ - unicode_category == UnicodeCategories.TITLECASE_LETTER or \ - unicode_category == UnicodeCategories.MODIFIER_LETTER or \ - unicode_category == UnicodeCategories.OTHER_LETTER or \ - unicode_category == UnicodeCategories.LETTER_NUMBER or \ - unicode_category == UnicodeCategories.NON_SPACING_MARK or \ - unicode_category == UnicodeCategories.COMBINING_SPACING_MARK or \ - unicode_category == UnicodeCategories.DECIMAL_DIGIT_NUMBER or \ - unicode_category == UnicodeCategories.CONNECTOR_PUNCTUATION or \ - unicode_category == UnicodeCategories.FORMAT or \ - chararcter == '\u00B7' + return ( + unicode_category == UnicodeCategories.UPPERCASE_LETTER + or unicode_category == UnicodeCategories.LOWERCASE_LETTER + or unicode_category == UnicodeCategories.TITLECASE_LETTER + or unicode_category == UnicodeCategories.MODIFIER_LETTER + or unicode_category == UnicodeCategories.OTHER_LETTER + or unicode_category == UnicodeCategories.LETTER_NUMBER + or unicode_category == UnicodeCategories.NON_SPACING_MARK + or unicode_category == UnicodeCategories.COMBINING_SPACING_MARK + or unicode_category == UnicodeCategories.DECIMAL_DIGIT_NUMBER + or unicode_category == UnicodeCategories.CONNECTOR_PUNCTUATION + or unicode_category == UnicodeCategories.FORMAT + or chararcter == "\u00B7" + ) diff --git a/exasol/analytics/schema/experiment_name.py b/exasol/analytics/schema/experiment_name.py index 2111ce25..6de0694d 100644 --- a/exasol/analytics/schema/experiment_name.py +++ b/exasol/analytics/schema/experiment_name.py @@ -1,9 +1,7 @@ from typeguard import typechecked -from exasol.analytics.schema import ( - ExasolIdentifierImpl, - ExasolIdentifier, -) +from exasol.analytics.schema.exasol_identifier import ExasolIdentifier +from exasol.analytics.schema.exasol_identifier_impl import ExasolIdentifierImpl from exasol.analytics.utils.hash_generation_for_object import generate_hash_for_object from exasol.analytics.utils.repr_generation_for_object import generate_repr_for_object @@ -18,8 +16,7 @@ def fully_qualified(self) -> str: return self.quoted_name def __eq__(self, other): - return isinstance(other, ExperimentName) and \ - self._name == other.name + return isinstance(other, ExperimentName) and self._name == other.name def __repr__(self): return generate_repr_for_object(self) diff --git a/exasol/analytics/schema/schema_name.py b/exasol/analytics/schema/schema_name.py index 9de4aad2..2133195d 100644 --- a/exasol/analytics/schema/schema_name.py +++ b/exasol/analytics/schema/schema_name.py @@ -1,7 +1,6 @@ from typeguard import typechecked -from exasol.analytics.schema import DBObjectNameImpl - +from exasol.analytics.schema.dbobject_name_impl import DBObjectNameImpl class SchemaName(DBObjectNameImpl): diff --git a/exasol/analytics/schema/table.py b/exasol/analytics/schema/table.py index 786e87cf..985c9e34 100644 --- a/exasol/analytics/schema/table.py +++ b/exasol/analytics/schema/table.py @@ -2,11 +2,9 @@ from typeguard import typechecked -from exasol.analytics.schema import ( - Column, - TableName, - TableLike, -) +from exasol.analytics.schema.column import Column +from exasol.analytics.schema.table_like import TableLike +from exasol.analytics.schema.table_name import TableName class Table(TableLike[TableName]): diff --git a/exasol/analytics/schema/table_builder.py b/exasol/analytics/schema/table_builder.py index 449b5e00..9831d7db 100644 --- a/exasol/analytics/schema/table_builder.py +++ b/exasol/analytics/schema/table_builder.py @@ -1,10 +1,8 @@ -from typing import Union, List +from typing import List, Union -from exasol.analytics.schema import ( - Column, - TableName, - Table, -) +from exasol.analytics.schema.column import Column +from exasol.analytics.schema.table import Table +from exasol.analytics.schema.table_name import TableName class TableBuilder: diff --git a/exasol/analytics/schema/table_like.py b/exasol/analytics/schema/table_like.py index 23986020..c27421de 100644 --- a/exasol/analytics/schema/table_like.py +++ b/exasol/analytics/schema/table_like.py @@ -1,17 +1,15 @@ from abc import ABC -from typing import List, TypeVar, Generic +from typing import Generic, List, TypeVar from typeguard import typechecked -from exasol.analytics.schema import ( - Column, - DBObject, - TableLikeName, -) +from exasol.analytics.schema.column import Column +from exasol.analytics.schema.dbobject import DBObject +from exasol.analytics.schema.table_like_name import TableLikeName from exasol.analytics.utils.hash_generation_for_object import generate_hash_for_object from exasol.analytics.utils.repr_generation_for_object import generate_repr_for_object -NameType = TypeVar('NameType', bound=TableLikeName) +NameType = TypeVar("NameType", bound=TableLikeName) class TableLike(DBObject[NameType], ABC): @@ -20,7 +18,7 @@ class TableLike(DBObject[NameType], ABC): def __init__(self, name: NameType, columns: List[Column]): super().__init__(name) self._columns = columns - if len(self._columns)==0: + if len(self._columns) == 0: raise ValueError("At least one column needed.") unique_column_names = {column.name for column in self.columns} if len(unique_column_names) != len(columns): @@ -31,8 +29,7 @@ def columns(self) -> List[Column]: return list(self._columns) def __eq__(self, other): - return super().__eq__(other) and \ - self._columns == other.columns + return super().__eq__(other) and self._columns == other.columns def __hash__(self): return generate_hash_for_object(self) diff --git a/exasol/analytics/schema/table_like_name.py b/exasol/analytics/schema/table_like_name.py index 6fd8f26e..aa0a9bc7 100644 --- a/exasol/analytics/schema/table_like_name.py +++ b/exasol/analytics/schema/table_like_name.py @@ -1,7 +1,6 @@ from abc import ABC -from exasol.analytics.schema import DBObjectNameWithSchema - +from exasol.analytics.schema.dbobject_name_with_schema import DBObjectNameWithSchema class TableLikeName(DBObjectNameWithSchema, ABC): diff --git a/exasol/analytics/schema/table_like_name_impl.py b/exasol/analytics/schema/table_like_name_impl.py index 352cf91a..d471376d 100644 --- a/exasol/analytics/schema/table_like_name_impl.py +++ b/exasol/analytics/schema/table_like_name_impl.py @@ -2,11 +2,9 @@ from typeguard import typechecked -from exasol.analytics.schema import ( - DBObjectNameWithSchemaImpl, - SchemaName, - TableLikeName, -) +from exasol.analytics.schema.dbobject_name_with_schema_impl import DBObjectNameWithSchemaImpl +from exasol.analytics.schema.schema_name import SchemaName +from exasol.analytics.schema.table_like_name import TableLikeName class TableLikeNameImpl(DBObjectNameWithSchemaImpl, TableLikeName): diff --git a/exasol/analytics/schema/table_name.py b/exasol/analytics/schema/table_name.py index 42b1607c..67b08df0 100644 --- a/exasol/analytics/schema/table_name.py +++ b/exasol/analytics/schema/table_name.py @@ -1,5 +1,4 @@ -from exasol.analytics.schema import TableLikeName - +from exasol.analytics.schema.table_like_name import TableLikeName class TableName(TableLikeName): diff --git a/exasol/analytics/schema/table_name_builder.py b/exasol/analytics/schema/table_name_builder.py index bc52eebb..1d0cab76 100644 --- a/exasol/analytics/schema/table_name_builder.py +++ b/exasol/analytics/schema/table_name_builder.py @@ -1,18 +1,18 @@ -from typing import Union, Optional +from typing import Optional, Union -from exasol.analytics.schema import ( - SchemaName, - TableNameImpl, - TableName, -) +from exasol.analytics.schema.schema_name import SchemaName +from exasol.analytics.schema.table_name import TableName +from exasol.analytics.schema.table_name_impl import TableNameImpl class TableNameBuilder: - def __init__(self, - name: Optional[str] = None, - schema: Optional[SchemaName] = None, - table_name: Optional[TableName] = None): + def __init__( + self, + name: Optional[str] = None, + schema: Optional[SchemaName] = None, + table_name: Optional[TableName] = None, + ): """ Creates a builder for TableName objects, either by copying a TableName object (parameter "table_like_name") or diff --git a/exasol/analytics/schema/table_name_impl.py b/exasol/analytics/schema/table_name_impl.py index 1dceccd7..ad56c64a 100644 --- a/exasol/analytics/schema/table_name_impl.py +++ b/exasol/analytics/schema/table_name_impl.py @@ -2,11 +2,9 @@ from typeguard import typechecked -from exasol.analytics.schema import ( - SchemaName, - TableLikeNameImpl, - TableName, -) +from exasol.analytics.schema.schema_name import SchemaName +from exasol.analytics.schema.table_like_name_impl import TableLikeNameImpl +from exasol.analytics.schema.table_name import TableName class TableNameImpl(TableLikeNameImpl, TableName): diff --git a/exasol/analytics/schema/udf_name.py b/exasol/analytics/schema/udf_name.py index 7201c658..e18b3239 100644 --- a/exasol/analytics/schema/udf_name.py +++ b/exasol/analytics/schema/udf_name.py @@ -1,5 +1,4 @@ -from exasol.analytics.schema import DBObjectNameWithSchema - +from exasol.analytics.schema.dbobject_name_with_schema import DBObjectNameWithSchema class UDFName(DBObjectNameWithSchema): diff --git a/exasol/analytics/schema/udf_name_builder.py b/exasol/analytics/schema/udf_name_builder.py index cf8ef84f..8b224d3e 100644 --- a/exasol/analytics/schema/udf_name_builder.py +++ b/exasol/analytics/schema/udf_name_builder.py @@ -1,18 +1,18 @@ from typing import Optional -from exasol.analytics.schema import ( - UDFNameImpl, - SchemaName, - UDFName, -) +from exasol.analytics.schema.schema_name import SchemaName +from exasol.analytics.schema.udf_name import UDFName +from exasol.analytics.schema.udf_name_impl import UDFNameImpl class UDFNameBuilder: - def __init__(self, - name: Optional[str] = None, - schema: Optional[SchemaName] = None, - udf_name: Optional[UDFName] = None): + def __init__( + self, + name: Optional[str] = None, + schema: Optional[SchemaName] = None, + udf_name: Optional[UDFName] = None, + ): """ Creates a builder for UDFName objects, either by copying a UDFName object (parameter "udf_name") or diff --git a/exasol/analytics/schema/udf_name_impl.py b/exasol/analytics/schema/udf_name_impl.py index 67da9c8b..ed7c65ea 100644 --- a/exasol/analytics/schema/udf_name_impl.py +++ b/exasol/analytics/schema/udf_name_impl.py @@ -2,11 +2,9 @@ from typeguard import typechecked -from exasol.analytics.schema import ( - DBObjectNameWithSchemaImpl, - SchemaName, - UDFName, -) +from exasol.analytics.schema.dbobject_name_with_schema_impl import DBObjectNameWithSchemaImpl +from exasol.analytics.schema.schema_name import SchemaName +from exasol.analytics.schema.udf_name import UDFName class UDFNameImpl(DBObjectNameWithSchemaImpl, UDFName): diff --git a/exasol/analytics/schema/view.py b/exasol/analytics/schema/view.py index 907a84e9..3baa0b2e 100644 --- a/exasol/analytics/schema/view.py +++ b/exasol/analytics/schema/view.py @@ -2,11 +2,9 @@ from typeguard import typechecked -from exasol.analytics.schema import ( - Column, - TableLike, - ViewName, -) +from exasol.analytics.schema.column import Column +from exasol.analytics.schema.table_like import TableLike +from exasol.analytics.schema.view_name import ViewName class View(TableLike[ViewName]): diff --git a/exasol/analytics/schema/view_name.py b/exasol/analytics/schema/view_name.py index 9353e385..202f1980 100644 --- a/exasol/analytics/schema/view_name.py +++ b/exasol/analytics/schema/view_name.py @@ -1,5 +1,4 @@ -from exasol.analytics.schema import TableLikeName - +from exasol.analytics.schema.table_like_name import TableLikeName class ViewName(TableLikeName): diff --git a/exasol/analytics/schema/view_name_builder.py b/exasol/analytics/schema/view_name_builder.py index 33b9e328..6bab966c 100644 --- a/exasol/analytics/schema/view_name_builder.py +++ b/exasol/analytics/schema/view_name_builder.py @@ -1,19 +1,18 @@ from typing import Optional -from exasol.analytics.schema import ( - SchemaName, - ViewNameImpl, - ViewName, -) - +from exasol.analytics.schema.schema_name import SchemaName +from exasol.analytics.schema.view_name import ViewName +from exasol.analytics.schema.view_name_impl import ViewNameImpl class ViewNameBuilder: - def __init__(self, - name: Optional[str] = None, - schema: Optional[SchemaName] = None, - view_name: Optional[ViewName] = None): + def __init__( + self, + name: Optional[str] = None, + schema: Optional[SchemaName] = None, + view_name: Optional[ViewName] = None, + ): """ Creates a builder for ViewName objects, either by copying a ViewName (parameter view_name) object or diff --git a/exasol/analytics/schema/view_name_impl.py b/exasol/analytics/schema/view_name_impl.py index ab0fa471..1de9c452 100644 --- a/exasol/analytics/schema/view_name_impl.py +++ b/exasol/analytics/schema/view_name_impl.py @@ -2,11 +2,9 @@ from typeguard import typechecked -from exasol.analytics.schema import ( - SchemaName, - TableLikeNameImpl, - ViewName, -) +from exasol.analytics.schema.schema_name import SchemaName +from exasol.analytics.schema.table_like_name_impl import TableLikeNameImpl +from exasol.analytics.schema.view_name import ViewName class ViewNameImpl(TableLikeNameImpl, ViewName): diff --git a/exasol/analytics/sql_executor/interface.py b/exasol/analytics/sql_executor/interface.py index 73d73b7f..bb9f4079 100644 --- a/exasol/analytics/sql_executor/interface.py +++ b/exasol/analytics/sql_executor/interface.py @@ -4,7 +4,6 @@ from exasol.analytics.schema import Column - class ResultSet(ABC): @abstractmethod def __iter__(self): diff --git a/exasol/analytics/sql_executor/pyexasol_impl.py b/exasol/analytics/sql_executor/pyexasol_impl.py index 2287e96d..32d4c69f 100644 --- a/exasol/analytics/sql_executor/pyexasol_impl.py +++ b/exasol/analytics/sql_executor/pyexasol_impl.py @@ -1,15 +1,10 @@ -from typing import List, Any, Tuple +from typing import Any, List, Tuple import pyexasol from pyexasol import ExaStatement -from exasol.analytics.schema import ( - Column, - ColumnType, - ColumnName, - ColumnNameBuilder, -) -from exasol.analytics.sql_executor.interface import SQLExecutor, ResultSet +from exasol.analytics.schema import Column, ColumnName, ColumnNameBuilder, ColumnType +from exasol.analytics.sql_executor.interface import ResultSet, SQLExecutor SRID = "srid" @@ -56,17 +51,27 @@ def columns(self) -> List[Column]: ColumnNameBuilder.create(column_name), ColumnType( name=column_type["type"], - precision=column_type[PRECISION] if PRECISION in column_type else None, + precision=( + column_type[PRECISION] if PRECISION in column_type else None + ), scale=column_type[SCALE] if SCALE in column_type else None, size=column_type[SIZE] if SIZE in column_type else None, - characterSet=column_type[CHARACTER_SET] if CHARACTER_SET in column_type else None, - withLocalTimeZone=column_type[ - WITH_LOCAL_TIME_ZONE] if WITH_LOCAL_TIME_ZONE in column_type else None, + characterSet=( + column_type[CHARACTER_SET] + if CHARACTER_SET in column_type + else None + ), + withLocalTimeZone=( + column_type[WITH_LOCAL_TIME_ZONE] + if WITH_LOCAL_TIME_ZONE in column_type + else None + ), fraction=column_type[FRACTION] if FRACTION in column_type else None, srid=column_type[SRID] if SRID in column_type else None, - ) + ), ) - for column_name, column_type in self.statement.columns().items()] + for column_name, column_type in self.statement.columns().items() + ] return columns def close(self): diff --git a/exasol/analytics/sql_executor/testing/mock_result_set.py b/exasol/analytics/sql_executor/testing/mock_result_set.py index e5e9f949..56b97b2c 100644 --- a/exasol/analytics/sql_executor/testing/mock_result_set.py +++ b/exasol/analytics/sql_executor/testing/mock_result_set.py @@ -1,24 +1,24 @@ import itertools -from typing import Tuple, List, Optional +from typing import List, Optional, Tuple from exasol.analytics.schema import Column - from exasol.analytics.sql_executor.interface import ResultSet class MockResultSet(ResultSet): - def __init__(self, - rows: Optional[List[Tuple]] = None, - columns: Optional[List[Column]] = None - ): + def __init__( + self, rows: Optional[List[Tuple]] = None, columns: Optional[List[Column]] = None + ): self._columns = columns self._rows = rows if rows is not None: if self._columns is not None: for row in rows: if len(row) != len(self._columns): - raise AssertionError(f"Row {row} doesn't fit columns {self._columns}") + raise AssertionError( + f"Row {row} doesn't fit columns {self._columns}" + ) self._iter = self._rows.__iter__() def __iter__(self): diff --git a/exasol/analytics/sql_executor/testing/mock_sql_executor.py b/exasol/analytics/sql_executor/testing/mock_sql_executor.py index 61a8bfed..70f34a5e 100644 --- a/exasol/analytics/sql_executor/testing/mock_sql_executor.py +++ b/exasol/analytics/sql_executor/testing/mock_sql_executor.py @@ -1,9 +1,9 @@ import dataclasses import difflib from inspect import cleandoc -from typing import Optional, List, Dict, Tuple +from typing import Dict, List, Optional, Tuple -from exasol.analytics.sql_executor.interface import SQLExecutor, ResultSet +from exasol.analytics.sql_executor.interface import ResultSet, SQLExecutor from exasol.analytics.sql_executor.testing.mock_result_set import MockResultSet @@ -26,10 +26,15 @@ def execute(self, actual_query: str) -> ResultSet: next_expected_query = next(self._expected_query_iterator) expected_query = next_expected_query.expected_query diff = "\n".join( - difflib.unified_diff(str(expected_query).split("\n"), actual_query.split("\n"), - "Expected Query", "Actual Query")) - assert expected_query == actual_query, \ - cleandoc(f"""Expected and actual query don't match: + difflib.unified_diff( + str(expected_query).split("\n"), + actual_query.split("\n"), + "Expected Query", + "Actual Query", + ) + ) + assert expected_query == actual_query, cleandoc( + f"""Expected and actual query don't match: Expected Query: --------------- {expected_query} @@ -44,7 +49,8 @@ def execute(self, actual_query: str) -> ResultSet: ----- {diff} -""") +""" + ) return next_expected_query.mock_result_set except StopIteration as e: raise RuntimeError(f"No result set found for query {actual_query}") diff --git a/exasol/analytics/udf/communication/broadcast_operation.py b/exasol/analytics/udf/communication/broadcast_operation.py index b6cfe0e6..4a9bc0b3 100644 --- a/exasol/analytics/udf/communication/broadcast_operation.py +++ b/exasol/analytics/udf/communication/broadcast_operation.py @@ -6,8 +6,14 @@ from exasol.analytics.udf.communication import messages from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.serialization import serialize_message, deserialize_message -from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory, Frame +from exasol.analytics.udf.communication.serialization import ( + deserialize_message, + serialize_message, +) +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Frame, + SocketFactory, +) _LOGGER: FilteringBoundLogger = structlog.getLogger() @@ -17,12 +23,14 @@ class BroadcastOperation: - def __init__(self, - sequence_number: int, - value: Optional[bytes], - localhost_communicator: PeerCommunicator, - multi_node_communicator: PeerCommunicator, - socket_factory: SocketFactory): + def __init__( + self, + sequence_number: int, + value: Optional[bytes], + localhost_communicator: PeerCommunicator, + multi_node_communicator: PeerCommunicator, + socket_factory: SocketFactory, + ): self._socket_factory = socket_factory self._value = value self._sequence_number = sequence_number @@ -55,13 +63,13 @@ def _forward_from_multi_node_leader(self) -> bytes: self._logger.info("_forward_from_multi_node_leader") value_frame = self.receive_value_frame_from_multi_node_leader() leader = self._localhost_communicator.leader - peers = [peer for peer in self._localhost_communicator.peers() if peer != leader] + peers = [ + peer for peer in self._localhost_communicator.peers() if peer != leader + ] for peer in peers: frames = self._construct_broadcast_message( - destination=peer, - leader=leader, - value_frame=value_frame + destination=peer, leader=leader, value_frame=value_frame ) self._localhost_communicator.send(peer=peer, message=frames) @@ -87,14 +95,14 @@ def _send_messages_to_local_leaders(self): self._logger.info("_send_messages_to_local_leaders") leader = self._multi_node_communicator.leader - peers = [peer for peer in self._multi_node_communicator.peers() if peer != leader] + peers = [ + peer for peer in self._multi_node_communicator.peers() if peer != leader + ] for peer in peers: value_frame = self._socket_factory.create_frame(self._value) frames = self._construct_broadcast_message( - destination=peer, - leader=leader, - value_frame=value_frame + destination=peer, leader=leader, value_frame=value_frame ) self._multi_node_communicator.send(peer=peer, message=frames) @@ -105,9 +113,7 @@ def _send_messages_to_local_peers_from_multi_node_leaders(self): for peer in peers: value_frame = self._socket_factory.create_frame(self._value) frames = self._construct_broadcast_message( - destination=peer, - leader=leader, - value_frame=value_frame + destination=peer, leader=leader, value_frame=value_frame ) self._localhost_communicator.send(peer=peer, message=frames) @@ -116,23 +122,29 @@ def _check_sequence_number(self, specific_message_obj: messages.Broadcast): raise RuntimeError( f"Got message with different sequence number. " f"We expect the sequence number {self._sequence_number} " - f"but we got {self._sequence_number} in message {specific_message_obj}") + f"but we got {self._sequence_number} in message {specific_message_obj}" + ) - def _get_and_check_specific_message_obj(self, message: messages.Message) -> messages.Broadcast: + def _get_and_check_specific_message_obj( + self, message: messages.Message + ) -> messages.Broadcast: specific_message_obj = message.__root__ if not isinstance(specific_message_obj, messages.Broadcast): - raise TypeError(f"Received the wrong message type. " - f"Expected {messages.Broadcast.__name__} got {type(message)}. " - f"For message {message}.") + raise TypeError( + f"Received the wrong message type. " + f"Expected {messages.Broadcast.__name__} got {type(message)}. " + f"For message {message}." + ) return specific_message_obj - def _construct_broadcast_message(self, destination: Peer, leader: Peer, value_frame: Frame): - message = messages.Broadcast(sequence_number=self._sequence_number, - destination=destination, - source=leader) + def _construct_broadcast_message( + self, destination: Peer, leader: Peer, value_frame: Frame + ): + message = messages.Broadcast( + sequence_number=self._sequence_number, + destination=destination, + source=leader, + ) serialized_message = serialize_message(message) - frames = [ - self._socket_factory.create_frame(serialized_message), - value_frame - ] + frames = [self._socket_factory.create_frame(serialized_message), value_frame] return frames diff --git a/exasol/analytics/udf/communication/communicator.py b/exasol/analytics/udf/communication/communicator.py index e0edf8af..88f78c9a 100644 --- a/exasol/analytics/udf/communication/communicator.py +++ b/exasol/analytics/udf/communication/communicator.py @@ -1,9 +1,9 @@ -from typing import Optional, List +from typing import List, Optional from exasol.analytics.udf.communication.broadcast_operation import BroadcastOperation from exasol.analytics.udf.communication.discovery import localhost, multi_node from exasol.analytics.udf.communication.gather_operation import GatherOperation -from exasol.analytics.udf.communication.ip_address import Port, IPAddress +from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory @@ -13,21 +13,22 @@ class Communicator: - def __init__(self, - multi_node_discovery_ip: IPAddress, - multi_node_discovery_port: Port, - local_discovery_port: Port, - node_name: str, - instance_name: str, - listen_ip: IPAddress, - group_identifier: str, - number_of_nodes: int, - number_of_instances_per_node: int, - is_discovery_leader_node: bool, - socket_factory: SocketFactory, - localhost_communicator_factory: localhost.CommunicatorFactory = localhost.CommunicatorFactory(), - multi_node_communicator_factory: multi_node.CommunicatorFactory = multi_node.CommunicatorFactory(), - ): + def __init__( + self, + multi_node_discovery_ip: IPAddress, + multi_node_discovery_port: Port, + local_discovery_port: Port, + node_name: str, + instance_name: str, + listen_ip: IPAddress, + group_identifier: str, + number_of_nodes: int, + number_of_instances_per_node: int, + is_discovery_leader_node: bool, + socket_factory: SocketFactory, + localhost_communicator_factory: localhost.CommunicatorFactory = localhost.CommunicatorFactory(), + multi_node_communicator_factory: multi_node.CommunicatorFactory = multi_node.CommunicatorFactory(), + ): self._number_of_nodes = number_of_nodes self._number_of_instances_per_node = number_of_instances_per_node self._group_identifier = group_identifier @@ -57,8 +58,8 @@ def _create_multi_node_communicator(self) -> Optional[PeerCommunicator]: if self._localhost_communicator.rank == LOCALHOST_LEADER_RANK: discovery_socket_factory = multi_node.DiscoverySocketFactory() is_discovery_leader = ( - self._localhost_communicator.rank == LOCALHOST_LEADER_RANK - and self._is_discovery_leader_node + self._localhost_communicator.rank == LOCALHOST_LEADER_RANK + and self._is_discovery_leader_node ) peer_communicator = self._multi_node_communicator_factory.create( group_identifier=multi_node_group_identifier, @@ -69,7 +70,8 @@ def _create_multi_node_communicator(self) -> Optional[PeerCommunicator]: discovery_ip=self._multi_node_discovery_ip, discovery_port=self._multi_node_discovery_port, socket_factory=self._socket_factory, - discovery_socket_factory=discovery_socket_factory) + discovery_socket_factory=discovery_socket_factory, + ) return peer_communicator else: return None @@ -85,24 +87,31 @@ def _create_localhost_communicator(self) -> PeerCommunicator: listen_ip=self._localhost_listen_ip, discovery_port=self._localhost_discovery_port, socket_factory=self._socket_factory, - discovery_socket_factory=discovery_socket_factory) + discovery_socket_factory=discovery_socket_factory, + ) return peer_communicator def gather(self, value: bytes) -> Optional[List[bytes]]: sequence_number = self._next_sequence_number() - gather = GatherOperation(sequence_number=sequence_number, value=value, - localhost_communicator=self._localhost_communicator, - multi_node_communicator=self._multi_node_communicator, - socket_factory=self._socket_factory, - number_of_instances_per_node=self._number_of_instances_per_node) + gather = GatherOperation( + sequence_number=sequence_number, + value=value, + localhost_communicator=self._localhost_communicator, + multi_node_communicator=self._multi_node_communicator, + socket_factory=self._socket_factory, + number_of_instances_per_node=self._number_of_instances_per_node, + ) return gather() def broadcast(self, value: Optional[bytes]) -> bytes: sequence_number = self._next_sequence_number() - operation = BroadcastOperation(sequence_number=sequence_number, value=value, - localhost_communicator=self._localhost_communicator, - multi_node_communicator=self._multi_node_communicator, - socket_factory=self._socket_factory) + operation = BroadcastOperation( + sequence_number=sequence_number, + value=value, + localhost_communicator=self._localhost_communicator, + multi_node_communicator=self._multi_node_communicator, + socket_factory=self._socket_factory, + ) return operation() def is_multi_node_leader(self): diff --git a/exasol/analytics/udf/communication/connection_info.py b/exasol/analytics/udf/communication/connection_info.py index 5802c647..23039083 100644 --- a/exasol/analytics/udf/communication/connection_info.py +++ b/exasol/analytics/udf/communication/connection_info.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from exasol.analytics.udf.communication.ip_address import Port, IPAddress +from exasol.analytics.udf.communication.ip_address import IPAddress, Port class ConnectionInfo(BaseModel, frozen=True): diff --git a/exasol/analytics/udf/communication/discovery/localhost/__init__.py b/exasol/analytics/udf/communication/discovery/localhost/__init__.py index 303ab9e0..eb84abd5 100644 --- a/exasol/analytics/udf/communication/discovery/localhost/__init__.py +++ b/exasol/analytics/udf/communication/discovery/localhost/__init__.py @@ -1,3 +1,3 @@ +from .communicator import CommunicatorFactory from .discovery_socket import DiscoverySocket, DiscoverySocketFactory from .discovery_strategy import DiscoveryStrategy -from .communicator import CommunicatorFactory diff --git a/exasol/analytics/udf/communication/discovery/localhost/communicator.py b/exasol/analytics/udf/communication/discovery/localhost/communicator.py index 66fbac7a..48c63c41 100644 --- a/exasol/analytics/udf/communication/discovery/localhost/communicator.py +++ b/exasol/analytics/udf/communication/discovery/localhost/communicator.py @@ -1,24 +1,28 @@ -from exasol.analytics.udf.communication.discovery import localhost +from exasol.analytics.udf.communication.discovery.localhost.discovery_socket import DiscoverySocketFactory +from exasol.analytics.udf.communication.discovery.localhost.discovery_strategy import DiscoveryStrategy from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import \ - ForwardRegisterPeerConfig -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig +from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import ( + ForwardRegisterPeerConfig, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory class CommunicatorFactory: def create( - self, - name: str, - group_identifier: str, - number_of_instances: int, - listen_ip: IPAddress, - discovery_port: Port, - socket_factory: SocketFactory, - discovery_socket_factory: localhost.DiscoverySocketFactory) -> PeerCommunicator: + self, + name: str, + group_identifier: str, + number_of_instances: int, + listen_ip: IPAddress, + discovery_port: Port, + socket_factory: SocketFactory, + discovery_socket_factory: DiscoverySocketFactory, + ) -> PeerCommunicator: peer_communicator = PeerCommunicator( name=name, number_of_peers=number_of_instances, @@ -30,9 +34,9 @@ def create( is_enabled=False, ) ), - socket_factory=socket_factory + socket_factory=socket_factory, ) - discovery = localhost.DiscoveryStrategy( + discovery = DiscoveryStrategy( port=discovery_port, timeout_in_seconds=120, time_between_ping_messages_in_seconds=1, diff --git a/exasol/analytics/udf/communication/discovery/localhost/discovery_socket.py b/exasol/analytics/udf/communication/discovery/localhost/discovery_socket.py index 845ea3e0..b230c021 100644 --- a/exasol/analytics/udf/communication/discovery/localhost/discovery_socket.py +++ b/exasol/analytics/udf/communication/discovery/localhost/discovery_socket.py @@ -8,21 +8,27 @@ class DiscoverySocket: def __init__(self, port: Port): self._port = port self._broadcast_ip = IPAddress(ip_address="127.255.255.255") - self._udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + self._udp_socket = socket.socket( + socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP + ) self._udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) self._udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) self._udp_socket.bind((self._broadcast_ip.ip_address, self._port.port)) def send(self, message: bytes): - self._udp_socket.sendto(message, (self._broadcast_ip.ip_address, self._port.port)) + self._udp_socket.sendto( + message, (self._broadcast_ip.ip_address, self._port.port) + ) def recvfrom(self, timeout_in_seconds: float) -> bytes: if timeout_in_seconds < 0.0: - raise ValueError(f"Timeout needs to be larger than or equal to 0.0, but got {timeout_in_seconds}") + raise ValueError( + f"Timeout needs to be larger than or equal to 0.0, but got {timeout_in_seconds}" + ) # We need to adjust the timeout with a very small number, to avoid 0.0, # because this leads the following error # BlockingIOError: [Errno 11] Resource temporarily unavailable - adjusted_timeout = timeout_in_seconds + 10 ** -9 + adjusted_timeout = timeout_in_seconds + 10**-9 self._udp_socket.settimeout(adjusted_timeout) data = self._udp_socket.recv(1024) return data diff --git a/exasol/analytics/udf/communication/discovery/localhost/discovery_strategy.py b/exasol/analytics/udf/communication/discovery/localhost/discovery_strategy.py index 9aa5e1cd..1252e285 100644 --- a/exasol/analytics/udf/communication/discovery/localhost/discovery_strategy.py +++ b/exasol/analytics/udf/communication/discovery/localhost/discovery_strategy.py @@ -1,37 +1,52 @@ import socket import time -from typing import cast, Optional +from typing import Optional, cast from exasol.analytics.udf.communication import messages -from exasol.analytics.udf.communication.discovery.localhost import DiscoverySocket, \ - DiscoverySocketFactory +from exasol.analytics.udf.communication.discovery.localhost.discovery_socket import ( + DiscoverySocket, + DiscoverySocketFactory, +) from exasol.analytics.udf.communication.ip_address import Port -from exasol.analytics.udf.communication.peer_communicator.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.serialization import serialize_message, deserialize_message +from exasol.analytics.udf.communication.peer_communicator.peer_communicator import ( + PeerCommunicator, +) +from exasol.analytics.udf.communication.serialization import ( + deserialize_message, + serialize_message, +) -NANOSECONDS_PER_SECOND = 10 ** 9 +NANOSECONDS_PER_SECOND = 10**9 def _convert_to_ping_message(serialized_message: bytes) -> messages.Ping: - ping_message = cast(messages.Ping, deserialize_message(serialized_message, messages.Ping)) + ping_message = cast( + messages.Ping, deserialize_message(serialized_message, messages.Ping) + ) return ping_message class DiscoveryStrategy: - def __init__(self, - port: Port, - timeout_in_seconds: int, - time_between_ping_messages_in_seconds: float, - peer_communicator: PeerCommunicator, - local_discovery_socket_factory: DiscoverySocketFactory): + def __init__( + self, + port: Port, + timeout_in_seconds: int, + time_between_ping_messages_in_seconds: float, + peer_communicator: PeerCommunicator, + local_discovery_socket_factory: DiscoverySocketFactory, + ): self._peer_communicator = peer_communicator - self._time_between_ping_messages_in_seconds = float(time_between_ping_messages_in_seconds) + self._time_between_ping_messages_in_seconds = float( + time_between_ping_messages_in_seconds + ) self._local_discovery_socket = local_discovery_socket_factory.create(port=port) self._timeout_in_ns = timeout_in_seconds * NANOSECONDS_PER_SECOND def _has_discovery_timed_out(self, begin_time_ns: int) -> bool: - time_left_until_timeout = self._time_left_until_discovery_timeout_in_ns(begin_time_ns) + time_left_until_timeout = self._time_left_until_discovery_timeout_in_ns( + begin_time_ns + ) return time_left_until_timeout == 0 def _time_left_until_discovery_timeout_in_ns(self, begin_time_ns: int) -> int: @@ -48,7 +63,10 @@ def discover_peers(self): self._send_ping() def _should_discovery_end(self, begin_time_ns: int) -> bool: - result = self._peer_communicator.are_all_peers_connected() or self._has_discovery_timed_out(begin_time_ns) + result = ( + self._peer_communicator.are_all_peers_connected() + or self._has_discovery_timed_out(begin_time_ns) + ) return result def _receive_pings(self, begin_time_ns: int): @@ -63,10 +81,14 @@ def _receive_pings(self, begin_time_ns: int): break def _compute_receive_timeout_in_seconds(self, begin_time_ns: int) -> float: - time_left_until_timeout_in_seconds = \ - self._time_left_until_discovery_timeout_in_ns(begin_time_ns) / NANOSECONDS_PER_SECOND - timeout_in_seconds = min(time_left_until_timeout_in_seconds, - self._time_between_ping_messages_in_seconds) + time_left_until_timeout_in_seconds = ( + self._time_left_until_discovery_timeout_in_ns(begin_time_ns) + / NANOSECONDS_PER_SECOND + ) + timeout_in_seconds = min( + time_left_until_timeout_in_seconds, + self._time_between_ping_messages_in_seconds, + ) return timeout_in_seconds def _handle_serialized_message(self, serialized_message) -> float: @@ -78,15 +100,14 @@ def _handle_serialized_message(self, serialized_message) -> float: def _receive_message(self, timeout_in_seconds: float) -> Optional[bytes]: try: - serialized_message = \ - self._local_discovery_socket.recvfrom(timeout_in_seconds=timeout_in_seconds) + serialized_message = self._local_discovery_socket.recvfrom( + timeout_in_seconds=timeout_in_seconds + ) except socket.timeout as e: serialized_message = None return serialized_message def _send_ping(self): - ping_message = messages.Ping( - source=self._peer_communicator.my_connection_info - ) + ping_message = messages.Ping(source=self._peer_communicator.my_connection_info) serialized_message = serialize_message(ping_message) self._local_discovery_socket.send(serialized_message) diff --git a/exasol/analytics/udf/communication/discovery/multi_node/__init__.py b/exasol/analytics/udf/communication/discovery/multi_node/__init__.py index 303ab9e0..eb84abd5 100644 --- a/exasol/analytics/udf/communication/discovery/multi_node/__init__.py +++ b/exasol/analytics/udf/communication/discovery/multi_node/__init__.py @@ -1,3 +1,3 @@ +from .communicator import CommunicatorFactory from .discovery_socket import DiscoverySocket, DiscoverySocketFactory from .discovery_strategy import DiscoveryStrategy -from .communicator import CommunicatorFactory diff --git a/exasol/analytics/udf/communication/discovery/multi_node/communicator.py b/exasol/analytics/udf/communication/discovery/multi_node/communicator.py index 4a2e7247..c1d3812c 100644 --- a/exasol/analytics/udf/communication/discovery/multi_node/communicator.py +++ b/exasol/analytics/udf/communication/discovery/multi_node/communicator.py @@ -1,26 +1,30 @@ -from exasol.analytics.udf.communication.discovery import multi_node +from exasol.analytics.udf.communication.discovery.multi_node.discovery_socket import DiscoverySocketFactory +from exasol.analytics.udf.communication.discovery.multi_node.discovery_strategy import DiscoveryStrategy from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import \ - ForwardRegisterPeerConfig -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig +from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import ( + ForwardRegisterPeerConfig, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory class CommunicatorFactory: def create( - self, - name: str, - group_identifier: str, - is_discovery_leader: bool, - number_of_instances: int, - listen_ip: IPAddress, - discovery_ip: IPAddress, - discovery_port: Port, - socket_factory: SocketFactory, - discovery_socket_factory: multi_node.DiscoverySocketFactory) -> PeerCommunicator: + self, + name: str, + group_identifier: str, + is_discovery_leader: bool, + number_of_instances: int, + listen_ip: IPAddress, + discovery_ip: IPAddress, + discovery_port: Port, + socket_factory: SocketFactory, + discovery_socket_factory: DiscoverySocketFactory, + ) -> PeerCommunicator: peer_communicator = PeerCommunicator( name=name, number_of_peers=number_of_instances, @@ -32,9 +36,9 @@ def create( is_enabled=True, ) ), - socket_factory=socket_factory + socket_factory=socket_factory, ) - discovery = multi_node.DiscoveryStrategy( + discovery = DiscoveryStrategy( ip_address=discovery_ip, port=discovery_port, timeout_in_seconds=120, diff --git a/exasol/analytics/udf/communication/discovery/multi_node/discovery_socket.py b/exasol/analytics/udf/communication/discovery/multi_node/discovery_socket.py index 4ddfce52..ccc96c02 100644 --- a/exasol/analytics/udf/communication/discovery/multi_node/discovery_socket.py +++ b/exasol/analytics/udf/communication/discovery/multi_node/discovery_socket.py @@ -5,7 +5,7 @@ from exasol.analytics.udf.communication.ip_address import IPAddress, Port -NANO_SECOND = 10 ** -9 +NANO_SECOND = 10**-9 LOGGER: FilteringBoundLogger = structlog.getLogger() @@ -15,12 +15,11 @@ class DiscoverySocket: def __init__(self, ip_address: IPAddress, port: Port): self._port = port self._ip_address = ip_address - self._logger = LOGGER.bind( - ip_address=ip_address.dict(), - port=port.dict() - ) + self._logger = LOGGER.bind(ip_address=ip_address.dict(), port=port.dict()) self._logger.info("create") - self._udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + self._udp_socket = socket.socket( + socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP + ) def bind(self): self._logger.info("bind") @@ -32,7 +31,9 @@ def send(self, message: bytes): def recvfrom(self, timeout_in_seconds: float) -> bytes: if timeout_in_seconds < 0.0: - raise ValueError(f"Timeout needs to be larger than or equal to 0.0, but got {timeout_in_seconds}") + raise ValueError( + f"Timeout needs to be larger than or equal to 0.0, but got {timeout_in_seconds}" + ) # We need to adjust the timeout with a very small number, to avoid 0.0, # because this leads the following error # BlockingIOError: [Errno 11] Resource temporarily unavailable diff --git a/exasol/analytics/udf/communication/discovery/multi_node/discovery_strategy.py b/exasol/analytics/udf/communication/discovery/multi_node/discovery_strategy.py index 4af61044..895811eb 100644 --- a/exasol/analytics/udf/communication/discovery/multi_node/discovery_strategy.py +++ b/exasol/analytics/udf/communication/discovery/multi_node/discovery_strategy.py @@ -1,37 +1,54 @@ import socket import time -from typing import cast, Optional +from typing import Optional, cast from exasol.analytics.udf.communication import messages -from exasol.analytics.udf.communication.discovery.multi_node import DiscoverySocketFactory +from exasol.analytics.udf.communication.discovery.multi_node.discovery_socket import ( + DiscoverySocketFactory, +) from exasol.analytics.udf.communication.ip_address import IPAddress, Port -from exasol.analytics.udf.communication.peer_communicator.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.serialization import serialize_message, deserialize_message +from exasol.analytics.udf.communication.peer_communicator.peer_communicator import ( + PeerCommunicator, +) +from exasol.analytics.udf.communication.serialization import ( + deserialize_message, + serialize_message, +) -NANOSECONDS_PER_SECOND = 10 ** 9 +NANOSECONDS_PER_SECOND = 10**9 def _to_ping_message(serialized_message: bytes) -> messages.Ping: - ping_message = cast(messages.Ping, deserialize_message(serialized_message, messages.Ping)) + ping_message = cast( + messages.Ping, deserialize_message(serialized_message, messages.Ping) + ) return ping_message class DiscoveryStrategy: - def __init__(self, - ip_address: IPAddress, - port: Port, - timeout_in_seconds: int, - time_between_ping_messages_in_seconds: float, - peer_communicator: PeerCommunicator, - global_discovery_socket_factory: DiscoverySocketFactory): + def __init__( + self, + ip_address: IPAddress, + port: Port, + timeout_in_seconds: int, + time_between_ping_messages_in_seconds: float, + peer_communicator: PeerCommunicator, + global_discovery_socket_factory: DiscoverySocketFactory, + ): self._peer_communicator = peer_communicator - self._time_between_ping_messages_in_seconds = float(time_between_ping_messages_in_seconds) - self._global_discovery_socket = global_discovery_socket_factory.create(ip_address=ip_address, port=port) + self._time_between_ping_messages_in_seconds = float( + time_between_ping_messages_in_seconds + ) + self._global_discovery_socket = global_discovery_socket_factory.create( + ip_address=ip_address, port=port + ) self._timeout_in_ns = timeout_in_seconds * NANOSECONDS_PER_SECOND def _has_discovery_timed_out(self, begin_time_ns: int) -> bool: - time_left_until_timeout = self._time_left_until_discovery_timeout_in_ns(begin_time_ns) + time_left_until_timeout = self._time_left_until_discovery_timeout_in_ns( + begin_time_ns + ) return time_left_until_timeout == 0 def _time_left_until_discovery_timeout_in_ns(self, begin_time_ns: int) -> int: @@ -42,7 +59,9 @@ def _time_left_until_discovery_timeout_in_ns(self, begin_time_ns: int) -> int: def discover_peers(self): if not self._peer_communicator.forward_register_peer_config.is_enabled: - raise ValueError("PeerCommunicator.is_forward_register_peer_enabled needs to be true") + raise ValueError( + "PeerCommunicator.is_forward_register_peer_enabled needs to be true" + ) if self._peer_communicator.forward_register_peer_config.is_leader: self._global_discovery_socket.bind() self._send_ping() @@ -69,10 +88,14 @@ def _receive_pings(self, begin_time_ns: int): timeout_in_seconds = self._handle_serialized_message(serialized_message) def _compute_receive_timeout_in_seconds(self, begin_time_ns: int) -> float: - time_left_until_timeout_in_seconds = \ - self._time_left_until_discovery_timeout_in_ns(begin_time_ns) / NANOSECONDS_PER_SECOND - timeout_in_seconds = min(time_left_until_timeout_in_seconds, - self._time_between_ping_messages_in_seconds) + time_left_until_timeout_in_seconds = ( + self._time_left_until_discovery_timeout_in_ns(begin_time_ns) + / NANOSECONDS_PER_SECOND + ) + timeout_in_seconds = min( + time_left_until_timeout_in_seconds, + self._time_between_ping_messages_in_seconds, + ) return timeout_in_seconds def _handle_serialized_message(self, serialized_message) -> float: @@ -84,15 +107,14 @@ def _handle_serialized_message(self, serialized_message) -> float: def _receive_message(self, timeout_in_seconds: float) -> Optional[bytes]: try: - serialized_message = \ - self._global_discovery_socket.recvfrom(timeout_in_seconds=timeout_in_seconds) + serialized_message = self._global_discovery_socket.recvfrom( + timeout_in_seconds=timeout_in_seconds + ) except socket.timeout as e: serialized_message = None return serialized_message def _send_ping(self): - ping_message = messages.Ping( - source=self._peer_communicator.my_connection_info - ) + ping_message = messages.Ping(source=self._peer_communicator.my_connection_info) serialized_message = serialize_message(ping_message) self._global_discovery_socket.send(serialized_message) diff --git a/exasol/analytics/udf/communication/gather_operation.py b/exasol/analytics/udf/communication/gather_operation.py index 41eab08c..e87af337 100644 --- a/exasol/analytics/udf/communication/gather_operation.py +++ b/exasol/analytics/udf/communication/gather_operation.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict +from typing import Dict, List, Optional import structlog from structlog.typing import FilteringBoundLogger @@ -7,23 +7,32 @@ from exasol.analytics.udf.communication.messages import Gather from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.serialization import serialize_message, deserialize_message -from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory, Frame +from exasol.analytics.udf.communication.serialization import ( + deserialize_message, + serialize_message, +) +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Frame, + SocketFactory, +) LOGGER: FilteringBoundLogger = structlog.getLogger() LOCALHOST_LEADER_RANK = 0 MULTI_NODE_LEADER_RANK = 0 + class GatherOperation: - def __init__(self, - sequence_number: int, - value: bytes, - localhost_communicator: PeerCommunicator, - multi_node_communicator: PeerCommunicator, - socket_factory: SocketFactory, - number_of_instances_per_node: int): + def __init__( + self, + sequence_number: int, + value: bytes, + localhost_communicator: PeerCommunicator, + multi_node_communicator: PeerCommunicator, + socket_factory: SocketFactory, + number_of_instances_per_node: int, + ): self._number_of_instances_per_node = number_of_instances_per_node self._socket_factory = socket_factory self._value = value @@ -44,8 +53,9 @@ def _send_to_localhost_leader(self) -> None: position = self._localhost_communicator.rank source = self._localhost_communicator.peer value_frame = self._socket_factory.create_frame(self._value) - frames = self._construct_gather_message(source=source, leader=leader, - position=position, value_frame=value_frame) + frames = self._construct_gather_message( + source=source, leader=leader, position=position, value_frame=value_frame + ) self._logger.info("_send_to_localhost_leader", frame=frames[0].to_bytes()) self._localhost_communicator.send(peer=leader, message=frames) @@ -70,40 +80,51 @@ def _forward_message_for_peer(self, peer: Peer): specific_message_obj = self._get_and_check_specific_message_obj(message) self._check_sequence_number(specific_message_obj) local_position = self._get_and_check_local_position(specific_message_obj) - self._logger.info("_forward_message_for_peer", local_position=local_position, peer=peer) - self._send_to_multi_node_leader(local_position=local_position, - value_frame=frames[1]) + self._logger.info( + "_forward_message_for_peer", local_position=local_position, peer=peer + ) + self._send_to_multi_node_leader( + local_position=local_position, value_frame=frames[1] + ) def _send_local_leader_message_to_multi_node_leader(self): local_position = LOCALHOST_LEADER_RANK value_frame = self._socket_factory.create_frame(self._value) - self._send_to_multi_node_leader(local_position=local_position, value_frame=value_frame) + self._send_to_multi_node_leader( + local_position=local_position, value_frame=value_frame + ) def _send_to_multi_node_leader(self, local_position: int, value_frame: Frame): leader = self._multi_node_communicator.leader source = self._multi_node_communicator.peer - base_position = self._multi_node_communicator.rank * self._number_of_instances_per_node + base_position = ( + self._multi_node_communicator.rank * self._number_of_instances_per_node + ) position = base_position + local_position - frames = self._construct_gather_message(source=source, leader=leader, position=position, - value_frame=value_frame) + frames = self._construct_gather_message( + source=source, leader=leader, position=position, value_frame=value_frame + ) self._logger.info("_send_to_multi_node_leader", frame=frames[0].to_bytes()) self._multi_node_communicator.send(peer=leader, message=frames) - def _construct_gather_message(self, source: Peer, leader: Peer, position: int, value_frame: Frame): - message = Gather(sequence_number=self._sequence_number, - destination=leader, - source=source, - position=position) + def _construct_gather_message( + self, source: Peer, leader: Peer, position: int, value_frame: Frame + ): + message = Gather( + sequence_number=self._sequence_number, + destination=leader, + source=source, + position=position, + ) serialized_message = serialize_message(message) - frames = [ - self._socket_factory.create_frame(serialized_message), - value_frame - ] + frames = [self._socket_factory.create_frame(serialized_message), value_frame] return frames def _handle_messages_from_all_nodes(self) -> List[bytes]: - number_of_instances_in_cluster = self._multi_node_communicator.number_of_peers \ - * self._number_of_instances_per_node + number_of_instances_in_cluster = ( + self._multi_node_communicator.number_of_peers + * self._number_of_instances_per_node + ) result: Dict[int, bytes] = {MULTI_NODE_LEADER_RANK: self._value} localhost_messages_are_done = False multi_node_messages_are_done = False @@ -111,7 +132,9 @@ def _handle_messages_from_all_nodes(self) -> List[bytes]: if not localhost_messages_are_done: localhost_messages_are_done = self._receive_localhost_messages(result) if not multi_node_messages_are_done: - multi_node_messages_are_done = self._receive_multi_node_messages(result, number_of_instances_in_cluster) + multi_node_messages_are_done = self._receive_multi_node_messages( + result, number_of_instances_in_cluster + ) sorted_items = sorted(result.items(), key=lambda kv: kv[0]) return [v for k, v in sorted_items] @@ -126,30 +149,44 @@ def _receive_localhost_messages(self, result: Dict[int, bytes]) -> bool: specific_message_obj = self._get_and_check_specific_message_obj(message) self._check_sequence_number(specific_message_obj) local_position = self._get_and_check_local_position(specific_message_obj) - self._check_if_position_is_already_set(local_position, result, specific_message_obj) + self._check_if_position_is_already_set( + local_position, result, specific_message_obj + ) result[local_position] = frames[1].to_bytes() positions_required_by_localhost = range(self._number_of_instances_per_node) is_done = set(positions_required_by_localhost).issubset(result.keys()) return is_done - def _receive_multi_node_messages(self, result: Dict[int, bytes], number_of_instances_in_cluster: int) -> bool: + def _receive_multi_node_messages( + self, result: Dict[int, bytes], number_of_instances_in_cluster: int + ) -> bool: if self._multi_node_communicator.number_of_peers == 1: return True peers_with_messages = self._multi_node_communicator.poll_peers() for peer in peers_with_messages: frames = self._multi_node_communicator.recv(peer) - self._logger.info("_receive_multi_node_messages", frame=frames[0].to_bytes()) + self._logger.info( + "_receive_multi_node_messages", frame=frames[0].to_bytes() + ) message = deserialize_message(frames[0].to_bytes(), messages.Message) specific_message_obj = self._get_and_check_specific_message_obj(message) self._check_sequence_number(specific_message_obj) - position = self._get_and_check_multi_node_position(specific_message_obj, number_of_instances_in_cluster) - self._check_if_position_is_already_set(position, result, specific_message_obj) + position = self._get_and_check_multi_node_position( + specific_message_obj, number_of_instances_in_cluster + ) + self._check_if_position_is_already_set( + position, result, specific_message_obj + ) result[position] = frames[1].to_bytes() - positions_required_from_other_nodes = range(self._number_of_instances_per_node, number_of_instances_in_cluster) + positions_required_from_other_nodes = range( + self._number_of_instances_per_node, number_of_instances_in_cluster + ) is_done = set(positions_required_from_other_nodes).issubset(result.keys()) return is_done - def _is_result_complete(self, result: Dict[int, bytes], number_of_instances_in_cluster: int) -> bool: + def _is_result_complete( + self, result: Dict[int, bytes], number_of_instances_in_cluster: int + ) -> bool: complete = len(result) == number_of_instances_in_cluster return complete @@ -159,18 +196,25 @@ def _get_and_check_local_position(self, specific_message_obj: Gather) -> int: raise RuntimeError( f"Got message with not allowed position. " f"Position needs to be greater than 0 and smaller than {self._number_of_instances_per_node}, " - f"but we got {local_position} in message {specific_message_obj}") + f"but we got {local_position} in message {specific_message_obj}" + ) return local_position - def _get_and_check_multi_node_position(self, specific_message_obj: Gather, - number_of_instances_in_cluster: int) -> int: + def _get_and_check_multi_node_position( + self, specific_message_obj: Gather, number_of_instances_in_cluster: int + ) -> int: position = specific_message_obj.position - if not (self._number_of_instances_per_node <= position < number_of_instances_in_cluster): + if not ( + self._number_of_instances_per_node + <= position + < number_of_instances_in_cluster + ): raise RuntimeError( f"Got message with not allowed position. " f"Position needs to be greater equal than {self._number_of_instances_per_node} and " f"smaller than {number_of_instances_in_cluster}, " - f"but we got {position} in message {specific_message_obj}") + f"but we got {position} in message {specific_message_obj}" + ) return position def _check_sequence_number(self, specific_message_obj: Gather): @@ -178,17 +222,24 @@ def _check_sequence_number(self, specific_message_obj: Gather): raise RuntimeError( f"Got message with different sequence number. " f"We expect the sequence number {self._sequence_number} " - f"but we got {self._sequence_number} in message {specific_message_obj}") + f"but we got {self._sequence_number} in message {specific_message_obj}" + ) def _get_and_check_specific_message_obj(self, message: messages.Message) -> Gather: specific_message_obj = message.__root__ if not isinstance(specific_message_obj, Gather): - raise TypeError(f"Received the wrong message type. " - f"Expected {Gather.__name__} got {type(message)}. " - f"For message {message}.") + raise TypeError( + f"Received the wrong message type. " + f"Expected {Gather.__name__} got {type(message)}. " + f"For message {message}." + ) return specific_message_obj - def _check_if_position_is_already_set(self, position: int, result: Dict[int, bytes], specific_message_obj: Gather): + def _check_if_position_is_already_set( + self, position: int, result: Dict[int, bytes], specific_message_obj: Gather + ): if position in result: - raise RuntimeError(f"Already received a message for position {position}. " - f"Got message {specific_message_obj}") + raise RuntimeError( + f"Already received a message for position {position}. " + f"Got message {specific_message_obj}" + ) diff --git a/exasol/analytics/udf/communication/messages.py b/exasol/analytics/udf/communication/messages.py index e75e839d..a0bf7b07 100644 --- a/exasol/analytics/udf/communication/messages.py +++ b/exasol/analytics/udf/communication/messages.py @@ -1,4 +1,4 @@ -from typing import Literal, Union, Optional +from typing import Literal, Optional, Union from pydantic import BaseModel @@ -29,8 +29,9 @@ class RegisterPeerComplete(BaseMessage, frozen=True): class PeerRegisterForwarderIsReady(BaseMessage, frozen=True): - message_type: Literal["PeerRegisterForwarderIsReady"] = \ + message_type: Literal["PeerRegisterForwarderIsReady"] = ( "PeerRegisterForwarderIsReady" + ) peer: Peer @@ -154,5 +155,5 @@ class Message(BaseModel, frozen=True): ConnectionIsClosed, Timeout, Gather, - Broadcast + Broadcast, ] diff --git a/exasol/analytics/udf/communication/peer_communicator/abort_timeout_sender.py b/exasol/analytics/udf/communication/peer_communicator/abort_timeout_sender.py index 07bd59ea..ace5d44b 100644 --- a/exasol/analytics/udf/communication/peer_communicator/abort_timeout_sender.py +++ b/exasol/analytics/udf/communication/peer_communicator/abort_timeout_sender.py @@ -19,19 +19,21 @@ class _States(IntFlag): class AbortTimeoutSender: - def __init__(self, - my_connection_info: ConnectionInfo, - peer: Peer, - reason: str, - out_control_socket: Socket, - timer: Timer): + def __init__( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + reason: str, + out_control_socket: Socket, + timer: Timer, + ): self._reason = reason self._timer = timer self._out_control_socket = out_control_socket self._states = _States.INIT self._logger = LOGGER.bind( - peer=peer.dict(), - my_connection_info=my_connection_info.dict()) + peer=peer.dict(), my_connection_info=my_connection_info.dict() + ) def stop(self): self._logger.info("stop") @@ -61,17 +63,19 @@ def _send_timeout_to_frontend(self): class AbortTimeoutSenderFactory: - def create(self, - my_connection_info: ConnectionInfo, - peer: Peer, - reason: str, - out_control_socket: Socket, - timer: Timer) -> AbortTimeoutSender: + def create( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + reason: str, + out_control_socket: Socket, + timer: Timer, + ) -> AbortTimeoutSender: abort_timeout_sender = AbortTimeoutSender( out_control_socket=out_control_socket, timer=timer, my_connection_info=my_connection_info, peer=peer, - reason=reason + reason=reason, ) return abort_timeout_sender diff --git a/exasol/analytics/udf/communication/peer_communicator/acknowledge_register_peer_sender.py b/exasol/analytics/udf/communication/peer_communicator/acknowledge_register_peer_sender.py index 2b486d86..777fba08 100644 --- a/exasol/analytics/udf/communication/peer_communicator/acknowledge_register_peer_sender.py +++ b/exasol/analytics/udf/communication/peer_communicator/acknowledge_register_peer_sender.py @@ -5,24 +5,29 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import \ - RegisterPeerConnection +from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import ( + RegisterPeerConnection, +) from exasol.analytics.udf.communication.peer_communicator.timer import Timer LOGGER: FilteringBoundLogger = structlog.get_logger() -class AcknowledgeRegisterPeerSender(): - def __init__(self, - register_peer_connection: Optional[RegisterPeerConnection], - needs_to_send_for_peer: bool, - my_connection_info: ConnectionInfo, - peer: Peer, - timer: Timer, ): +class AcknowledgeRegisterPeerSender: + def __init__( + self, + register_peer_connection: Optional[RegisterPeerConnection], + needs_to_send_for_peer: bool, + my_connection_info: ConnectionInfo, + peer: Peer, + timer: Timer, + ): self._needs_to_send_for_peer = needs_to_send_for_peer self._register_peer_connection = register_peer_connection if self._needs_to_send_for_peer and self._register_peer_connection is None: - raise ValueError("_register_peer_connection is None while _needs_to_send_for_peer is true") + raise ValueError( + "_register_peer_connection is None while _needs_to_send_for_peer is true" + ) self._my_connection_info = my_connection_info self._timer = timer self._finished = False @@ -58,22 +63,26 @@ def _should_we_send(self): def is_ready_to_stop(self): result = self._finished or not self._needs_to_send_for_peer - self._logger.debug("is_ready_to_stop", finished=self._finished, is_ready_to_stop=result) + self._logger.debug( + "is_ready_to_stop", finished=self._finished, is_ready_to_stop=result + ) return result -class AcknowledgeRegisterPeerSenderFactory(): - def create(self, - register_peer_connection: Optional[RegisterPeerConnection], - needs_to_send_for_peer: bool, - my_connection_info: ConnectionInfo, - peer: Peer, - timer: Timer, ) -> AcknowledgeRegisterPeerSender: +class AcknowledgeRegisterPeerSenderFactory: + def create( + self, + register_peer_connection: Optional[RegisterPeerConnection], + needs_to_send_for_peer: bool, + my_connection_info: ConnectionInfo, + peer: Peer, + timer: Timer, + ) -> AcknowledgeRegisterPeerSender: acknowledge_register_peer_sender = AcknowledgeRegisterPeerSender( register_peer_connection=register_peer_connection, needs_to_send_for_peer=needs_to_send_for_peer, my_connection_info=my_connection_info, peer=peer, - timer=timer + timer=timer, ) return acknowledge_register_peer_sender diff --git a/exasol/analytics/udf/communication/peer_communicator/background_listener_interface.py b/exasol/analytics/udf/communication/peer_communicator/background_listener_interface.py index 66c0b950..6c53737d 100644 --- a/exasol/analytics/udf/communication/peer_communicator/background_listener_interface.py +++ b/exasol/analytics/udf/communication/peer_communicator/background_listener_interface.py @@ -1,6 +1,6 @@ import threading from dataclasses import asdict -from typing import Optional, Iterator, List, Tuple +from typing import Iterator, List, Optional, Tuple import structlog from structlog.types import FilteringBoundLogger @@ -8,38 +8,53 @@ from exasol.analytics.udf.communication import messages from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress -from exasol.analytics.udf.communication.messages import Message, IsReadyToStop, Stop, PrepareToStop +from exasol.analytics.udf.communication.messages import ( + IsReadyToStop, + Message, + PrepareToStop, + Stop, +) from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.background_listener_thread import \ - BackgroundListenerThread +from exasol.analytics.udf.communication.peer_communicator.background_listener_thread import ( + BackgroundListenerThread, +) from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig -from exasol.analytics.udf.communication.serialization import deserialize_message, serialize_message -from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory, \ - SocketType, Socket, PollerFlag, Frame +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) +from exasol.analytics.udf.communication.serialization import ( + deserialize_message, + serialize_message, +) +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Frame, + PollerFlag, + Socket, + SocketFactory, + SocketType, +) LOGGER: FilteringBoundLogger = structlog.get_logger() class BackgroundListenerInterface: - def __init__(self, - name: str, - number_of_peers: int, - socket_factory: SocketFactory, - listen_ip: IPAddress, - group_identifier: str, - config: PeerCommunicatorConfig, - clock: Clock, - trace_logging: bool): + def __init__( + self, + name: str, + number_of_peers: int, + socket_factory: SocketFactory, + listen_ip: IPAddress, + group_identifier: str, + config: PeerCommunicatorConfig, + clock: Clock, + trace_logging: bool, + ): self._socket_factory = socket_factory self._config = config self._name = name self._logger = LOGGER.bind( - name=self._name, - group_identifier=group_identifier, - config=asdict(config) + name=self._name, group_identifier=group_identifier, config=asdict(config) ) out_control_socket_address = self._create_out_control_socket(socket_factory) in_control_socket_address = self._create_in_control_socket(socket_factory) @@ -64,13 +79,17 @@ def __init__(self, def _create_in_control_socket(self, socket_factory: SocketFactory) -> str: self._in_control_socket: Socket = socket_factory.create_socket(SocketType.PAIR) - in_control_socket_address = f"inproc://BackgroundListener_in_control_socket{id(self)}" + in_control_socket_address = ( + f"inproc://BackgroundListener_in_control_socket{id(self)}" + ) self._in_control_socket.bind(in_control_socket_address) return in_control_socket_address def _create_out_control_socket(self, socket_factory: SocketFactory) -> str: self._out_control_socket: Socket = socket_factory.create_socket(SocketType.PAIR) - out_control_socket_address = f"inproc://BackgroundListener_out_control_socket{id(self)}" + out_control_socket_address = ( + f"inproc://BackgroundListener_out_control_socket{id(self)}" + ) self._out_control_socket.bind(out_control_socket_address) return out_control_socket_address @@ -78,7 +97,9 @@ def _set_my_connection_info(self): message = None try: message = self._out_control_socket.receive() - message_obj: messages.Message = deserialize_message(message, messages.Message) + message_obj: messages.Message = deserialize_message( + message, messages.Message + ) specific_message_obj = message_obj.__root__ assert isinstance(specific_message_obj, messages.MyConnectionInfo) self._my_connection_info = specific_message_obj.my_connection_info @@ -98,15 +119,19 @@ def send_payload(self, message: messages.Payload, payload: List[Frame]): frame = self._socket_factory.create_frame(serialized_message) self._in_control_socket.send_multipart([frame] + payload) - def receive_messages(self, timeout_in_milliseconds: Optional[int] = 0) -> Iterator[Tuple[Message, List[Frame]]]: + def receive_messages( + self, timeout_in_milliseconds: Optional[int] = 0 + ) -> Iterator[Tuple[Message, List[Frame]]]: while PollerFlag.POLLIN in self._out_control_socket.poll( - flags=PollerFlag.POLLIN, - timeout_in_ms=timeout_in_milliseconds): + flags=PollerFlag.POLLIN, timeout_in_ms=timeout_in_milliseconds + ): message = None try: timeout_in_milliseconds = 0 frames = self._out_control_socket.receive_multipart() - message_obj: Message = deserialize_message(frames[0].to_bytes(), Message) + message_obj: Message = deserialize_message( + frames[0].to_bytes(), Message + ) yield message_obj, frames except Exception as e: self._logger.exception("Exception", raw_message=message) diff --git a/exasol/analytics/udf/communication/peer_communicator/background_listener_thread.py b/exasol/analytics/udf/communication/peer_communicator/background_listener_thread.py index 94f5da17..0c22c540 100644 --- a/exasol/analytics/udf/communication/peer_communicator/background_listener_thread.py +++ b/exasol/analytics/udf/communication/peer_communicator/background_listener_thread.py @@ -10,38 +10,59 @@ from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.messages import PrepareToStop from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.background_peer_state import \ - BackgroundPeerState -from exasol.analytics.udf.communication.peer_communicator.background_peer_state_builder import \ - BackgroundPeerStateBuilder -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer_builder import ConnectionCloserBuilder +from exasol.analytics.udf.communication.peer_communicator.background_peer_state import ( + BackgroundPeerState, +) +from exasol.analytics.udf.communication.peer_communicator.background_peer_state_builder import ( + BackgroundPeerStateBuilder, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer_builder import ( + ConnectionCloserBuilder, +) from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.connection_establisher_builder import \ - ConnectionEstablisherBuilder -from exasol.analytics.udf.communication.peer_communicator.payload_handler_builder import \ - PayloadHandlerBuilder -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_factory import \ - PayloadMessageSenderFactory -from exasol.analytics.udf.communication.peer_communicator.payload_sender_factory import \ - PayloadSenderFactory -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig -from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import \ - RegisterPeerConnection -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_behavior_config \ - import RegisterPeerForwarderBehaviorConfig -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_builder import \ - RegisterPeerForwarderBuilder -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_builder_parameter \ - import RegisterPeerForwarderBuilderParameter -from exasol.analytics.udf.communication.peer_communicator.send_socket_factory import \ - SendSocketFactory +from exasol.analytics.udf.communication.peer_communicator.connection_establisher_builder import ( + ConnectionEstablisherBuilder, +) +from exasol.analytics.udf.communication.peer_communicator.payload_handler_builder import ( + PayloadHandlerBuilder, +) +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_factory import ( + PayloadMessageSenderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.payload_sender_factory import ( + PayloadSenderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import ( + RegisterPeerConnection, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_behavior_config import ( + RegisterPeerForwarderBehaviorConfig, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_builder import ( + RegisterPeerForwarderBuilder, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_builder_parameter import ( + RegisterPeerForwarderBuilderParameter, +) +from exasol.analytics.udf.communication.peer_communicator.send_socket_factory import ( + SendSocketFactory, +) from exasol.analytics.udf.communication.peer_communicator.sender import SenderFactory from exasol.analytics.udf.communication.peer_communicator.timer import TimerFactory -from exasol.analytics.udf.communication.serialization import deserialize_message, serialize_message -from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory, \ - SocketType, Socket, PollerFlag, Frame +from exasol.analytics.udf.communication.serialization import ( + deserialize_message, + serialize_message, +) +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Frame, + PollerFlag, + Socket, + SocketFactory, + SocketType, +) LOGGER: FilteringBoundLogger = structlog.get_logger() @@ -49,18 +70,28 @@ def create_background_peer_state_builder() -> BackgroundPeerStateBuilder: timer_factory = TimerFactory() sender_factory = SenderFactory() - connection_establisher_builder = ConnectionEstablisherBuilder(timer_factory=timer_factory) + connection_establisher_builder = ConnectionEstablisherBuilder( + timer_factory=timer_factory + ) connection_closer_builder = ConnectionCloserBuilder(timer_factory=timer_factory) - register_peer_forwarder_builder = RegisterPeerForwarderBuilder(timer_factory=timer_factory) - payload_message_sender_factory = PayloadMessageSenderFactory(timer_factory=timer_factory) - payload_sender_factory = PayloadSenderFactory(payload_message_sender_factory=payload_message_sender_factory) - payload_handler_builder = PayloadHandlerBuilder(payload_sender_factory=payload_sender_factory) + register_peer_forwarder_builder = RegisterPeerForwarderBuilder( + timer_factory=timer_factory + ) + payload_message_sender_factory = PayloadMessageSenderFactory( + timer_factory=timer_factory + ) + payload_sender_factory = PayloadSenderFactory( + payload_message_sender_factory=payload_message_sender_factory + ) + payload_handler_builder = PayloadHandlerBuilder( + payload_sender_factory=payload_sender_factory + ) background_peer_state_factory = BackgroundPeerStateBuilder( sender_factory=sender_factory, connection_establisher_builder=connection_establisher_builder, connection_closer_builder=connection_closer_builder, register_peer_forwarder_builder=register_peer_forwarder_builder, - payload_handler_builder=payload_handler_builder + payload_handler_builder=payload_handler_builder, ) return background_peer_state_factory @@ -71,18 +102,20 @@ class Status(enum.Enum): PREPARE_TO_STOP = enum.auto() STOPPED = enum.auto() - def __init__(self, - name: str, - number_of_peers: int, - socket_factory: SocketFactory, - listen_ip: IPAddress, - group_identifier: str, - out_control_socket_address: str, - in_control_socket_address: str, - clock: Clock, - config: PeerCommunicatorConfig, - trace_logging: bool, - background_peer_state_factory: BackgroundPeerStateBuilder = create_background_peer_state_builder()): + def __init__( + self, + name: str, + number_of_peers: int, + socket_factory: SocketFactory, + listen_ip: IPAddress, + group_identifier: str, + out_control_socket_address: str, + in_control_socket_address: str, + clock: Clock, + config: PeerCommunicatorConfig, + trace_logging: bool, + background_peer_state_factory: BackgroundPeerStateBuilder = create_background_peer_state_builder(), + ): self._number_of_peers = number_of_peers self._config = config self._background_peer_state_factory = background_peer_state_factory @@ -93,7 +126,7 @@ def __init__(self, self._logger = LOGGER.bind( name=self._name, group_identifier=group_identifier, - config=dataclasses.asdict(config) + config=dataclasses.asdict(config), ) self._group_identifier = group_identifier self._listen_ip = listen_ip @@ -124,17 +157,23 @@ def _stop(self): self._logger.info("end") def _create_listener_socket(self): - self._listener_socket: Socket = self._socket_factory.create_socket(SocketType.ROUTER) + self._listener_socket: Socket = self._socket_factory.create_socket( + SocketType.ROUTER + ) self._listener_socket.set_identity(self._name) port = self._listener_socket.bind_to_random_port(f"tcp://*") return port def _create_in_control_socket(self): - self._in_control_socket: Socket = self._socket_factory.create_socket(SocketType.PAIR) + self._in_control_socket: Socket = self._socket_factory.create_socket( + SocketType.PAIR + ) self._in_control_socket.connect(self._in_control_socket_address) def _create_out_control_socket(self): - self._out_control_socket: Socket = self._socket_factory.create_socket(SocketType.PAIR) + self._out_control_socket: Socket = self._socket_factory.create_socket( + SocketType.PAIR + ) self._out_control_socket.connect(self._out_control_socket_address) def _create_poller(self): @@ -159,16 +198,24 @@ def _try_send(self): def _handle_message(self): poll = self.poller.poll(timeout_in_ms=self._config.poll_timeout_in_ms) - if self._in_control_socket in poll and PollerFlag.POLLIN in poll[self._in_control_socket]: + if ( + self._in_control_socket in poll + and PollerFlag.POLLIN in poll[self._in_control_socket] + ): message = self._in_control_socket.receive_multipart() self._status = self._handle_control_message(message) - if self._listener_socket in poll and PollerFlag.POLLIN in poll[self._listener_socket]: + if ( + self._listener_socket in poll + and PollerFlag.POLLIN in poll[self._listener_socket] + ): message = self._listener_socket.receive_multipart() self._handle_listener_message(message) def _handle_control_message(self, frames: List[Frame]) -> Status: try: - message_obj: messages.Message = deserialize_message(frames[0].to_bytes(), messages.Message) + message_obj: messages.Message = deserialize_message( + frames[0].to_bytes(), messages.Message + ) specific_message_obj = message_obj.__root__ if isinstance(specific_message_obj, messages.Stop): return BackgroundListenerThread.Status.STOPPED @@ -178,43 +225,52 @@ def _handle_control_message(self, frames: List[Frame]) -> Status: if self._is_register_peer_message_allowed_as_control_message(): self._handle_register_peer_message(specific_message_obj) else: - self._logger.error("RegisterPeer message not allowed", - message_obj=specific_message_obj.dict()) + self._logger.error( + "RegisterPeer message not allowed", + message_obj=specific_message_obj.dict(), + ) elif isinstance(specific_message_obj, messages.Payload): self.send_payload(payload=specific_message_obj, frames=frames) else: - self._logger.error("Unknown message type", message_obj=specific_message_obj.dict()) + self._logger.error( + "Unknown message type", message_obj=specific_message_obj.dict() + ) except Exception as e: self._logger.exception("Exception during handling message", message=frames) return self._status def _is_register_peer_message_allowed_as_control_message(self) -> bool: return ( - ( - self._config.forward_register_peer_config.is_enabled - and self._config.forward_register_peer_config.is_leader - ) - or not self._config.forward_register_peer_config.is_enabled - ) + self._config.forward_register_peer_config.is_enabled + and self._config.forward_register_peer_config.is_leader + ) or not self._config.forward_register_peer_config.is_enabled def send_payload(self, payload: messages.Payload, frames: List[Frame]): self._peer_state[payload.destination].send_payload( - message=payload, frames=frames) - - def _add_peer(self, - peer: Peer, - register_peer_forwarder_behavior_config: RegisterPeerForwarderBehaviorConfig = - RegisterPeerForwarderBehaviorConfig()): - if peer.connection_info.group_identifier != self._my_connection_info.group_identifier: - self._logger.error("Peer belongs to a different group", - my_connection_info=self._my_connection_info.dict(), - peer=peer.dict()) + message=payload, frames=frames + ) + + def _add_peer( + self, + peer: Peer, + register_peer_forwarder_behavior_config: RegisterPeerForwarderBehaviorConfig = RegisterPeerForwarderBehaviorConfig(), + ): + if ( + peer.connection_info.group_identifier + != self._my_connection_info.group_identifier + ): + self._logger.error( + "Peer belongs to a different group", + my_connection_info=self._my_connection_info.dict(), + peer=peer.dict(), + ) raise ValueError("Peer belongs to a different group") if peer not in self._peer_state: parameter = RegisterPeerForwarderBuilderParameter( register_peer_connection=self._register_peer_connection, timeout_config=self._config.register_peer_forwarder_timeout_config, - behavior_config=register_peer_forwarder_behavior_config) + behavior_config=register_peer_forwarder_behavior_config, + ) self._peer_state[peer] = self._background_peer_state_factory.create( my_connection_info=self._my_connection_info, peer=peer, @@ -225,16 +281,16 @@ def _add_peer(self, connection_establisher_timeout_config=self._config.connection_establisher_timeout_config, connection_closer_timeout_config=self._config.connection_closer_timeout_config, register_peer_forwarder_builder_parameter=parameter, - payload_message_sender_timeout_config=self._config.payload_message_sender_timeout_config + payload_message_sender_timeout_config=self._config.payload_message_sender_timeout_config, ) def _handle_listener_message(self, frames: List[Frame]): - logger = self._logger.bind( - sender_queue_id=frames[0].to_bytes() - ) + logger = self._logger.bind(sender_queue_id=frames[0].to_bytes()) message_content_bytes = frames[1].to_bytes() try: - message_obj: messages.Message = deserialize_message(message_content_bytes, messages.Message) + message_obj: messages.Message = deserialize_message( + message_content_bytes, messages.Message + ) specific_message_obj = message_obj.__root__ if isinstance(specific_message_obj, messages.SynchronizeConnection): self._handle_synchronize_connection(specific_message_obj) @@ -248,7 +304,10 @@ def _handle_listener_message(self, frames: List[Frame]): if self.is_register_peer_message_allowed_as_listener_message(): self._handle_register_peer_message(specific_message_obj) else: - logger.error("RegisterPeer message not allowed", message_obj=specific_message_obj.dict()) + logger.error( + "RegisterPeer message not allowed", + message_obj=specific_message_obj.dict(), + ) elif isinstance(specific_message_obj, messages.AcknowledgeRegisterPeer): self._handle_acknowledge_register_peer_message(specific_message_obj) elif isinstance(specific_message_obj, messages.RegisterPeerComplete): @@ -258,19 +317,30 @@ def _handle_listener_message(self, frames: List[Frame]): elif isinstance(specific_message_obj, messages.AcknowledgePayload): self._handle_acknowledge_payload_message(specific_message_obj) else: - logger.error("Unknown message type", message_obj=specific_message_obj.dict()) + logger.error( + "Unknown message type", message_obj=specific_message_obj.dict() + ) except Exception as e: - logger.exception("Exception during handling message", message_content=message_content_bytes) + logger.exception( + "Exception during handling message", + message_content=message_content_bytes, + ) def is_register_peer_message_allowed_as_listener_message(self) -> bool: - return not self._config.forward_register_peer_config.is_leader \ - and self._config.forward_register_peer_config.is_enabled + return ( + not self._config.forward_register_peer_config.is_leader + and self._config.forward_register_peer_config.is_enabled + ) def _handle_payload_message(self, payload: messages.Payload, frames: List[Frame]): self._peer_state[payload.source].received_payload(payload, frames=frames) - def _handle_acknowledge_payload_message(self, acknowledge_payload: messages.AcknowledgePayload): - self._peer_state[acknowledge_payload.source].received_acknowledge_payload(acknowledge_payload) + def _handle_acknowledge_payload_message( + self, acknowledge_payload: messages.AcknowledgePayload + ): + self._peer_state[acknowledge_payload.source].received_acknowledge_payload( + acknowledge_payload + ) def _handle_synchronize_connection(self, message: messages.SynchronizeConnection): peer = Peer(connection_info=message.source) @@ -287,7 +357,9 @@ def _handle_close_connection(self, message: messages.CloseConnection): self._add_peer(peer) self._peer_state[peer].received_close_connection() - def _handle_acknowledge_close_connection(self, message: messages.AcknowledgeCloseConnection): + def _handle_acknowledge_close_connection( + self, message: messages.AcknowledgeCloseConnection + ): peer = Peer(connection_info=message.source) self._add_peer(peer) self._peer_state[peer].received_acknowledge_close_connection() @@ -297,7 +369,8 @@ def _set_my_connection_info(self, port: int): name=self._name, ipaddress=self._listen_ip, port=Port(port=port), - group_identifier=self._group_identifier) + group_identifier=self._group_identifier, + ) message = messages.MyConnectionInfo(my_connection_info=self._my_connection_info) self._out_control_socket.send(serialize_message(message)) @@ -312,7 +385,7 @@ def _handle_register_peer_message(self, message: messages.RegisterPeer): message.peer, register_peer_forwarder_behavior_config=RegisterPeerForwarderBehaviorConfig( needs_to_send_acknowledge_register_peer=not self._config.forward_register_peer_config.is_leader - ) + ), ) return @@ -321,20 +394,20 @@ def _handle_register_peer_message(self, message: messages.RegisterPeer): register_peer_forwarder_behavior_config=RegisterPeerForwarderBehaviorConfig( needs_to_send_register_peer=True, needs_to_send_acknowledge_register_peer=not self._config.forward_register_peer_config.is_leader, - ) + ), ) def _create_register_peer_connection(self, message: messages.RegisterPeer): successor_send_socket_factory = SendSocketFactory( my_connection_info=self._my_connection_info, peer=message.peer, - socket_factory=self._socket_factory + socket_factory=self._socket_factory, ) if message.source is not None: predecessor_send_socket_factory = SendSocketFactory( my_connection_info=self._my_connection_info, peer=message.source, - socket_factory=self._socket_factory + socket_factory=self._socket_factory, ) else: predecessor_send_socket_factory = None @@ -343,17 +416,27 @@ def _create_register_peer_connection(self, message: messages.RegisterPeer): predecessor_send_socket_factory=predecessor_send_socket_factory, successor=message.peer, successor_send_socket_factory=successor_send_socket_factory, - my_connection_info=self._my_connection_info + my_connection_info=self._my_connection_info, ) - def _handle_acknowledge_register_peer_message(self, message: messages.AcknowledgeRegisterPeer): + def _handle_acknowledge_register_peer_message( + self, message: messages.AcknowledgeRegisterPeer + ): if self._register_peer_connection.successor != message.source: - self._logger.error("AcknowledgeRegisterPeer message not from successor", message_obj=message.dict()) + self._logger.error( + "AcknowledgeRegisterPeer message not from successor", + message_obj=message.dict(), + ) peer = message.peer self._peer_state[peer].received_acknowledge_register_peer() - def _handle_register_peer_complete_message(self, message: messages.RegisterPeerComplete): + def _handle_register_peer_complete_message( + self, message: messages.RegisterPeerComplete + ): if self._register_peer_connection.predecessor != message.source: - self._logger.error("RegisterPeerComplete message not from predecessor", message_obj=message.dict()) + self._logger.error( + "RegisterPeerComplete message not from predecessor", + message_obj=message.dict(), + ) peer = message.peer self._peer_state[peer].received_register_peer_complete() diff --git a/exasol/analytics/udf/communication/peer_communicator/background_peer_state.py b/exasol/analytics/udf/communication/peer_communicator/background_peer_state.py index 01c98732..fc13b1c3 100644 --- a/exasol/analytics/udf/communication/peer_communicator/background_peer_state.py +++ b/exasol/analytics/udf/communication/peer_communicator/background_peer_state.py @@ -5,18 +5,21 @@ from exasol.analytics.udf.communication import messages from exasol.analytics.udf.communication.connection_info import ConnectionInfo -from exasol.analytics.udf.communication import messages from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer import ConnectionCloser -from exasol.analytics.udf.communication.peer_communicator.connection_establisher import \ - ConnectionEstablisher -from exasol.analytics.udf.communication.peer_communicator.payload_handler import PayloadHandler -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder import \ - RegisterPeerForwarder +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer import ( + ConnectionCloser, +) +from exasol.analytics.udf.communication.peer_communicator.connection_establisher import ( + ConnectionEstablisher, +) +from exasol.analytics.udf.communication.peer_communicator.payload_handler import ( + PayloadHandler, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder import ( + RegisterPeerForwarder, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender -from exasol.analytics.udf.communication.socket_factory.abstract \ - import Frame +from exasol.analytics.udf.communication.socket_factory.abstract import Frame LOGGER: FilteringBoundLogger = structlog.get_logger() @@ -24,14 +27,14 @@ class BackgroundPeerState: def __init__( - self, - my_connection_info: ConnectionInfo, - peer: Peer, - sender: Sender, - connection_establisher: ConnectionEstablisher, - connection_closer: ConnectionCloser, - register_peer_forwarder: RegisterPeerForwarder, - payload_handler: PayloadHandler + self, + my_connection_info: ConnectionInfo, + peer: Peer, + sender: Sender, + connection_establisher: ConnectionEstablisher, + connection_closer: ConnectionCloser, + register_peer_forwarder: RegisterPeerForwarder, + payload_handler: PayloadHandler, ): self._connection_closer = connection_closer self._payload_handler = payload_handler @@ -57,9 +60,11 @@ def try_send(self): def _should_we_close_connection(self): is_ready_to_stop = self._is_ready_to_stop() - self._logger.debug("_should_we_send_close_connection", - is_ready_to_stop=is_ready_to_stop, - prepare_to_stop=self._prepare_to_stop) + self._logger.debug( + "_should_we_send_close_connection", + is_ready_to_stop=is_ready_to_stop, + prepare_to_stop=self._prepare_to_stop, + ) return self._prepare_to_stop and is_ready_to_stop def received_synchronize_connection(self): @@ -88,18 +93,24 @@ def prepare_to_stop(self): self._prepare_to_stop = True def _is_ready_to_stop(self): - connection_establisher_is_ready = self._connection_establisher.is_ready_to_stop() - register_peer_forwarder_is_ready = self._register_peer_forwarder.is_ready_to_stop() + connection_establisher_is_ready = ( + self._connection_establisher.is_ready_to_stop() + ) + register_peer_forwarder_is_ready = ( + self._register_peer_forwarder.is_ready_to_stop() + ) payload_handler_is_ready = self._payload_handler.is_ready_to_stop() is_ready_to_stop = ( - connection_establisher_is_ready - and register_peer_forwarder_is_ready - and payload_handler_is_ready + connection_establisher_is_ready + and register_peer_forwarder_is_ready + and payload_handler_is_ready + ) + self._logger.debug( + "background_peer_state_is_ready_to_stop", + connection_establisher_is_ready=connection_establisher_is_ready, + register_peer_forwarder_is_ready=register_peer_forwarder_is_ready, + payload_handler_is_ready=payload_handler_is_ready, ) - self._logger.debug("background_peer_state_is_ready_to_stop", - connection_establisher_is_ready=connection_establisher_is_ready, - register_peer_forwarder_is_ready=register_peer_forwarder_is_ready, - payload_handler_is_ready=payload_handler_is_ready) return is_ready_to_stop def received_close_connection(self): diff --git a/exasol/analytics/udf/communication/peer_communicator/background_peer_state_builder.py b/exasol/analytics/udf/communication/peer_communicator/background_peer_state_builder.py index fd877ae4..9eee377b 100644 --- a/exasol/analytics/udf/communication/peer_communicator/background_peer_state_builder.py +++ b/exasol/analytics/udf/communication/peer_communicator/background_peer_state_builder.py @@ -1,40 +1,54 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.background_peer_state import \ - BackgroundPeerState -from exasol.analytics.udf.communication.peer_communicator.background_peer_state_factory import \ - BackgroundPeerStateFactory -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer_builder import ConnectionCloserBuilder -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer_timeout_config import ConnectionCloserTimeoutConfig +from exasol.analytics.udf.communication.peer_communicator.background_peer_state import ( + BackgroundPeerState, +) +from exasol.analytics.udf.communication.peer_communicator.background_peer_state_factory import ( + BackgroundPeerStateFactory, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer_builder import ( + ConnectionCloserBuilder, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer_timeout_config import ( + ConnectionCloserTimeoutConfig, +) from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.connection_establisher_builder import \ - ConnectionEstablisherBuilder -from exasol.analytics.udf.communication.peer_communicator.connection_establisher_timeout_config \ - import ConnectionEstablisherTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator.payload_handler_builder import \ - PayloadHandlerBuilder -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config \ - import PayloadMessageSenderTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_builder import \ - RegisterPeerForwarderBuilder -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_builder_parameter \ - import RegisterPeerForwarderBuilderParameter +from exasol.analytics.udf.communication.peer_communicator.connection_establisher_builder import ( + ConnectionEstablisherBuilder, +) +from exasol.analytics.udf.communication.peer_communicator.connection_establisher_timeout_config import ( + ConnectionEstablisherTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.payload_handler_builder import ( + PayloadHandlerBuilder, +) +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config import ( + PayloadMessageSenderTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_builder import ( + RegisterPeerForwarderBuilder, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_builder_parameter import ( + RegisterPeerForwarderBuilderParameter, +) from exasol.analytics.udf.communication.peer_communicator.sender import SenderFactory -from exasol.analytics.udf.communication.socket_factory.abstract import Socket, \ - SocketFactory +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Socket, + SocketFactory, +) class BackgroundPeerStateBuilder: - def __init__(self, - connection_establisher_builder: ConnectionEstablisherBuilder, - connection_closer_builder: ConnectionCloserBuilder, - register_peer_forwarder_builder: RegisterPeerForwarderBuilder, - payload_handler_builder: PayloadHandlerBuilder, - sender_factory: SenderFactory, - background_peer_state_factory: BackgroundPeerStateFactory = BackgroundPeerStateFactory()): + def __init__( + self, + connection_establisher_builder: ConnectionEstablisherBuilder, + connection_closer_builder: ConnectionCloserBuilder, + register_peer_forwarder_builder: RegisterPeerForwarderBuilder, + payload_handler_builder: PayloadHandlerBuilder, + sender_factory: SenderFactory, + background_peer_state_factory: BackgroundPeerStateFactory = BackgroundPeerStateFactory(), + ): self._connection_closer_builder = connection_closer_builder self._payload_handler_builder = payload_handler_builder self._connection_establisher_builder = connection_establisher_builder @@ -43,30 +57,31 @@ def __init__(self, self._sender_factory = sender_factory def create( - self, - my_connection_info: ConnectionInfo, - out_control_socket: Socket, - socket_factory: SocketFactory, - peer: Peer, - clock: Clock, - send_socket_linger_time_in_ms: int, - register_peer_forwarder_builder_parameter: RegisterPeerForwarderBuilderParameter, - connection_establisher_timeout_config: ConnectionEstablisherTimeoutConfig, - connection_closer_timeout_config: ConnectionCloserTimeoutConfig, - payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig + self, + my_connection_info: ConnectionInfo, + out_control_socket: Socket, + socket_factory: SocketFactory, + peer: Peer, + clock: Clock, + send_socket_linger_time_in_ms: int, + register_peer_forwarder_builder_parameter: RegisterPeerForwarderBuilderParameter, + connection_establisher_timeout_config: ConnectionEstablisherTimeoutConfig, + connection_closer_timeout_config: ConnectionCloserTimeoutConfig, + payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig, ) -> BackgroundPeerState: sender = self._sender_factory.create( my_connection_info=my_connection_info, socket_factory=socket_factory, peer=peer, - send_socket_linger_time_in_ms=send_socket_linger_time_in_ms) + send_socket_linger_time_in_ms=send_socket_linger_time_in_ms, + ) connection_establisher = self._connection_establisher_builder.create( peer=peer, my_connection_info=my_connection_info, out_control_socket=out_control_socket, clock=clock, sender=sender, - timeout_config=connection_establisher_timeout_config + timeout_config=connection_establisher_timeout_config, ) connection_closer = self._connection_closer_builder.create( peer=peer, @@ -74,7 +89,7 @@ def create( out_control_socket=out_control_socket, clock=clock, sender=sender, - timeout_config=connection_closer_timeout_config + timeout_config=connection_closer_timeout_config, ) register_peer_forwarder = self._register_peer_forwarder_builder.create( peer=peer, @@ -91,7 +106,7 @@ def create( out_control_socket=out_control_socket, payload_message_sender_timeout_config=payload_message_sender_timeout_config, sender=sender, - clock=clock + clock=clock, ) peer_state = self._background_peer_state_factory.create( my_connection_info=my_connection_info, @@ -100,6 +115,6 @@ def create( connection_establisher=connection_establisher, connection_closer=connection_closer, register_peer_forwarder=register_peer_forwarder, - payload_handler=payload_handler + payload_handler=payload_handler, ) return peer_state diff --git a/exasol/analytics/udf/communication/peer_communicator/background_peer_state_factory.py b/exasol/analytics/udf/communication/peer_communicator/background_peer_state_factory.py index 32100d65..5b7e25f4 100644 --- a/exasol/analytics/udf/communication/peer_communicator/background_peer_state_factory.py +++ b/exasol/analytics/udf/communication/peer_communicator/background_peer_state_factory.py @@ -1,28 +1,34 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.background_peer_state import \ - BackgroundPeerState -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer import ConnectionCloser -from exasol.analytics.udf.communication.peer_communicator.connection_establisher import \ - ConnectionEstablisher -from exasol.analytics.udf.communication.peer_communicator.payload_handler import PayloadHandler -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder import \ - RegisterPeerForwarder +from exasol.analytics.udf.communication.peer_communicator.background_peer_state import ( + BackgroundPeerState, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer import ( + ConnectionCloser, +) +from exasol.analytics.udf.communication.peer_communicator.connection_establisher import ( + ConnectionEstablisher, +) +from exasol.analytics.udf.communication.peer_communicator.payload_handler import ( + PayloadHandler, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder import ( + RegisterPeerForwarder, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender class BackgroundPeerStateFactory: def create( - self, - my_connection_info: ConnectionInfo, - peer: Peer, - sender: Sender, - connection_establisher: ConnectionEstablisher, - connection_closer: ConnectionCloser, - register_peer_forwarder: RegisterPeerForwarder, - payload_handler: PayloadHandler + self, + my_connection_info: ConnectionInfo, + peer: Peer, + sender: Sender, + connection_establisher: ConnectionEstablisher, + connection_closer: ConnectionCloser, + register_peer_forwarder: RegisterPeerForwarder, + payload_handler: PayloadHandler, ) -> BackgroundPeerState: return BackgroundPeerState( my_connection_info=my_connection_info, @@ -31,5 +37,5 @@ def create( connection_establisher=connection_establisher, connection_closer=connection_closer, register_peer_forwarder=register_peer_forwarder, - payload_handler=payload_handler + payload_handler=payload_handler, ) diff --git a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/close_connection_sender.py b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/close_connection_sender.py index 4965cc1e..6738cc35 100644 --- a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/close_connection_sender.py +++ b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/close_connection_sender.py @@ -11,11 +11,13 @@ class CloseConnectionSender: - def __init__(self, - my_connection_info: ConnectionInfo, - peer: Peer, - sender: Sender, - timer: Timer): + def __init__( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + sender: Sender, + timer: Timer, + ): self._my_connection_info = my_connection_info self._timer = timer self._sender = sender @@ -23,8 +25,8 @@ def __init__(self, self._send_attempt_count = 0 self._peer = peer self._logger = LOGGER.bind( - peer=peer.dict(), - my_connection_info=my_connection_info.dict()) + peer=peer.dict(), my_connection_info=my_connection_info.dict() + ) self._logger.debug("init") def stop(self): @@ -48,7 +50,9 @@ def _send(self): __root__=messages.CloseConnection( source=self._my_connection_info, destination=self._peer, - attempt=self._send_attempt_count)) + attempt=self._send_attempt_count, + ) + ) self._sender.send(message) def _should_we_send(self): @@ -58,15 +62,14 @@ def _should_we_send(self): class CloseConnectionSenderFactory: - def create(self, - my_connection_info: ConnectionInfo, - peer: Peer, - sender: Sender, - timer: Timer) -> CloseConnectionSender: + def create( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + sender: Sender, + timer: Timer, + ) -> CloseConnectionSender: close_connection_sender = CloseConnectionSender( - my_connection_info=my_connection_info, - peer=peer, - sender=sender, - timer=timer + my_connection_info=my_connection_info, peer=peer, sender=sender, timer=timer ) return close_connection_sender diff --git a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer.py b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer.py index 03ad0525..2ec8141e 100644 --- a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer.py +++ b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer.py @@ -5,25 +5,30 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.messages import Message from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSender -from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.close_connection_sender import \ - CloseConnectionSender -from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_is_closed_sender import \ - ConnectionIsClosedSender +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSender, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.close_connection_sender import ( + CloseConnectionSender, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_is_closed_sender import ( + ConnectionIsClosedSender, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender LOGGER: FilteringBoundLogger = structlog.get_logger() class ConnectionCloser: - def __init__(self, - peer: Peer, - my_connection_info: ConnectionInfo, - sender: Sender, - abort_timeout_sender: AbortTimeoutSender, - connection_is_closed_sender: ConnectionIsClosedSender, - close_connection_sender: CloseConnectionSender): + def __init__( + self, + peer: Peer, + my_connection_info: ConnectionInfo, + sender: Sender, + abort_timeout_sender: AbortTimeoutSender, + connection_is_closed_sender: ConnectionIsClosedSender, + close_connection_sender: CloseConnectionSender, + ): self._close_connection_sender = close_connection_sender self._connection_is_closed_sender = connection_is_closed_sender self._my_connection_info = my_connection_info @@ -36,10 +41,13 @@ def __init__(self, def received_close_connection(self): self._logger.debug("received_synchronize_connection") - self._sender.send(Message( - __root__=messages.AcknowledgeCloseConnection( - source=self._my_connection_info, - destination=self._peer))) + self._sender.send( + Message( + __root__=messages.AcknowledgeCloseConnection( + source=self._my_connection_info, destination=self._peer + ) + ) + ) self._connection_is_closed_sender.received_close_connection() def received_acknowledge_close_connection(self): diff --git a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_builder.py b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_builder.py index 7f5395a9..016d6839 100644 --- a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_builder.py +++ b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_builder.py @@ -1,17 +1,25 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSenderFactory -from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.close_connection_sender import \ - CloseConnectionSenderFactory, CloseConnectionSender -from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer import \ - ConnectionCloser -from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer_factory import \ - ConnectionCloserFactory -from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer_timeout_config import \ - ConnectionCloserTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_is_closed_sender import \ - ConnectionIsClosedSenderFactory, ConnectionIsClosedSender +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSenderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.close_connection_sender import ( + CloseConnectionSender, + CloseConnectionSenderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer import ( + ConnectionCloser, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer_factory import ( + ConnectionCloserFactory, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer_timeout_config import ( + ConnectionCloserTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_is_closed_sender import ( + ConnectionIsClosedSender, + ConnectionIsClosedSenderFactory, +) from exasol.analytics.udf.communication.peer_communicator.clock import Clock from exasol.analytics.udf.communication.peer_communicator.sender import Sender from exasol.analytics.udf.communication.peer_communicator.timer import TimerFactory @@ -20,44 +28,49 @@ class ConnectionCloserBuilder: - def __init__(self, - timer_factory: TimerFactory, - abort_timeout_sender_factory: AbortTimeoutSenderFactory = AbortTimeoutSenderFactory(), - connection_is_closed_sender_factory: ConnectionIsClosedSenderFactory = ConnectionIsClosedSenderFactory(), - close_connection_sender_factory: CloseConnectionSenderFactory = - CloseConnectionSenderFactory(), - connection_closer_factory: ConnectionCloserFactory = ConnectionCloserFactory()): + def __init__( + self, + timer_factory: TimerFactory, + abort_timeout_sender_factory: AbortTimeoutSenderFactory = AbortTimeoutSenderFactory(), + connection_is_closed_sender_factory: ConnectionIsClosedSenderFactory = ConnectionIsClosedSenderFactory(), + close_connection_sender_factory: CloseConnectionSenderFactory = CloseConnectionSenderFactory(), + connection_closer_factory: ConnectionCloserFactory = ConnectionCloserFactory(), + ): self._connection_closer_factory = connection_closer_factory self._timer_factory = timer_factory self._close_connection_sender_factory = close_connection_sender_factory self._connection_is_closed_sender_factory = connection_is_closed_sender_factory self._abort_timeout_sender_factory = abort_timeout_sender_factory - def create(self, - peer: Peer, - my_connection_info: ConnectionInfo, - out_control_socket: Socket, - clock: Clock, - sender: Sender, - timeout_config: ConnectionCloserTimeoutConfig) -> ConnectionCloser: + def create( + self, + peer: Peer, + my_connection_info: ConnectionInfo, + out_control_socket: Socket, + clock: Clock, + sender: Sender, + timeout_config: ConnectionCloserTimeoutConfig, + ) -> ConnectionCloser: close_connection_sender = self._create_close_connection_sender( my_connection_info=my_connection_info, peer=peer, sender=sender, clock=clock, - timeout_config=timeout_config) + timeout_config=timeout_config, + ) abort_timeout_sender = self._create_abort_timeout_sender( my_connection_info=my_connection_info, peer=peer, out_control_socket=out_control_socket, clock=clock, - timeout_config=timeout_config) + timeout_config=timeout_config, + ) connection_is_closed_sender = self._create_connection_is_closed_sender( my_connection_info=my_connection_info, peer=peer, clock=clock, out_control_socket=out_control_socket, - timeout_config=timeout_config + timeout_config=timeout_config, ) return self._connection_closer_factory.create( peer=peer, @@ -65,17 +78,21 @@ def create(self, sender=sender, abort_timeout_sender=abort_timeout_sender, connection_is_closed_sender=connection_is_closed_sender, - close_connection_sender=close_connection_sender + close_connection_sender=close_connection_sender, ) def _create_connection_is_closed_sender( - self, - my_connection_info: ConnectionInfo, peer: Peer, - clock: Clock, out_control_socket: Socket, - timeout_config: ConnectionCloserTimeoutConfig + self, + my_connection_info: ConnectionInfo, + peer: Peer, + clock: Clock, + out_control_socket: Socket, + timeout_config: ConnectionCloserTimeoutConfig, ) -> ConnectionIsClosedSender: timer = self._timer_factory.create( - clock=clock, timeout_in_ms=timeout_config.connection_is_closed_wait_time_in_ms) + clock=clock, + timeout_in_ms=timeout_config.connection_is_closed_wait_time_in_ms, + ) connection_is_closed_sender = self._connection_is_closed_sender_factory.create( out_control_socket=out_control_socket, timer=timer, @@ -84,33 +101,38 @@ def _create_connection_is_closed_sender( ) return connection_is_closed_sender - def _create_abort_timeout_sender(self, - my_connection_info: ConnectionInfo, peer: Peer, - out_control_socket: Socket, clock: Clock, - timeout_config: ConnectionCloserTimeoutConfig): + def _create_abort_timeout_sender( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + out_control_socket: Socket, + clock: Clock, + timeout_config: ConnectionCloserTimeoutConfig, + ): abort_timeout_sender_timer = self._timer_factory.create( - clock=clock, timeout_in_ms=timeout_config.abort_timeout_in_ms) + clock=clock, timeout_in_ms=timeout_config.abort_timeout_in_ms + ) abort_timeout_sender = self._abort_timeout_sender_factory.create( out_control_socket=out_control_socket, timer=abort_timeout_sender_timer, my_connection_info=my_connection_info, peer=peer, - reason="Timeout occurred during establishing connection." + reason="Timeout occurred during establishing connection.", ) return abort_timeout_sender def _create_close_connection_sender( - self, - my_connection_info: ConnectionInfo, peer: Peer, - sender: Sender, clock: Clock, - timeout_config: ConnectionCloserTimeoutConfig + self, + my_connection_info: ConnectionInfo, + peer: Peer, + sender: Sender, + clock: Clock, + timeout_config: ConnectionCloserTimeoutConfig, ) -> CloseConnectionSender: timer = self._timer_factory.create( - clock=clock, timeout_in_ms=timeout_config.close_retry_timeout_in_ms) + clock=clock, timeout_in_ms=timeout_config.close_retry_timeout_in_ms + ) close_connection_sender = self._close_connection_sender_factory.create( - my_connection_info=my_connection_info, - peer=peer, - sender=sender, - timer=timer + my_connection_info=my_connection_info, peer=peer, sender=sender, timer=timer ) return close_connection_sender diff --git a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_factory.py b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_factory.py index b251e4f1..ded21411 100644 --- a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_factory.py +++ b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_factory.py @@ -1,31 +1,36 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSender -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.close_connection_sender import CloseConnectionSender -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer import ConnectionCloser -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_is_closed_sender import ConnectionIsClosedSender +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSender, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.close_connection_sender import ( + CloseConnectionSender, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer import ( + ConnectionCloser, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_is_closed_sender import ( + ConnectionIsClosedSender, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender class ConnectionCloserFactory: - def create(self, - peer: Peer, - my_connection_info: ConnectionInfo, - sender: Sender, - abort_timeout_sender: AbortTimeoutSender, - connection_is_closed_sender: ConnectionIsClosedSender, - close_connection_sender: CloseConnectionSender - ) -> ConnectionCloser: + def create( + self, + peer: Peer, + my_connection_info: ConnectionInfo, + sender: Sender, + abort_timeout_sender: AbortTimeoutSender, + connection_is_closed_sender: ConnectionIsClosedSender, + close_connection_sender: CloseConnectionSender, + ) -> ConnectionCloser: return ConnectionCloser( peer=peer, my_connection_info=my_connection_info, sender=sender, abort_timeout_sender=abort_timeout_sender, connection_is_closed_sender=connection_is_closed_sender, - close_connection_sender=close_connection_sender + close_connection_sender=close_connection_sender, ) diff --git a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_timeout_config.py b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_timeout_config.py index fcbe39ec..2ec7c54e 100644 --- a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_timeout_config.py +++ b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_closer_timeout_config.py @@ -5,4 +5,4 @@ class ConnectionCloserTimeoutConfig: close_retry_timeout_in_ms: int = 1000 abort_timeout_in_ms: int = 100000 - connection_is_closed_wait_time_in_ms: int = 10000 \ No newline at end of file + connection_is_closed_wait_time_in_ms: int = 10000 diff --git a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_is_closed_sender.py b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_is_closed_sender.py index f5c8cebc..d0120838 100644 --- a/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_is_closed_sender.py +++ b/exasol/analytics/udf/communication/peer_communicator/background_thread/connection_closer/connection_is_closed_sender.py @@ -21,11 +21,13 @@ class _States(IntFlag): class ConnectionIsClosedSender: - def __init__(self, - out_control_socket: Socket, - peer: Peer, - my_connection_info: ConnectionInfo, - timer: Timer): + def __init__( + self, + out_control_socket: Socket, + peer: Peer, + my_connection_info: ConnectionInfo, + timer: Timer, + ): self._timer = timer self._peer = peer self._out_control_socket = out_control_socket @@ -56,23 +58,21 @@ def _should_we_send(self): is_time = self._timer.is_time() send_time_dependent = _States.RECEIVED_CLOSE_CONNECTION in self._states send_time_independent = ( - _States.RECEIVED_CLOSE_CONNECTION in self._states - and _States.RECEIVED_ACKKNOWLEDGE_CLOSE_CONNECTION in self._states + _States.RECEIVED_CLOSE_CONNECTION in self._states + and _States.RECEIVED_ACKKNOWLEDGE_CLOSE_CONNECTION in self._states ) finished = _States.FINISHED in self._states - result = ( - not finished - and ( - (is_time and send_time_dependent) or - send_time_independent - ) + result = not finished and ( + (is_time and send_time_dependent) or send_time_independent + ) + self._logger.debug( + "_should_we_send", + result=result, + is_time=is_time, + send_time_dependent=send_time_dependent, + send_time_independent=send_time_independent, + states=self._states, ) - self._logger.debug("_should_we_send", - result=result, - is_time=is_time, - send_time_dependent=send_time_dependent, - send_time_independent=send_time_independent, - states=self._states) return result def _send_connection_is_closed_to_frontend(self): @@ -83,11 +83,13 @@ def _send_connection_is_closed_to_frontend(self): class ConnectionIsClosedSenderFactory: - def create(self, - out_control_socket: Socket, - peer: Peer, - my_connection_info: ConnectionInfo, - timer: Timer) -> ConnectionIsClosedSender: + def create( + self, + out_control_socket: Socket, + peer: Peer, + my_connection_info: ConnectionInfo, + timer: Timer, + ) -> ConnectionIsClosedSender: peer_is_closed_sender = ConnectionIsClosedSender( out_control_socket=out_control_socket, timer=timer, diff --git a/exasol/analytics/udf/communication/peer_communicator/clock.py b/exasol/analytics/udf/communication/peer_communicator/clock.py index 6f7b6962..8b0a4636 100644 --- a/exasol/analytics/udf/communication/peer_communicator/clock.py +++ b/exasol/analytics/udf/communication/peer_communicator/clock.py @@ -1,7 +1,7 @@ import time -class Clock(): +class Clock: def current_timestamp_in_ms(self) -> int: - timestamp = time.monotonic_ns() // 10 ** 6 + timestamp = time.monotonic_ns() // 10**6 return timestamp diff --git a/exasol/analytics/udf/communication/peer_communicator/connection_establisher.py b/exasol/analytics/udf/communication/peer_communicator/connection_establisher.py index 2f7860c9..b171c86e 100644 --- a/exasol/analytics/udf/communication/peer_communicator/connection_establisher.py +++ b/exasol/analytics/udf/communication/peer_communicator/connection_establisher.py @@ -5,25 +5,30 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.messages import Message from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSender -from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import \ - ConnectionIsReadySender +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSender, +) +from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import ( + ConnectionIsReadySender, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender -from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import \ - SynchronizeConnectionSender +from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import ( + SynchronizeConnectionSender, +) LOGGER: FilteringBoundLogger = structlog.get_logger() class ConnectionEstablisher: - def __init__(self, - peer: Peer, - my_connection_info: ConnectionInfo, - sender: Sender, - abort_timeout_sender: AbortTimeoutSender, - connection_is_ready_sender: ConnectionIsReadySender, - synchronize_connection_sender: SynchronizeConnectionSender): + def __init__( + self, + peer: Peer, + my_connection_info: ConnectionInfo, + sender: Sender, + abort_timeout_sender: AbortTimeoutSender, + connection_is_ready_sender: ConnectionIsReadySender, + synchronize_connection_sender: SynchronizeConnectionSender, + ): self._synchronize_connection_sender = synchronize_connection_sender self._connection_is_ready_sender = connection_is_ready_sender self._abort_timeout_sender = abort_timeout_sender @@ -44,9 +49,10 @@ def received_synchronize_connection(self): self._sender.send( Message( __root__=messages.AcknowledgeConnection( - source=self._my_connection_info, - destination=self._peer - ))) + source=self._my_connection_info, destination=self._peer + ) + ) + ) self._connection_is_ready_sender.received_synchronize_connection() self._abort_timeout_sender.stop() diff --git a/exasol/analytics/udf/communication/peer_communicator/connection_establisher_builder.py b/exasol/analytics/udf/communication/peer_communicator/connection_establisher_builder.py index c88ad13b..43c87454 100644 --- a/exasol/analytics/udf/communication/peer_communicator/connection_establisher_builder.py +++ b/exasol/analytics/udf/communication/peer_communicator/connection_establisher_builder.py @@ -1,64 +1,76 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSenderFactory +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSenderFactory, +) from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.connection_establisher import \ - ConnectionEstablisher -from exasol.analytics.udf.communication.peer_communicator.connection_establisher_factory import \ - ConnectionEstablisherFactory -from exasol.analytics.udf.communication.peer_communicator.connection_establisher_timeout_config import \ - ConnectionEstablisherTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import \ - ConnectionIsReadySenderFactory +from exasol.analytics.udf.communication.peer_communicator.connection_establisher import ( + ConnectionEstablisher, +) +from exasol.analytics.udf.communication.peer_communicator.connection_establisher_factory import ( + ConnectionEstablisherFactory, +) +from exasol.analytics.udf.communication.peer_communicator.connection_establisher_timeout_config import ( + ConnectionEstablisherTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import ( + ConnectionIsReadySenderFactory, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender -from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import \ - SynchronizeConnectionSenderFactory +from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import ( + SynchronizeConnectionSenderFactory, +) from exasol.analytics.udf.communication.peer_communicator.timer import TimerFactory from exasol.analytics.udf.communication.socket_factory.abstract import Socket class ConnectionEstablisherBuilder: - def __init__(self, - timer_factory: TimerFactory, - abort_timeout_sender_factory: AbortTimeoutSenderFactory = AbortTimeoutSenderFactory(), - connection_is_ready_sender_factory: ConnectionIsReadySenderFactory = ConnectionIsReadySenderFactory(), - synchronize_connection_sender_factory: SynchronizeConnectionSenderFactory = - SynchronizeConnectionSenderFactory(), - connection_establisher_factory: ConnectionEstablisherFactory = - ConnectionEstablisherFactory()): + def __init__( + self, + timer_factory: TimerFactory, + abort_timeout_sender_factory: AbortTimeoutSenderFactory = AbortTimeoutSenderFactory(), + connection_is_ready_sender_factory: ConnectionIsReadySenderFactory = ConnectionIsReadySenderFactory(), + synchronize_connection_sender_factory: SynchronizeConnectionSenderFactory = SynchronizeConnectionSenderFactory(), + connection_establisher_factory: ConnectionEstablisherFactory = ConnectionEstablisherFactory(), + ): self._connection_establisher_factory = connection_establisher_factory self._timer_factory = timer_factory - self._synchronize_connection_sender_factory = synchronize_connection_sender_factory + self._synchronize_connection_sender_factory = ( + synchronize_connection_sender_factory + ) self._connection_is_ready_sender_factory = connection_is_ready_sender_factory self._abort_timeout_sender_factory = abort_timeout_sender_factory - def create(self, - peer: Peer, - my_connection_info: ConnectionInfo, - out_control_socket: Socket, - clock: Clock, - sender: Sender, - timeout_config: ConnectionEstablisherTimeoutConfig) -> ConnectionEstablisher: + def create( + self, + peer: Peer, + my_connection_info: ConnectionInfo, + out_control_socket: Socket, + clock: Clock, + sender: Sender, + timeout_config: ConnectionEstablisherTimeoutConfig, + ) -> ConnectionEstablisher: synchronize_connection_sender = self._create_synchronize_connection_sender( my_connection_info=my_connection_info, peer=peer, sender=sender, clock=clock, - timeout_config=timeout_config) + timeout_config=timeout_config, + ) abort_timeout_sender = self._create_abort_timeout_sender( my_connection_info=my_connection_info, peer=peer, out_control_socket=out_control_socket, clock=clock, - timeout_config=timeout_config) + timeout_config=timeout_config, + ) connection_is_ready_sender = self._create_connection_is_ready_sender( my_connection_info=my_connection_info, peer=peer, clock=clock, out_control_socket=out_control_socket, - timeout_config=timeout_config + timeout_config=timeout_config, ) return self._connection_establisher_factory.create( peer=peer, @@ -66,15 +78,21 @@ def create(self, sender=sender, abort_timeout_sender=abort_timeout_sender, connection_is_ready_sender=connection_is_ready_sender, - synchronize_connection_sender=synchronize_connection_sender + synchronize_connection_sender=synchronize_connection_sender, ) - def _create_connection_is_ready_sender(self, - my_connection_info: ConnectionInfo, peer: Peer, - clock: Clock, out_control_socket: Socket, - timeout_config: ConnectionEstablisherTimeoutConfig): + def _create_connection_is_ready_sender( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + clock: Clock, + out_control_socket: Socket, + timeout_config: ConnectionEstablisherTimeoutConfig, + ): connection_is_ready_sender_timer = self._timer_factory.create( - clock=clock, timeout_in_ms=timeout_config.connection_is_ready_wait_time_in_ms) + clock=clock, + timeout_in_ms=timeout_config.connection_is_ready_wait_time_in_ms, + ) connection_is_ready_sender = self._connection_is_ready_sender_factory.create( out_control_socket=out_control_socket, timer=connection_is_ready_sender_timer, @@ -83,31 +101,43 @@ def _create_connection_is_ready_sender(self, ) return connection_is_ready_sender - def _create_abort_timeout_sender(self, - my_connection_info: ConnectionInfo, peer: Peer, - out_control_socket: Socket, clock: Clock, - timeout_config: ConnectionEstablisherTimeoutConfig): + def _create_abort_timeout_sender( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + out_control_socket: Socket, + clock: Clock, + timeout_config: ConnectionEstablisherTimeoutConfig, + ): abort_timeout_sender_timer = self._timer_factory.create( - clock=clock, timeout_in_ms=timeout_config.abort_timeout_in_ms) + clock=clock, timeout_in_ms=timeout_config.abort_timeout_in_ms + ) abort_timeout_sender = self._abort_timeout_sender_factory.create( out_control_socket=out_control_socket, timer=abort_timeout_sender_timer, my_connection_info=my_connection_info, peer=peer, - reason="Timeout occurred during establishing connection." + reason="Timeout occurred during establishing connection.", ) return abort_timeout_sender - def _create_synchronize_connection_sender(self, - my_connection_info: ConnectionInfo, peer: Peer, - sender: Sender, clock: Clock, - timeout_config: ConnectionEstablisherTimeoutConfig): + def _create_synchronize_connection_sender( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + sender: Sender, + clock: Clock, + timeout_config: ConnectionEstablisherTimeoutConfig, + ): synchronize_connection_sender_timer = self._timer_factory.create( - clock=clock, timeout_in_ms=timeout_config.synchronize_retry_timeout_in_ms) - synchronize_connection_sender = self._synchronize_connection_sender_factory.create( - my_connection_info=my_connection_info, - peer=peer, - sender=sender, - timer=synchronize_connection_sender_timer + clock=clock, timeout_in_ms=timeout_config.synchronize_retry_timeout_in_ms + ) + synchronize_connection_sender = ( + self._synchronize_connection_sender_factory.create( + my_connection_info=my_connection_info, + peer=peer, + sender=sender, + timer=synchronize_connection_sender_timer, + ) ) return synchronize_connection_sender diff --git a/exasol/analytics/udf/communication/peer_communicator/connection_establisher_factory.py b/exasol/analytics/udf/communication/peer_communicator/connection_establisher_factory.py index e1ef5b5e..c7ebd5c4 100644 --- a/exasol/analytics/udf/communication/peer_communicator/connection_establisher_factory.py +++ b/exasol/analytics/udf/communication/peer_communicator/connection_establisher_factory.py @@ -1,31 +1,36 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSender -from exasol.analytics.udf.communication.peer_communicator.connection_establisher import \ - ConnectionEstablisher -from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import \ - ConnectionIsReadySender +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSender, +) +from exasol.analytics.udf.communication.peer_communicator.connection_establisher import ( + ConnectionEstablisher, +) +from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import ( + ConnectionIsReadySender, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender -from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import \ - SynchronizeConnectionSender +from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import ( + SynchronizeConnectionSender, +) class ConnectionEstablisherFactory: - def create(self, - peer: Peer, - my_connection_info: ConnectionInfo, - sender: Sender, - abort_timeout_sender: AbortTimeoutSender, - connection_is_ready_sender: ConnectionIsReadySender, - synchronize_connection_sender: SynchronizeConnectionSender - ) -> ConnectionEstablisher: + def create( + self, + peer: Peer, + my_connection_info: ConnectionInfo, + sender: Sender, + abort_timeout_sender: AbortTimeoutSender, + connection_is_ready_sender: ConnectionIsReadySender, + synchronize_connection_sender: SynchronizeConnectionSender, + ) -> ConnectionEstablisher: return ConnectionEstablisher( peer=peer, my_connection_info=my_connection_info, sender=sender, abort_timeout_sender=abort_timeout_sender, connection_is_ready_sender=connection_is_ready_sender, - synchronize_connection_sender=synchronize_connection_sender + synchronize_connection_sender=synchronize_connection_sender, ) diff --git a/exasol/analytics/udf/communication/peer_communicator/connection_establisher_timeout_config.py b/exasol/analytics/udf/communication/peer_communicator/connection_establisher_timeout_config.py index f9451f85..f9e83000 100644 --- a/exasol/analytics/udf/communication/peer_communicator/connection_establisher_timeout_config.py +++ b/exasol/analytics/udf/communication/peer_communicator/connection_establisher_timeout_config.py @@ -5,4 +5,4 @@ class ConnectionEstablisherTimeoutConfig: synchronize_retry_timeout_in_ms: int = 1000 abort_timeout_in_ms: int = 100000 - connection_is_ready_wait_time_in_ms: int = 10000 \ No newline at end of file + connection_is_ready_wait_time_in_ms: int = 10000 diff --git a/exasol/analytics/udf/communication/peer_communicator/connection_is_ready_sender.py b/exasol/analytics/udf/communication/peer_communicator/connection_is_ready_sender.py index 7117fbdc..20cf61a9 100644 --- a/exasol/analytics/udf/communication/peer_communicator/connection_is_ready_sender.py +++ b/exasol/analytics/udf/communication/peer_communicator/connection_is_ready_sender.py @@ -21,11 +21,13 @@ class _States(IntFlag): class ConnectionIsReadySender: - def __init__(self, - out_control_socket: Socket, - peer: Peer, - my_connection_info: ConnectionInfo, - timer: Timer): + def __init__( + self, + out_control_socket: Socket, + peer: Peer, + my_connection_info: ConnectionInfo, + timer: Timer, + ): self._timer = timer self._peer = peer self._out_control_socket = out_control_socket @@ -57,19 +59,17 @@ def _should_we_send(self): send_time_dependent = _States.RECEIVED_SYNCHRONIZE_CONNECTION in self._states send_time_independent = _States.RECEIVED_ACKKNOWLEDGE_CONNECTION in self._states finished = _States.FINISHED in self._states - result = ( - not finished - and ( - (is_time and send_time_dependent) or - send_time_independent - ) + result = not finished and ( + (is_time and send_time_dependent) or send_time_independent + ) + self._logger.debug( + "_should_we_send", + result=result, + is_time=is_time, + send_time_dependent=send_time_dependent, + send_time_independent=send_time_independent, + states=self._states, ) - self._logger.debug("_should_we_send", - result=result, - is_time=is_time, - send_time_dependent=send_time_dependent, - send_time_independent=send_time_independent, - states=self._states) return result def _send_connection_is_ready_to_frontend(self): @@ -83,11 +83,13 @@ def is_ready_to_stop(self): class ConnectionIsReadySenderFactory: - def create(self, - out_control_socket: Socket, - peer: Peer, - my_connection_info: ConnectionInfo, - timer: Timer) -> ConnectionIsReadySender: + def create( + self, + out_control_socket: Socket, + peer: Peer, + my_connection_info: ConnectionInfo, + timer: Timer, + ) -> ConnectionIsReadySender: peer_is_ready_sender = ConnectionIsReadySender( out_control_socket=out_control_socket, timer=timer, diff --git a/exasol/analytics/udf/communication/peer_communicator/frontend_peer_state.py b/exasol/analytics/udf/communication/peer_communicator/frontend_peer_state.py index d7e50d51..e66961bf 100644 --- a/exasol/analytics/udf/communication/peer_communicator/frontend_peer_state.py +++ b/exasol/analytics/udf/communication/peer_communicator/frontend_peer_state.py @@ -1,5 +1,5 @@ from collections import deque -from typing import List, Deque +from typing import Deque, List import structlog from structlog.typing import FilteringBoundLogger @@ -7,21 +7,26 @@ from exasol.analytics.udf.communication import messages from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.background_listener_interface import \ - BackgroundListenerInterface -from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory, \ - Frame +from exasol.analytics.udf.communication.peer_communicator.background_listener_interface import ( + BackgroundListenerInterface, +) +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Frame, + SocketFactory, +) LOGGER: FilteringBoundLogger = structlog.getLogger() class FrontendPeerState: - def __init__(self, - my_connection_info: ConnectionInfo, - socket_factory: SocketFactory, - background_listener: BackgroundListenerInterface, - peer: Peer): + def __init__( + self, + my_connection_info: ConnectionInfo, + socket_factory: SocketFactory, + background_listener: BackgroundListenerInterface, + peer: Peer, + ): self._connection_is_closed = False self._received_messages: Deque[List[Frame]] = deque() self._background_listener = background_listener @@ -31,7 +36,9 @@ def __init__(self, self._connection_is_ready = False self._peer_register_forwarder_is_ready = False self._sequence_number = 0 - self._logger = LOGGER.bind(peer=peer.dict(), my_connection_info=my_connection_info.dict()) + self._logger = LOGGER.bind( + peer=peer.dict(), my_connection_info=my_connection_info.dict() + ) def _next_sequence_number(self): result = self._sequence_number @@ -44,11 +51,15 @@ def received_connection_is_ready(self): def received_peer_register_forwarder_is_ready(self): self._peer_register_forwarder_is_ready = True - def received_payload_message(self, message_obj: messages.Payload, frames: List[Frame]): + def received_payload_message( + self, message_obj: messages.Payload, frames: List[Frame] + ): if message_obj.source != self._peer: - raise RuntimeError(f"Received message from wrong peer. " - f"Expected peer is {self._peer}, but got {message_obj.source}." - f"Message was: {message_obj}") + raise RuntimeError( + f"Received message from wrong peer. " + f"Expected peer is {self._peer}, but got {message_obj.source}." + f"Message was: {message_obj}" + ) self._received_messages.append(frames[1:]) @property @@ -56,9 +67,11 @@ def peer_is_ready(self) -> bool: return self._connection_is_ready and self._peer_register_forwarder_is_ready def send(self, payload: List[Frame]): - message = messages.Payload(source=Peer(connection_info=self._my_connection_info), - destination=self._peer, - sequence_number=self._next_sequence_number()) + message = messages.Payload( + source=Peer(connection_info=self._my_connection_info), + destination=self._peer, + sequence_number=self._next_sequence_number(), + ) self._logger.debug("send", message=message.dict()) self._background_listener.send_payload(message=message, payload=payload) return message.sequence_number @@ -79,5 +92,7 @@ def received_connection_is_closed(self): def connection_is_closed(self) -> bool: return self._connection_is_closed - def received_acknowledge_payload_message(self, acknowledge_payload: messages.AcknowledgePayload): - """ Not yet implemented and for that reason we ignore the input""" + def received_acknowledge_payload_message( + self, acknowledge_payload: messages.AcknowledgePayload + ): + """Not yet implemented and for that reason we ignore the input""" diff --git a/exasol/analytics/udf/communication/peer_communicator/get_peer_receive_socket_name.py b/exasol/analytics/udf/communication/peer_communicator/get_peer_receive_socket_name.py index 569d7dac..e9638906 100644 --- a/exasol/analytics/udf/communication/peer_communicator/get_peer_receive_socket_name.py +++ b/exasol/analytics/udf/communication/peer_communicator/get_peer_receive_socket_name.py @@ -4,7 +4,9 @@ def get_peer_receive_socket_name(peer: Peer) -> str: - quoted_ip_address = urllib.parse.quote_plus(peer.connection_info.ipaddress.ip_address) + quoted_ip_address = urllib.parse.quote_plus( + peer.connection_info.ipaddress.ip_address + ) quoted_port = urllib.parse.quote_plus(str(peer.connection_info.port.port)) quoted_group_identifier = peer.connection_info.group_identifier return f"inproc://peer/{quoted_group_identifier}/{quoted_ip_address}/{quoted_port}" diff --git a/exasol/analytics/udf/communication/peer_communicator/payload_handler.py b/exasol/analytics/udf/communication/peer_communicator/payload_handler.py index 1454034d..fdd06203 100644 --- a/exasol/analytics/udf/communication/peer_communicator/payload_handler.py +++ b/exasol/analytics/udf/communication/peer_communicator/payload_handler.py @@ -1,15 +1,19 @@ from typing import List from exasol.analytics.udf.communication import messages -from exasol.analytics.udf.communication.peer_communicator.payload_receiver import PayloadReceiver -from exasol.analytics.udf.communication.peer_communicator.payload_sender import PayloadSender +from exasol.analytics.udf.communication.peer_communicator.payload_receiver import ( + PayloadReceiver, +) +from exasol.analytics.udf.communication.peer_communicator.payload_sender import ( + PayloadSender, +) from exasol.analytics.udf.communication.socket_factory.abstract import Frame class PayloadHandler: - def __init__(self, - payload_sender: PayloadSender, - payload_receiver: PayloadReceiver): + def __init__( + self, payload_sender: PayloadSender, payload_receiver: PayloadReceiver + ): self._payload_receiver = payload_receiver self._payload_sender = payload_sender diff --git a/exasol/analytics/udf/communication/peer_communicator/payload_handler_builder.py b/exasol/analytics/udf/communication/peer_communicator/payload_handler_builder.py index 4f2ce989..bad4b502 100644 --- a/exasol/analytics/udf/communication/peer_communicator/payload_handler_builder.py +++ b/exasol/analytics/udf/communication/peer_communicator/payload_handler_builder.py @@ -1,50 +1,62 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.payload_handler import PayloadHandler -from exasol.analytics.udf.communication.peer_communicator.payload_handler_factory import \ - PayloadHandlerFactory -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config \ - import PayloadMessageSenderTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator.payload_receiver_factory import \ - PayloadReceiverFactory -from exasol.analytics.udf.communication.peer_communicator.payload_sender_factory import \ - PayloadSenderFactory +from exasol.analytics.udf.communication.peer_communicator.payload_handler import ( + PayloadHandler, +) +from exasol.analytics.udf.communication.peer_communicator.payload_handler_factory import ( + PayloadHandlerFactory, +) +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config import ( + PayloadMessageSenderTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.payload_receiver_factory import ( + PayloadReceiverFactory, +) +from exasol.analytics.udf.communication.peer_communicator.payload_sender_factory import ( + PayloadSenderFactory, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender -from exasol.analytics.udf.communication.socket_factory.abstract import Socket, \ - SocketFactory +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Socket, + SocketFactory, +) class PayloadHandlerBuilder: - def __init__(self, - payload_sender_factory: PayloadSenderFactory, - payload_receiver_factory: PayloadReceiverFactory = PayloadReceiverFactory(), - payload_handler_factory: PayloadHandlerFactory = PayloadHandlerFactory()): + def __init__( + self, + payload_sender_factory: PayloadSenderFactory, + payload_receiver_factory: PayloadReceiverFactory = PayloadReceiverFactory(), + payload_handler_factory: PayloadHandlerFactory = PayloadHandlerFactory(), + ): self._payload_handler_factory = payload_handler_factory self._payload_receiver_factory = payload_receiver_factory self._payload_sender_factory = payload_sender_factory - def create(self, - my_connection_info: ConnectionInfo, - peer: Peer, - out_control_socket: Socket, - socket_factory: SocketFactory, - sender: Sender, - clock: Clock, - payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig) -> PayloadHandler: + def create( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + out_control_socket: Socket, + socket_factory: SocketFactory, + sender: Sender, + clock: Clock, + payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig, + ) -> PayloadHandler: payload_sender = self._payload_sender_factory.create( my_connection_info=my_connection_info, peer=peer, sender=sender, out_control_socket=out_control_socket, clock=clock, - payload_message_sender_timeout_config=payload_message_sender_timeout_config + payload_message_sender_timeout_config=payload_message_sender_timeout_config, ) payload_receiver = self._payload_receiver_factory.create( my_connection_info=my_connection_info, peer=peer, sender=sender, - out_control_socket=out_control_socket + out_control_socket=out_control_socket, ) payload_handler = self._payload_handler_factory.create( payload_sender=payload_sender, diff --git a/exasol/analytics/udf/communication/peer_communicator/payload_handler_factory.py b/exasol/analytics/udf/communication/peer_communicator/payload_handler_factory.py index 5d944fb2..f4d2b3e3 100644 --- a/exasol/analytics/udf/communication/peer_communicator/payload_handler_factory.py +++ b/exasol/analytics/udf/communication/peer_communicator/payload_handler_factory.py @@ -1,13 +1,18 @@ -from exasol.analytics.udf.communication.peer_communicator.payload_handler import PayloadHandler -from exasol.analytics.udf.communication.peer_communicator.payload_receiver import PayloadReceiver -from exasol.analytics.udf.communication.peer_communicator.payload_sender import PayloadSender +from exasol.analytics.udf.communication.peer_communicator.payload_handler import ( + PayloadHandler, +) +from exasol.analytics.udf.communication.peer_communicator.payload_receiver import ( + PayloadReceiver, +) +from exasol.analytics.udf.communication.peer_communicator.payload_sender import ( + PayloadSender, +) class PayloadHandlerFactory: - def create(self, - payload_sender: PayloadSender, - payload_receiver: PayloadReceiver) -> PayloadHandler: + def create( + self, payload_sender: PayloadSender, payload_receiver: PayloadReceiver + ) -> PayloadHandler: return PayloadHandler( - payload_sender=payload_sender, - payload_receiver=payload_receiver + payload_sender=payload_sender, payload_receiver=payload_receiver ) diff --git a/exasol/analytics/udf/communication/peer_communicator/payload_message_sender.py b/exasol/analytics/udf/communication/peer_communicator/payload_message_sender.py index 1feab672..4aaceec3 100644 --- a/exasol/analytics/udf/communication/peer_communicator/payload_message_sender.py +++ b/exasol/analytics/udf/communication/peer_communicator/payload_message_sender.py @@ -13,13 +13,15 @@ class PayloadMessageSender: - def __init__(self, - message: messages.Payload, - frames: List[Frame], - retry_timer: Timer, - abort_timer: Timer, - sender: Sender, - out_control_socket: Socket): + def __init__( + self, + message: messages.Payload, + frames: List[Frame], + retry_timer: Timer, + abort_timer: Timer, + sender: Sender, + out_control_socket: Socket, + ): self._logger = LOGGER.bind(message=message) self._abort_timer = abort_timer self._out_control_socket = out_control_socket @@ -68,8 +70,7 @@ def _should_we_send_payload(self): def _send_abort(self): abort_payload_message = messages.AbortPayload( - payload=self._message, - reason="Send timeout reached" + payload=self._message, reason="Send timeout reached" ) serialized_message = serialize_message(abort_payload_message) self._out_control_socket.send(serialized_message) diff --git a/exasol/analytics/udf/communication/peer_communicator/payload_message_sender_factory.py b/exasol/analytics/udf/communication/peer_communicator/payload_message_sender_factory.py index e2b9ab87..4d76d160 100644 --- a/exasol/analytics/udf/communication/peer_communicator/payload_message_sender_factory.py +++ b/exasol/analytics/udf/communication/peer_communicator/payload_message_sender_factory.py @@ -2,33 +2,41 @@ from exasol.analytics.udf.communication.messages import Payload from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender import \ - PayloadMessageSender -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config \ - import PayloadMessageSenderTimeoutConfig +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender import ( + PayloadMessageSender, +) +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config import ( + PayloadMessageSenderTimeoutConfig, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender from exasol.analytics.udf.communication.peer_communicator.timer import TimerFactory -from exasol.analytics.udf.communication.socket_factory.abstract import Socket, Frame +from exasol.analytics.udf.communication.socket_factory.abstract import Frame, Socket class PayloadMessageSenderFactory: def __init__(self, timer_factory: TimerFactory): self._timer_factory = timer_factory - def create(self, - clock: Clock, - sender: Sender, - message: Payload, - frames: List[Frame], - payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig, - out_control_socket: Socket) -> PayloadMessageSender: + def create( + self, + clock: Clock, + sender: Sender, + message: Payload, + frames: List[Frame], + payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig, + out_control_socket: Socket, + ) -> PayloadMessageSender: retry_timer = self._timer_factory.create( - clock, payload_message_sender_timeout_config.retry_timeout_in_ms) + clock, payload_message_sender_timeout_config.retry_timeout_in_ms + ) abort_timer = self._timer_factory.create( - clock, payload_message_sender_timeout_config.abort_timeout_in_ms) - return PayloadMessageSender(message=message, - frames=frames, - retry_timer=retry_timer, - abort_timer=abort_timer, - sender=sender, - out_control_socket=out_control_socket) + clock, payload_message_sender_timeout_config.abort_timeout_in_ms + ) + return PayloadMessageSender( + message=message, + frames=frames, + retry_timer=retry_timer, + abort_timer=abort_timer, + sender=sender, + out_control_socket=out_control_socket, + ) diff --git a/exasol/analytics/udf/communication/peer_communicator/payload_receiver.py b/exasol/analytics/udf/communication/peer_communicator/payload_receiver.py index 55e3ea3b..461c0f1b 100644 --- a/exasol/analytics/udf/communication/peer_communicator/payload_receiver.py +++ b/exasol/analytics/udf/communication/peer_communicator/payload_receiver.py @@ -7,19 +7,18 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator.sender import Sender -from exasol.analytics.udf.communication.socket_factory.abstract \ - import Frame, Socket +from exasol.analytics.udf.communication.socket_factory.abstract import Frame, Socket LOGGER: FilteringBoundLogger = structlog.get_logger() class PayloadReceiver: def __init__( - self, - peer: Peer, - my_connection_info: ConnectionInfo, - out_control_socket: Socket, - sender: Sender + self, + peer: Peer, + my_connection_info: ConnectionInfo, + out_control_socket: Socket, + sender: Sender, ): self._peer = peer self._my_connection_info = my_connection_info @@ -41,30 +40,45 @@ def received_payload(self, message: messages.Payload, frames: List[Frame]): elif message.sequence_number > self._next_received_payload_sequence_number: self._add_new_message_to_buffer(message, frames) - def _add_new_message_to_buffer(self, message: messages.Payload, frames: List[Frame]): + def _add_new_message_to_buffer( + self, message: messages.Payload, frames: List[Frame] + ): self._logger.info("put_to_buffer", message=message.dict()) self._received_payload_dict[message.sequence_number] = frames - def _forward_new_message_directly(self, message: messages.Payload, frames: List[Frame]): + def _forward_new_message_directly( + self, message: messages.Payload, frames: List[Frame] + ): self._logger.info("forward_from_message", message=message.dict()) self._forward_received_payload(frames) def _forward_messages_from_buffer(self): - while self._next_received_payload_sequence_number in self._received_payload_dict: - self._logger.info("forward_from_buffer", - _next_recieved_payload_sequence_number=self._next_received_payload_sequence_number, - _received_payload_dict_keys=list(self._received_payload_dict.keys())) - next_frames = self._received_payload_dict.pop(self._next_received_payload_sequence_number) + while ( + self._next_received_payload_sequence_number in self._received_payload_dict + ): + self._logger.info( + "forward_from_buffer", + _next_recieved_payload_sequence_number=self._next_received_payload_sequence_number, + _received_payload_dict_keys=list(self._received_payload_dict.keys()), + ) + next_frames = self._received_payload_dict.pop( + self._next_received_payload_sequence_number + ) self._forward_received_payload(next_frames) def _send_acknowledge_payload_message(self, sequence_number: int): acknowledge_payload_message = messages.AcknowledgePayload( source=Peer(connection_info=self._my_connection_info), sequence_number=sequence_number, - destination=self._peer + destination=self._peer, + ) + self._logger.info( + "_send_acknowledge_payload_message", + message=acknowledge_payload_message.dict(), + ) + self._sender.send( + message=messages.Message(__root__=acknowledge_payload_message) ) - self._logger.info("_send_acknowledge_payload_message", message=acknowledge_payload_message.dict()) - self._sender.send(message=messages.Message(__root__=acknowledge_payload_message)) def _forward_received_payload(self, frames: List[Frame]): self._out_control_socket.send_multipart(frames) diff --git a/exasol/analytics/udf/communication/peer_communicator/payload_receiver_factory.py b/exasol/analytics/udf/communication/peer_communicator/payload_receiver_factory.py index a59643e9..34bef4dd 100644 --- a/exasol/analytics/udf/communication/peer_communicator/payload_receiver_factory.py +++ b/exasol/analytics/udf/communication/peer_communicator/payload_receiver_factory.py @@ -3,25 +3,26 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.payload_receiver import PayloadReceiver +from exasol.analytics.udf.communication.peer_communicator.payload_receiver import ( + PayloadReceiver, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender -from exasol.analytics.udf.communication.socket_factory.abstract \ - import Socket +from exasol.analytics.udf.communication.socket_factory.abstract import Socket LOGGER: FilteringBoundLogger = structlog.get_logger() class PayloadReceiverFactory: def create( - self, - peer: Peer, - my_connection_info: ConnectionInfo, - sender: Sender, - out_control_socket: Socket, + self, + peer: Peer, + my_connection_info: ConnectionInfo, + sender: Sender, + out_control_socket: Socket, ) -> PayloadReceiver: return PayloadReceiver( peer=peer, my_connection_info=my_connection_info, sender=sender, - out_control_socket=out_control_socket + out_control_socket=out_control_socket, ) diff --git a/exasol/analytics/udf/communication/peer_communicator/payload_sender.py b/exasol/analytics/udf/communication/peer_communicator/payload_sender.py index 329b7b53..bb46d710 100644 --- a/exasol/analytics/udf/communication/peer_communicator/payload_sender.py +++ b/exasol/analytics/udf/communication/peer_communicator/payload_sender.py @@ -8,12 +8,15 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender import \ - PayloadMessageSender -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_factory import \ - PayloadMessageSenderFactory -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config \ - import PayloadMessageSenderTimeoutConfig +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender import ( + PayloadMessageSender, +) +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_factory import ( + PayloadMessageSenderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config import ( + PayloadMessageSenderTimeoutConfig, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender from exasol.analytics.udf.communication.serialization import serialize_message from exasol.analytics.udf.communication.socket_factory.abstract import Frame, Socket @@ -22,16 +25,20 @@ class PayloadSender: - def __init__(self, - my_connection_info: ConnectionInfo, - peer: Peer, - sender: Sender, - clock: Clock, - out_control_socket: Socket, - payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig, - payload_message_sender_factory: PayloadMessageSenderFactory): + def __init__( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + sender: Sender, + clock: Clock, + out_control_socket: Socket, + payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig, + payload_message_sender_factory: PayloadMessageSenderFactory, + ): self._out_control_socket = out_control_socket - self._payload_message_sender_timeout_config = payload_message_sender_timeout_config + self._payload_message_sender_timeout_config = ( + payload_message_sender_timeout_config + ) self._clock = clock self._peer = peer self._my_connection_info = my_connection_info @@ -42,7 +49,9 @@ def __init__(self, my_connection_info=self._my_connection_info.dict(), ) self._next_send_payload_sequence_number = 0 - self._payload_message_sender_dict: Dict[int, PayloadMessageSender] = OrderedDict() + self._payload_message_sender_dict: Dict[int, PayloadMessageSender] = ( + OrderedDict() + ) def try_send(self): for payload_sender in self._payload_message_sender_dict.values(): @@ -53,11 +62,13 @@ def received_acknowledge_payload(self, message: messages.AcknowledgePayload): if message.sequence_number in self._payload_message_sender_dict: self._payload_message_sender_dict[message.sequence_number].stop() del self._payload_message_sender_dict[message.sequence_number] - self._out_control_socket.send(serialize_message(messages.Message(__root__=message))) + self._out_control_socket.send( + serialize_message(messages.Message(__root__=message)) + ) def send_payload(self, message: messages.Payload, frames: List[Frame]): self._logger.info("send_payload", message=message.dict()) - self._payload_message_sender_dict[message.sequence_number] = \ + self._payload_message_sender_dict[message.sequence_number] = ( self._payload_message_sender_factory.create( message=message, frames=frames, @@ -66,6 +77,7 @@ def send_payload(self, message: messages.Payload, frames: List[Frame]): clock=self._clock, payload_message_sender_timeout_config=self._payload_message_sender_timeout_config, ) + ) def is_ready_to_stop(self): return len(self._payload_message_sender_dict) == 0 diff --git a/exasol/analytics/udf/communication/peer_communicator/payload_sender_factory.py b/exasol/analytics/udf/communication/peer_communicator/payload_sender_factory.py index 039b88fb..3647f6ff 100644 --- a/exasol/analytics/udf/communication/peer_communicator/payload_sender_factory.py +++ b/exasol/analytics/udf/communication/peer_communicator/payload_sender_factory.py @@ -1,11 +1,15 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_factory import \ - PayloadMessageSenderFactory -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config \ - import PayloadMessageSenderTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator.payload_sender import PayloadSender +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_factory import ( + PayloadMessageSenderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config import ( + PayloadMessageSenderTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.payload_sender import ( + PayloadSender, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender from exasol.analytics.udf.communication.socket_factory.abstract import Socket @@ -15,13 +19,15 @@ class PayloadSenderFactory: def __init__(self, payload_message_sender_factory: PayloadMessageSenderFactory): self._payload_message_sender_factory = payload_message_sender_factory - def create(self, - my_connection_info: ConnectionInfo, - peer: Peer, - sender: Sender, - clock: Clock, - payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig, - out_control_socket: Socket) -> PayloadSender: + def create( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + sender: Sender, + clock: Clock, + payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig, + out_control_socket: Socket, + ) -> PayloadSender: return PayloadSender( my_connection_info=my_connection_info, peer=peer, @@ -29,5 +35,5 @@ def create(self, clock=clock, payload_message_sender_timeout_config=payload_message_sender_timeout_config, out_control_socket=out_control_socket, - payload_message_sender_factory=self._payload_message_sender_factory + payload_message_sender_factory=self._payload_message_sender_factory, ) diff --git a/exasol/analytics/udf/communication/peer_communicator/peer_communicator.py b/exasol/analytics/udf/communication/peer_communicator/peer_communicator.py index 673883d9..30020fcb 100644 --- a/exasol/analytics/udf/communication/peer_communicator/peer_communicator.py +++ b/exasol/analytics/udf/communication/peer_communicator/peer_communicator.py @@ -1,6 +1,6 @@ import time from dataclasses import asdict -from typing import Optional, Dict, List, Callable +from typing import Callable, Dict, List, Optional import structlog from structlog.types import FilteringBoundLogger @@ -9,43 +9,57 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.background_listener_interface import \ - BackgroundListenerInterface +from exasol.analytics.udf.communication.peer_communicator.background_listener_interface import ( + BackgroundListenerInterface, +) from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import \ - ForwardRegisterPeerConfig -from exasol.analytics.udf.communication.peer_communicator.frontend_peer_state import \ - FrontendPeerState -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig -from exasol.analytics.udf.communication.socket_factory.abstract \ - import SocketFactory, Frame +from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import ( + ForwardRegisterPeerConfig, +) +from exasol.analytics.udf.communication.peer_communicator.frontend_peer_state import ( + FrontendPeerState, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Frame, + SocketFactory, +) LOGGER: FilteringBoundLogger = structlog.getLogger() def key_for_peer(peer: Peer): - return peer.connection_info.ipaddress.ip_address + "_" + str(peer.connection_info.port.port) + return ( + peer.connection_info.ipaddress.ip_address + + "_" + + str(peer.connection_info.port.port) + ) -def _compute_handle_message_timeout(start_time_ns: int, timeout_in_milliseconds: Optional[int] = None) -> int: +def _compute_handle_message_timeout( + start_time_ns: int, timeout_in_milliseconds: Optional[int] = None +) -> int: time_difference_ns = time.monotonic_ns() - start_time_ns - time_difference_ms = time_difference_ns // 10 ** 6 + time_difference_ms = time_difference_ns // 10**6 handle_message_timeout_ms = timeout_in_milliseconds - time_difference_ms return handle_message_timeout_ms class PeerCommunicator: - def __init__(self, - name: str, - number_of_peers: int, - listen_ip: IPAddress, - group_identifier: str, - socket_factory: SocketFactory, - config: PeerCommunicatorConfig = PeerCommunicatorConfig(), - clock: Clock = Clock(), - trace_logging: bool = False): + def __init__( + self, + name: str, + number_of_peers: int, + listen_ip: IPAddress, + group_identifier: str, + socket_factory: SocketFactory, + config: PeerCommunicatorConfig = PeerCommunicatorConfig(), + clock: Clock = Clock(), + trace_logging: bool = False, + ): self._config = config self._socket_factory = socket_factory self._name = name @@ -55,7 +69,7 @@ def __init__(self, name=self._name, group_identifier=self._group_identifier, number_of_peers=self._number_of_peers, - config=asdict(config) + config=asdict(config), ) self._logger.info("init") self._background_listener = BackgroundListenerInterface( @@ -69,12 +83,16 @@ def __init__(self, trace_logging=trace_logging, ) self._my_connection_info = self._background_listener.my_connection_info - self._logger = self._logger.bind(my_connection_info=self._my_connection_info.dict()) + self._logger = self._logger.bind( + my_connection_info=self._my_connection_info.dict() + ) self._logger.info("my_connection_info") self._peer_states: Dict[Peer, FrontendPeerState] = {} def _handle_messages(self, timeout_in_milliseconds: Optional[int] = 0): - for message_obj, frames in self._background_listener.receive_messages(timeout_in_milliseconds): + for message_obj, frames in self._background_listener.receive_messages( + timeout_in_milliseconds + ): specific_message_obj = message_obj.__root__ if isinstance(specific_message_obj, messages.ConnectionIsReady): peer = specific_message_obj.peer @@ -84,23 +102,28 @@ def _handle_messages(self, timeout_in_milliseconds: Optional[int] = 0): peer = specific_message_obj.peer self._add_peer_state(peer) self._peer_states[peer].received_connection_is_closed() - elif isinstance(specific_message_obj, messages.PeerRegisterForwarderIsReady): + elif isinstance( + specific_message_obj, messages.PeerRegisterForwarderIsReady + ): peer = specific_message_obj.peer self._add_peer_state(peer) self._peer_states[peer].received_peer_register_forwarder_is_ready() elif isinstance(specific_message_obj, messages.Timeout): raise TimeoutError(specific_message_obj.reason) elif isinstance(specific_message_obj, messages.Payload): - self._peer_states[specific_message_obj.source].received_payload_message(specific_message_obj, frames) + self._peer_states[specific_message_obj.source].received_payload_message( + specific_message_obj, frames + ) elif isinstance(specific_message_obj, messages.AcknowledgePayload): - self._peer_states[specific_message_obj.source].received_acknowledge_payload_message( - specific_message_obj) + self._peer_states[ + specific_message_obj.source + ].received_acknowledge_payload_message(specific_message_obj) elif isinstance(specific_message_obj, messages.AbortPayload): raise TimeoutError(specific_message_obj.reason) else: self._logger.error( - "Unknown message", - message_obj=specific_message_obj.dict()) + "Unknown message", message_obj=specific_message_obj.dict() + ) def _add_peer_state(self, peer: Peer): if peer not in self._peer_states: @@ -108,16 +131,21 @@ def _add_peer_state(self, peer: Peer): my_connection_info=self.my_connection_info, socket_factory=self._socket_factory, peer=peer, - background_listener=self._background_listener + background_listener=self._background_listener, ) - def _wait_for_condition(self, condition: Callable[[], bool], - timeout_in_milliseconds: Optional[int] = None) -> bool: + def _wait_for_condition( + self, + condition: Callable[[], bool], + timeout_in_milliseconds: Optional[int] = None, + ) -> bool: start_time_ns = time.monotonic_ns() self._handle_messages(timeout_in_milliseconds=0) while not condition(): if timeout_in_milliseconds is not None: - handle_message_timeout_ms = _compute_handle_message_timeout(start_time_ns, timeout_in_milliseconds) + handle_message_timeout_ms = _compute_handle_message_timeout( + start_time_ns, timeout_in_milliseconds + ) if handle_message_timeout_ms < 0: break else: @@ -126,22 +154,32 @@ def _wait_for_condition(self, condition: Callable[[], bool], return condition() def wait_for_peers(self, timeout_in_milliseconds: Optional[int] = None) -> bool: - return self._wait_for_condition(self._are_all_peers_connected, timeout_in_milliseconds) + return self._wait_for_condition( + self._are_all_peers_connected, timeout_in_milliseconds + ) - def peers(self, timeout_in_milliseconds: Optional[int] = None) -> Optional[List[Peer]]: + def peers( + self, timeout_in_milliseconds: Optional[int] = None + ) -> Optional[List[Peer]]: self.wait_for_peers(timeout_in_milliseconds) if self._are_all_peers_connected(): - peers = [peer for peer in self._peer_states.keys()] + \ - [Peer(connection_info=self._my_connection_info)] + peers = [peer for peer in self._peer_states.keys()] + [ + Peer(connection_info=self._my_connection_info) + ] return sorted(peers, key=key_for_peer) else: return None def register_peer(self, peer_connection_info: ConnectionInfo): - self._logger.info("register_peer", peer_connection_info=peer_connection_info.dict()) + self._logger.info( + "register_peer", peer_connection_info=peer_connection_info.dict() + ) self._handle_messages() - if (peer_connection_info.group_identifier == self.my_connection_info.group_identifier - and peer_connection_info != self.my_connection_info): + if ( + peer_connection_info.group_identifier + == self.my_connection_info.group_identifier + and peer_connection_info != self.my_connection_info + ): peer = Peer(connection_info=peer_connection_info) if peer not in self._peer_states: self._add_peer_state(peer) @@ -178,7 +216,9 @@ def are_all_peers_connected(self) -> bool: return result def _are_all_peers_connected(self): - all_peers_ready = all(peer_state.peer_is_ready for peer_state in self._peer_states.values()) + all_peers_ready = all( + peer_state.peer_is_ready for peer_state in self._peer_states.values() + ) result = len(self._peer_states) == self._number_of_peers - 1 and all_peers_ready return result @@ -186,19 +226,23 @@ def send(self, peer: Peer, message: List[Frame]): self.wait_for_peers() self._peer_states[peer].send(message) - def recv(self, peer: Peer, timeout_in_milliseconds: Optional[int] = None) -> List[Frame]: + def recv( + self, peer: Peer, timeout_in_milliseconds: Optional[int] = None + ) -> List[Frame]: self.wait_for_peers() - peer_has_received_messages = \ - self._wait_for_condition(self._peer_states[peer].has_received_messages, - timeout_in_milliseconds=timeout_in_milliseconds) + peer_has_received_messages = self._wait_for_condition( + self._peer_states[peer].has_received_messages, + timeout_in_milliseconds=timeout_in_milliseconds, + ) if peer_has_received_messages: return self._peer_states[peer].recv() else: raise TimeoutError("Timeout occurred during waiting for messages.") def poll_peers( - self, peers: Optional[List[Peer]] = None, - timeout_in_milliseconds: Optional[int] = None + self, + peers: Optional[List[Peer]] = None, + timeout_in_milliseconds: Optional[int] = None, ) -> List[Peer]: self.wait_for_peers() @@ -206,12 +250,18 @@ def poll_peers( peers = self._peer_states.keys() def have_peers_received_messages() -> bool: - result = any(self._peer_states[peer].has_received_messages() for peer in peers) + result = any( + self._peer_states[peer].has_received_messages() for peer in peers + ) return result - self._wait_for_condition(have_peers_received_messages, - timeout_in_milliseconds=timeout_in_milliseconds) - return [peer for peer in peers if self._peer_states[peer].has_received_messages()] + self._wait_for_condition( + have_peers_received_messages, + timeout_in_milliseconds=timeout_in_milliseconds, + ) + return [ + peer for peer in peers if self._peer_states[peer].has_received_messages() + ] def stop(self): self._logger.info("stop") @@ -225,17 +275,22 @@ def _stop_background_listener(self): self._logger.info("stop background_listener") self._background_listener.prepare_to_stop() try: - is_ready_to_stop = \ - self._wait_for_condition(self._are_all_peers_disconnected, - timeout_in_milliseconds=self._config.close_timeout_in_ms) + is_ready_to_stop = self._wait_for_condition( + self._are_all_peers_disconnected, + timeout_in_milliseconds=self._config.close_timeout_in_ms, + ) if not is_ready_to_stop: - raise TimeoutError("Timeout expired, could not gracefully stop PeerCommuincator.") + raise TimeoutError( + "Timeout expired, could not gracefully stop PeerCommuincator." + ) finally: self._background_listener.stop() self._background_listener = None def _are_all_peers_disconnected(self): - all_peers_ready = all(peer_state.connection_is_closed for peer_state in self._peer_states.values()) + all_peers_ready = all( + peer_state.connection_is_closed for peer_state in self._peer_states.values() + ) result = len(self._peer_states) == self._number_of_peers - 1 and all_peers_ready return result diff --git a/exasol/analytics/udf/communication/peer_communicator/peer_communicator_config.py b/exasol/analytics/udf/communication/peer_communicator/peer_communicator_config.py index 419a1138..1172ee66 100644 --- a/exasol/analytics/udf/communication/peer_communicator/peer_communicator_config.py +++ b/exasol/analytics/udf/communication/peer_communicator/peer_communicator_config.py @@ -1,24 +1,39 @@ import dataclasses -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer_timeout_config import ConnectionCloserTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator. \ - connection_establisher_timeout_config import ConnectionEstablisherTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import \ - ForwardRegisterPeerConfig -from exasol.analytics.udf.communication.peer_communicator. \ - payload_message_sender_timeout_config import PayloadMessageSenderTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator. \ - register_peer_forwarder_timeout_config import RegisterPeerForwarderTimeoutConfig +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer_timeout_config import ( + ConnectionCloserTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.connection_establisher_timeout_config import ( + ConnectionEstablisherTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import ( + ForwardRegisterPeerConfig, +) +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config import ( + PayloadMessageSenderTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_timeout_config import ( + RegisterPeerForwarderTimeoutConfig, +) @dataclasses.dataclass class PeerCommunicatorConfig: - connection_establisher_timeout_config: ConnectionEstablisherTimeoutConfig = ConnectionEstablisherTimeoutConfig() - connection_closer_timeout_config: ConnectionCloserTimeoutConfig = ConnectionCloserTimeoutConfig() - register_peer_forwarder_timeout_config: RegisterPeerForwarderTimeoutConfig = RegisterPeerForwarderTimeoutConfig() - forward_register_peer_config: ForwardRegisterPeerConfig = ForwardRegisterPeerConfig() - payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig = PayloadMessageSenderTimeoutConfig() + connection_establisher_timeout_config: ConnectionEstablisherTimeoutConfig = ( + ConnectionEstablisherTimeoutConfig() + ) + connection_closer_timeout_config: ConnectionCloserTimeoutConfig = ( + ConnectionCloserTimeoutConfig() + ) + register_peer_forwarder_timeout_config: RegisterPeerForwarderTimeoutConfig = ( + RegisterPeerForwarderTimeoutConfig() + ) + forward_register_peer_config: ForwardRegisterPeerConfig = ( + ForwardRegisterPeerConfig() + ) + payload_message_sender_timeout_config: PayloadMessageSenderTimeoutConfig = ( + PayloadMessageSenderTimeoutConfig() + ) poll_timeout_in_ms: int = 200 send_socket_linger_time_in_ms: int = 100 close_timeout_in_ms: int = 100000 diff --git a/exasol/analytics/udf/communication/peer_communicator/register_peer_connection.py b/exasol/analytics/udf/communication/peer_communicator/register_peer_connection.py index cc53a407..d23ec86e 100644 --- a/exasol/analytics/udf/communication/peer_communicator/register_peer_connection.py +++ b/exasol/analytics/udf/communication/peer_communicator/register_peer_connection.py @@ -6,8 +6,9 @@ from exasol.analytics.udf.communication import messages from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.send_socket_factory import \ - SendSocketFactory +from exasol.analytics.udf.communication.peer_communicator.send_socket_factory import ( + SendSocketFactory, +) from exasol.analytics.udf.communication.serialization import serialize_message from exasol.analytics.udf.communication.socket_factory.abstract import Socket @@ -16,22 +17,28 @@ class RegisterPeerConnection: - def __init__(self, - predecessor: Optional[Peer], - predecessor_send_socket_factory: Optional[SendSocketFactory], - successor: Peer, - successor_send_socket_factory: SendSocketFactory, - my_connection_info: ConnectionInfo): - self._logger = LOGGER.bind(successor=successor.dict(), - predecessor=None if predecessor is None else predecessor.dict(), - my_connection_info=my_connection_info.dict()) + def __init__( + self, + predecessor: Optional[Peer], + predecessor_send_socket_factory: Optional[SendSocketFactory], + successor: Peer, + successor_send_socket_factory: SendSocketFactory, + my_connection_info: ConnectionInfo, + ): + self._logger = LOGGER.bind( + successor=successor.dict(), + predecessor=None if predecessor is None else predecessor.dict(), + my_connection_info=my_connection_info.dict(), + ) self._successor = successor self._predecessor = predecessor self._my_connection_info = my_connection_info self._successor_socket = successor_send_socket_factory.create_send_socket() self._predecessor_socket: Optional[Socket] = None if predecessor_send_socket_factory is not None: - self._predecessor_socket = predecessor_send_socket_factory.create_send_socket() + self._predecessor_socket = ( + predecessor_send_socket_factory.create_send_socket() + ) @property def successor(self) -> Peer: @@ -44,8 +51,7 @@ def predecessor(self) -> Optional[Peer]: def forward(self, peer: Peer): self._logger.debug("forward", peer=peer.dict()) message = messages.RegisterPeer( - peer=peer, - source=Peer(connection_info=self._my_connection_info) + peer=peer, source=Peer(connection_info=self._my_connection_info) ) serialized_message = serialize_message(message) self._successor_socket.send(serialized_message) @@ -54,8 +60,7 @@ def ack(self, peer: Peer): self._logger.debug("ack", peer=peer.dict()) if self._predecessor_socket is not None: message = messages.AcknowledgeRegisterPeer( - peer=peer, - source=Peer(connection_info=self._my_connection_info) + peer=peer, source=Peer(connection_info=self._my_connection_info) ) serialized_message = serialize_message(message) self._predecessor_socket.send(serialized_message) @@ -63,8 +68,7 @@ def ack(self, peer: Peer): def complete(self, peer: Peer): self._logger.debug("complete", peer=peer.dict()) message = messages.RegisterPeerComplete( - peer=peer, - source=Peer(connection_info=self._my_connection_info) + peer=peer, source=Peer(connection_info=self._my_connection_info) ) serialized_message = serialize_message(message) self._successor_socket.send(serialized_message) diff --git a/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder.py b/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder.py index 944d3d70..7db66451 100644 --- a/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder.py +++ b/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder.py @@ -3,16 +3,21 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSender -from exasol.analytics.udf.communication.peer_communicator.acknowledge_register_peer_sender import \ - AcknowledgeRegisterPeerSender -from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import \ - RegisterPeerConnection -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_is_ready_sender import \ - RegisterPeerForwarderIsReadySender -from exasol.analytics.udf.communication.peer_communicator.register_peer_sender import \ - RegisterPeerSender +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSender, +) +from exasol.analytics.udf.communication.peer_communicator.acknowledge_register_peer_sender import ( + AcknowledgeRegisterPeerSender, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import ( + RegisterPeerConnection, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_is_ready_sender import ( + RegisterPeerForwarderIsReadySender, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_sender import ( + RegisterPeerSender, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender LOGGER: FilteringBoundLogger = structlog.get_logger() @@ -20,16 +25,20 @@ class RegisterPeerForwarder: - def __init__(self, - peer: Peer, - my_connection_info: ConnectionInfo, - sender: Sender, - register_peer_connection: RegisterPeerConnection, - abort_timeout_sender: AbortTimeoutSender, - acknowledge_register_peer_sender: AcknowledgeRegisterPeerSender, - register_peer_sender: RegisterPeerSender, - register_peer_forwarder_is_ready_sender: RegisterPeerForwarderIsReadySender): - self._register_peer_forwarder_is_ready_sender = register_peer_forwarder_is_ready_sender + def __init__( + self, + peer: Peer, + my_connection_info: ConnectionInfo, + sender: Sender, + register_peer_connection: RegisterPeerConnection, + abort_timeout_sender: AbortTimeoutSender, + acknowledge_register_peer_sender: AcknowledgeRegisterPeerSender, + register_peer_sender: RegisterPeerSender, + register_peer_forwarder_is_ready_sender: RegisterPeerForwarderIsReadySender, + ): + self._register_peer_forwarder_is_ready_sender = ( + register_peer_forwarder_is_ready_sender + ) self._register_peer_sender = register_peer_sender self._acknowledge_register_peer_sender = acknowledge_register_peer_sender self._abort_timeout_sender = abort_timeout_sender @@ -68,5 +77,3 @@ def try_send(self): def is_ready_to_stop(self): return self._register_peer_forwarder_is_ready_sender.is_ready_to_stop() - - diff --git a/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_builder.py b/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_builder.py index 43de58f3..6470c860 100644 --- a/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_builder.py +++ b/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_builder.py @@ -1,20 +1,31 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSenderFactory, AbortTimeoutSender -from exasol.analytics.udf.communication.peer_communicator.acknowledge_register_peer_sender import \ - AcknowledgeRegisterPeerSenderFactory, AcknowledgeRegisterPeerSender +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSender, + AbortTimeoutSenderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.acknowledge_register_peer_sender import ( + AcknowledgeRegisterPeerSender, + AcknowledgeRegisterPeerSenderFactory, +) from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder import \ - RegisterPeerForwarder -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_builder_parameter import \ - RegisterPeerForwarderBuilderParameter -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_factory import \ - RegisterPeerForwarderFactory -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_is_ready_sender import \ - RegisterPeerForwarderIsReadySenderFactory, RegisterPeerForwarderIsReadySender -from exasol.analytics.udf.communication.peer_communicator.register_peer_sender import \ - RegisterPeerSenderFactory, RegisterPeerSender +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder import ( + RegisterPeerForwarder, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_builder_parameter import ( + RegisterPeerForwarderBuilderParameter, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_factory import ( + RegisterPeerForwarderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_is_ready_sender import ( + RegisterPeerForwarderIsReadySender, + RegisterPeerForwarderIsReadySenderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_sender import ( + RegisterPeerSender, + RegisterPeerSenderFactory, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender from exasol.analytics.udf.communication.peer_communicator.timer import TimerFactory from exasol.analytics.udf.communication.socket_factory.abstract import Socket @@ -22,55 +33,64 @@ class RegisterPeerForwarderBuilder: - def __init__(self, - timer_factory: TimerFactory, - abort_timeout_sender_factory: AbortTimeoutSenderFactory = AbortTimeoutSenderFactory(), - acknowledge_register_peer_sender_factory: AcknowledgeRegisterPeerSenderFactory = - AcknowledgeRegisterPeerSenderFactory(), - register_peer_forwarder_is_ready_sender_factory: RegisterPeerForwarderIsReadySenderFactory = - RegisterPeerForwarderIsReadySenderFactory(), - register_peer_sender_factory: RegisterPeerSenderFactory = RegisterPeerSenderFactory(), - register_peer_forwarder_factory: RegisterPeerForwarderFactory = - RegisterPeerForwarderFactory()): + def __init__( + self, + timer_factory: TimerFactory, + abort_timeout_sender_factory: AbortTimeoutSenderFactory = AbortTimeoutSenderFactory(), + acknowledge_register_peer_sender_factory: AcknowledgeRegisterPeerSenderFactory = AcknowledgeRegisterPeerSenderFactory(), + register_peer_forwarder_is_ready_sender_factory: RegisterPeerForwarderIsReadySenderFactory = RegisterPeerForwarderIsReadySenderFactory(), + register_peer_sender_factory: RegisterPeerSenderFactory = RegisterPeerSenderFactory(), + register_peer_forwarder_factory: RegisterPeerForwarderFactory = RegisterPeerForwarderFactory(), + ): self._register_peer_forwarder_factory = register_peer_forwarder_factory self._timer_factory = timer_factory self._register_peer_sender_factory = register_peer_sender_factory - self._register_peer_forwarder_is_ready_sender_factory = register_peer_forwarder_is_ready_sender_factory - self._acknowledge_register_peer_sender_factory = acknowledge_register_peer_sender_factory + self._register_peer_forwarder_is_ready_sender_factory = ( + register_peer_forwarder_is_ready_sender_factory + ) + self._acknowledge_register_peer_sender_factory = ( + acknowledge_register_peer_sender_factory + ) self._abort_timeout_sender_factory = abort_timeout_sender_factory - def create(self, - peer: Peer, - my_connection_info: ConnectionInfo, - out_control_socket: Socket, - clock: Clock, - sender: Sender, - parameter: RegisterPeerForwarderBuilderParameter - ) -> RegisterPeerForwarder: + def create( + self, + peer: Peer, + my_connection_info: ConnectionInfo, + out_control_socket: Socket, + clock: Clock, + sender: Sender, + parameter: RegisterPeerForwarderBuilderParameter, + ) -> RegisterPeerForwarder: abort_timeout_sender = self._create_abort_timeout_sender( my_connection_info=my_connection_info, peer=peer, out_control_socket=out_control_socket, clock=clock, - parameter=parameter) + parameter=parameter, + ) register_peer_sender = self._create_register_peer_sender( my_connection_info=my_connection_info, peer=peer, clock=clock, - parameter=parameter + parameter=parameter, ) - acknowledge_register_peer_sender = self._create_acknowledge_register_peer_sender( - my_connection_info=my_connection_info, - peer=peer, - clock=clock, - parameter=parameter + acknowledge_register_peer_sender = ( + self._create_acknowledge_register_peer_sender( + my_connection_info=my_connection_info, + peer=peer, + clock=clock, + parameter=parameter, + ) ) - register_peer_forwarder_is_ready_sender = self._create_register_peer_forwarder_is_ready_sender( - my_connection_info=my_connection_info, - peer=peer, - out_control_socket=out_control_socket, - clock=clock, - parameter=parameter + register_peer_forwarder_is_ready_sender = ( + self._create_register_peer_forwarder_is_ready_sender( + my_connection_info=my_connection_info, + peer=peer, + out_control_socket=out_control_socket, + clock=clock, + parameter=parameter, + ) ) return self._register_peer_forwarder_factory.create( peer=peer, @@ -84,12 +104,16 @@ def create(self, ) def _create_acknowledge_register_peer_sender( - self, - my_connection_info: ConnectionInfo, peer: Peer, - clock: Clock, parameter: RegisterPeerForwarderBuilderParameter) -> AcknowledgeRegisterPeerSender: + self, + my_connection_info: ConnectionInfo, + peer: Peer, + clock: Clock, + parameter: RegisterPeerForwarderBuilderParameter, + ) -> AcknowledgeRegisterPeerSender: acknowledge_register_peer_sender_timer = self._timer_factory.create( clock=clock, - timeout_in_ms=parameter.timeout_config.acknowledge_register_peer_retry_timeout_in_ms) + timeout_in_ms=parameter.timeout_config.acknowledge_register_peer_retry_timeout_in_ms, + ) acknowledge_register_peer_sender = self._acknowledge_register_peer_sender_factory.create( register_peer_connection=parameter.register_peer_connection, needs_to_send_for_peer=parameter.behavior_config.needs_to_send_acknowledge_register_peer, @@ -100,11 +124,16 @@ def _create_acknowledge_register_peer_sender( return acknowledge_register_peer_sender def _create_register_peer_sender( - self, - my_connection_info: ConnectionInfo, peer: Peer, - clock: Clock, parameter: RegisterPeerForwarderBuilderParameter) -> RegisterPeerSender: + self, + my_connection_info: ConnectionInfo, + peer: Peer, + clock: Clock, + parameter: RegisterPeerForwarderBuilderParameter, + ) -> RegisterPeerSender: register_peer_sender_timer = self._timer_factory.create( - clock=clock, timeout_in_ms=parameter.timeout_config.register_peer_retry_timeout_in_ms) + clock=clock, + timeout_in_ms=parameter.timeout_config.register_peer_retry_timeout_in_ms, + ) register_peer_sender = self._register_peer_sender_factory.create( register_peer_connection=parameter.register_peer_connection, needs_to_send_for_peer=parameter.behavior_config.needs_to_send_register_peer, @@ -115,35 +144,45 @@ def _create_register_peer_sender( return register_peer_sender def _create_register_peer_forwarder_is_ready_sender( - self, - my_connection_info: ConnectionInfo, - peer: Peer, - out_control_socket: Socket, - clock: Clock, - parameter: RegisterPeerForwarderBuilderParameter) -> RegisterPeerForwarderIsReadySender: + self, + my_connection_info: ConnectionInfo, + peer: Peer, + out_control_socket: Socket, + clock: Clock, + parameter: RegisterPeerForwarderBuilderParameter, + ) -> RegisterPeerForwarderIsReadySender: register_peer_forwarder_is_ready_sender_timer = self._timer_factory.create( - clock=clock, timeout_in_ms=parameter.timeout_config.register_peer_forwarder_is_ready_wait_time_in_ms) - register_peer_forwarder_is_ready_sender = self._register_peer_forwarder_is_ready_sender_factory.create( - peer=peer, my_connection_info=my_connection_info, - out_control_socket=out_control_socket, - timer=register_peer_forwarder_is_ready_sender_timer, - behavior_config=parameter.behavior_config + clock=clock, + timeout_in_ms=parameter.timeout_config.register_peer_forwarder_is_ready_wait_time_in_ms, + ) + register_peer_forwarder_is_ready_sender = ( + self._register_peer_forwarder_is_ready_sender_factory.create( + peer=peer, + my_connection_info=my_connection_info, + out_control_socket=out_control_socket, + timer=register_peer_forwarder_is_ready_sender_timer, + behavior_config=parameter.behavior_config, + ) ) return register_peer_forwarder_is_ready_sender def _create_abort_timeout_sender( - self, - my_connection_info: ConnectionInfo, peer: Peer, - out_control_socket: Socket, clock: Clock, - parameter: RegisterPeerForwarderBuilderParameter) -> AbortTimeoutSender: + self, + my_connection_info: ConnectionInfo, + peer: Peer, + out_control_socket: Socket, + clock: Clock, + parameter: RegisterPeerForwarderBuilderParameter, + ) -> AbortTimeoutSender: abort_timeout_sender_timer = self._timer_factory.create( - clock=clock, timeout_in_ms=parameter.timeout_config.abort_timeout_in_ms) + clock=clock, timeout_in_ms=parameter.timeout_config.abort_timeout_in_ms + ) abort_timeout_sender = self._abort_timeout_sender_factory.create( out_control_socket=out_control_socket, timer=abort_timeout_sender_timer, my_connection_info=my_connection_info, peer=peer, - reason="Timeout occurred during sending register peer." + reason="Timeout occurred during sending register peer.", ) if not parameter.behavior_config.needs_to_send_register_peer: abort_timeout_sender.stop() diff --git a/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_builder_parameter.py b/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_builder_parameter.py index 7cef9c7f..16ad4f3f 100644 --- a/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_builder_parameter.py +++ b/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_builder_parameter.py @@ -1,11 +1,14 @@ import dataclasses -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_behavior_config \ - import RegisterPeerForwarderBehaviorConfig -from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import \ - RegisterPeerConnection -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_timeout_config \ - import RegisterPeerForwarderTimeoutConfig +from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import ( + RegisterPeerConnection, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_behavior_config import ( + RegisterPeerForwarderBehaviorConfig, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_timeout_config import ( + RegisterPeerForwarderTimeoutConfig, +) @dataclasses.dataclass(frozen=True) diff --git a/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_factory.py b/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_factory.py index d440117d..2d21524a 100644 --- a/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_factory.py +++ b/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_factory.py @@ -1,31 +1,39 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSender -from exasol.analytics.udf.communication.peer_communicator.acknowledge_register_peer_sender import \ - AcknowledgeRegisterPeerSender -from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import \ - RegisterPeerConnection -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder import \ - RegisterPeerForwarder -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_is_ready_sender import \ - RegisterPeerForwarderIsReadySender -from exasol.analytics.udf.communication.peer_communicator.register_peer_sender import \ - RegisterPeerSender +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSender, +) +from exasol.analytics.udf.communication.peer_communicator.acknowledge_register_peer_sender import ( + AcknowledgeRegisterPeerSender, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import ( + RegisterPeerConnection, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder import ( + RegisterPeerForwarder, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_is_ready_sender import ( + RegisterPeerForwarderIsReadySender, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_sender import ( + RegisterPeerSender, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender class RegisterPeerForwarderFactory: - def create(self, - peer: Peer, - my_connection_info: ConnectionInfo, - sender: Sender, - register_peer_connection: RegisterPeerConnection, - abort_timeout_sender: AbortTimeoutSender, - acknowledge_register_peer_sender: AcknowledgeRegisterPeerSender, - register_peer_sender: RegisterPeerSender, - register_peer_forwarder_is_ready_sender: RegisterPeerForwarderIsReadySender)->RegisterPeerForwarder: + def create( + self, + peer: Peer, + my_connection_info: ConnectionInfo, + sender: Sender, + register_peer_connection: RegisterPeerConnection, + abort_timeout_sender: AbortTimeoutSender, + acknowledge_register_peer_sender: AcknowledgeRegisterPeerSender, + register_peer_sender: RegisterPeerSender, + register_peer_forwarder_is_ready_sender: RegisterPeerForwarderIsReadySender, + ) -> RegisterPeerForwarder: return RegisterPeerForwarder( peer=peer, my_connection_info=my_connection_info, @@ -34,5 +42,5 @@ def create(self, abort_timeout_sender=abort_timeout_sender, acknowledge_register_peer_sender=acknowledge_register_peer_sender, register_peer_sender=register_peer_sender, - register_peer_forwarder_is_ready_sender=register_peer_forwarder_is_ready_sender + register_peer_forwarder_is_ready_sender=register_peer_forwarder_is_ready_sender, ) diff --git a/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_is_ready_sender.py b/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_is_ready_sender.py index 7154a7df..bc46085a 100644 --- a/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_is_ready_sender.py +++ b/exasol/analytics/udf/communication/peer_communicator/register_peer_forwarder_is_ready_sender.py @@ -7,8 +7,9 @@ from exasol.analytics.udf.communication import messages from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_behavior_config \ - import RegisterPeerForwarderBehaviorConfig +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_behavior_config import ( + RegisterPeerForwarderBehaviorConfig, +) from exasol.analytics.udf.communication.peer_communicator.timer import Timer from exasol.analytics.udf.communication.serialization import serialize_message from exasol.analytics.udf.communication.socket_factory.abstract import Socket @@ -25,12 +26,14 @@ class _States(IntFlag): class RegisterPeerForwarderIsReadySender: - def __init__(self, - peer: Peer, - my_connection_info: ConnectionInfo, - timer: Timer, - out_control_socket: Socket, - behavior_config: RegisterPeerForwarderBehaviorConfig): + def __init__( + self, + peer: Peer, + my_connection_info: ConnectionInfo, + timer: Timer, + out_control_socket: Socket, + behavior_config: RegisterPeerForwarderBehaviorConfig, + ): self._behavior_config = behavior_config self._peer = peer self._timer = timer @@ -40,7 +43,8 @@ def __init__(self, self._logger = LOGGER.bind( peer=self._peer.dict(), my_connection_info=my_connection_info.dict(), - behavior_config=asdict(self._behavior_config)) + behavior_config=asdict(self._behavior_config), + ) def received_acknowledge_register_peer(self): self._logger.debug("received_acknowledge_register_peer") @@ -64,40 +68,37 @@ def _should_we_send(self) -> bool: is_time = self._timer.is_time() send_time_dependent = self._is_send_time_dependent() send_time_independent = self._is_send_time_independent() - result = ( - not _States.FINISHED in self._states - and ( - (is_time and send_time_dependent) or - send_time_independent - ) + result = not _States.FINISHED in self._states and ( + (is_time and send_time_dependent) or send_time_independent + ) + self._logger.debug( + "_should_we_send", + result=result, + is_time=is_time, + send_time_dependent=send_time_dependent, + send_time_independent=send_time_independent, + states=self._states, ) - self._logger.debug("_should_we_send", - result=result, - is_time=is_time, - send_time_dependent=send_time_dependent, - send_time_independent=send_time_independent, - states=self._states) return result def _is_send_time_independent(self): received_acknowledge_register_peer = ( - not self._behavior_config.needs_to_send_register_peer - or _States.REGISTER_PEER_ACKNOWLEDGED in self._states + not self._behavior_config.needs_to_send_register_peer + or _States.REGISTER_PEER_ACKNOWLEDGED in self._states ) received_register_peer_complete = ( - not self._behavior_config.needs_to_send_acknowledge_register_peer - or _States.REGISTER_PEER_COMPLETED in self._states + not self._behavior_config.needs_to_send_acknowledge_register_peer + or _States.REGISTER_PEER_COMPLETED in self._states ) send_independent_of_time = ( - received_acknowledge_register_peer - and received_register_peer_complete + received_acknowledge_register_peer and received_register_peer_complete ) return send_independent_of_time def _is_send_time_dependent(self): received_acknowledge_register_peer = ( - not self._behavior_config.needs_to_send_register_peer - or _States.REGISTER_PEER_ACKNOWLEDGED in self._states + not self._behavior_config.needs_to_send_register_peer + or _States.REGISTER_PEER_ACKNOWLEDGED in self._states ) return received_acknowledge_register_peer @@ -113,16 +114,18 @@ def is_ready_to_stop(self) -> bool: class RegisterPeerForwarderIsReadySenderFactory: - def create(self, - peer: Peer, - my_connection_info: ConnectionInfo, - timer: Timer, - out_control_socket: Socket, - behavior_config: RegisterPeerForwarderBehaviorConfig): + def create( + self, + peer: Peer, + my_connection_info: ConnectionInfo, + timer: Timer, + out_control_socket: Socket, + behavior_config: RegisterPeerForwarderBehaviorConfig, + ): return RegisterPeerForwarderIsReadySender( peer=peer, my_connection_info=my_connection_info, behavior_config=behavior_config, timer=timer, - out_control_socket=out_control_socket + out_control_socket=out_control_socket, ) diff --git a/exasol/analytics/udf/communication/peer_communicator/register_peer_sender.py b/exasol/analytics/udf/communication/peer_communicator/register_peer_sender.py index 82681c84..a53b3b91 100644 --- a/exasol/analytics/udf/communication/peer_communicator/register_peer_sender.py +++ b/exasol/analytics/udf/communication/peer_communicator/register_peer_sender.py @@ -5,24 +5,29 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import \ - RegisterPeerConnection +from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import ( + RegisterPeerConnection, +) from exasol.analytics.udf.communication.peer_communicator.timer import Timer LOGGER: FilteringBoundLogger = structlog.get_logger() -class RegisterPeerSender(): - def __init__(self, - register_peer_connection: Optional[RegisterPeerConnection], - needs_to_send_for_peer: bool, - my_connection_info: ConnectionInfo, - peer: Peer, - timer: Timer): +class RegisterPeerSender: + def __init__( + self, + register_peer_connection: Optional[RegisterPeerConnection], + needs_to_send_for_peer: bool, + my_connection_info: ConnectionInfo, + peer: Peer, + timer: Timer, + ): self._needs_to_send_for_peer = needs_to_send_for_peer self._register_peer_connection = register_peer_connection if needs_to_send_for_peer and self._register_peer_connection is None: - raise ValueError("_register_peer_connection is None while needs_to_send_for_peer is true") + raise ValueError( + "_register_peer_connection is None while needs_to_send_for_peer is true" + ) self._my_connection_info = my_connection_info self._timer = timer self._finished = False @@ -31,7 +36,7 @@ def __init__(self, self._logger = LOGGER.bind( peer=peer.dict(), my_connection_info=my_connection_info.dict(), - needs_to_send_for_peer=self._needs_to_send_for_peer + needs_to_send_for_peer=self._needs_to_send_for_peer, ) self._logger.debug("init") @@ -60,13 +65,15 @@ def is_ready_to_stop(self) -> bool: return self._finished or not self._needs_to_send_for_peer -class RegisterPeerSenderFactory(): - def create(self, - register_peer_connection: Optional[RegisterPeerConnection], - needs_to_send_for_peer: bool, - my_connection_info: ConnectionInfo, - peer: Peer, - timer: Timer) -> RegisterPeerSender: +class RegisterPeerSenderFactory: + def create( + self, + register_peer_connection: Optional[RegisterPeerConnection], + needs_to_send_for_peer: bool, + my_connection_info: ConnectionInfo, + peer: Peer, + timer: Timer, + ) -> RegisterPeerSender: register_peer_sender = RegisterPeerSender( register_peer_connection=register_peer_connection, needs_to_send_for_peer=needs_to_send_for_peer, diff --git a/exasol/analytics/udf/communication/peer_communicator/send_socket_factory.py b/exasol/analytics/udf/communication/peer_communicator/send_socket_factory.py index d426421a..6677b042 100644 --- a/exasol/analytics/udf/communication/peer_communicator/send_socket_factory.py +++ b/exasol/analytics/udf/communication/peer_communicator/send_socket_factory.py @@ -5,17 +5,22 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory, \ - Socket, SocketType +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Socket, + SocketFactory, + SocketType, +) LOGGER: FilteringBoundLogger = structlog.get_logger(__name__) class SendSocketFactory: - def __init__(self, - my_connection_info: ConnectionInfo, - socket_factory: SocketFactory, - peer: Peer): + def __init__( + self, + my_connection_info: ConnectionInfo, + socket_factory: SocketFactory, + peer: Peer, + ): self._my_connection_info = my_connection_info self._peer = peer self._socket_factory = socket_factory @@ -30,7 +35,8 @@ def create_send_socket(self) -> Socket: try: send_socket = self._socket_factory.create_socket(SocketType.DEALER) send_socket.connect( - f"tcp://{self._peer.connection_info.ipaddress.ip_address}:{self._peer.connection_info.port.port}") + f"tcp://{self._peer.connection_info.ipaddress.ip_address}:{self._peer.connection_info.port.port}" + ) return send_socket except Exception: self._logger.exception("Error during connect") diff --git a/exasol/analytics/udf/communication/peer_communicator/sender.py b/exasol/analytics/udf/communication/peer_communicator/sender.py index af3b97de..9915aee4 100644 --- a/exasol/analytics/udf/communication/peer_communicator/sender.py +++ b/exasol/analytics/udf/communication/peer_communicator/sender.py @@ -6,25 +6,31 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.messages import Message from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.send_socket_factory import \ - SendSocketFactory +from exasol.analytics.udf.communication.peer_communicator.send_socket_factory import ( + SendSocketFactory, +) from exasol.analytics.udf.communication.serialization import serialize_message -from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory, Frame +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Frame, + SocketFactory, +) LOGGER: FilteringBoundLogger = structlog.get_logger(__name__) class Sender: - def __init__(self, - my_connection_info: ConnectionInfo, - socket_factory: SocketFactory, - peer: Peer, - send_socket_linger_time_in_ms: int): + def __init__( + self, + my_connection_info: ConnectionInfo, + socket_factory: SocketFactory, + peer: Peer, + send_socket_linger_time_in_ms: int, + ): self._send_socket_linger_time_in_ms = send_socket_linger_time_in_ms self._send_socket_factory = SendSocketFactory( my_connection_info=my_connection_info, socket_factory=socket_factory, - peer=peer + peer=peer, ) def send(self, message: Message): @@ -40,13 +46,17 @@ def send_multipart(self, frames: List[Frame]): class SenderFactory: - def create(self, - my_connection_info: ConnectionInfo, - socket_factory: SocketFactory, - peer: Peer, - send_socket_linger_time_in_ms: int) -> Sender: - sender = Sender(my_connection_info=my_connection_info, - socket_factory=socket_factory, - peer=peer, - send_socket_linger_time_in_ms=send_socket_linger_time_in_ms) + def create( + self, + my_connection_info: ConnectionInfo, + socket_factory: SocketFactory, + peer: Peer, + send_socket_linger_time_in_ms: int, + ) -> Sender: + sender = Sender( + my_connection_info=my_connection_info, + socket_factory=socket_factory, + peer=peer, + send_socket_linger_time_in_ms=send_socket_linger_time_in_ms, + ) return sender diff --git a/exasol/analytics/udf/communication/peer_communicator/synchronize_connection_sender.py b/exasol/analytics/udf/communication/peer_communicator/synchronize_connection_sender.py index 43e07bd2..964bbd1e 100644 --- a/exasol/analytics/udf/communication/peer_communicator/synchronize_connection_sender.py +++ b/exasol/analytics/udf/communication/peer_communicator/synchronize_connection_sender.py @@ -11,11 +11,13 @@ class SynchronizeConnectionSender: - def __init__(self, - my_connection_info: ConnectionInfo, - peer: Peer, - sender: Sender, - timer: Timer): + def __init__( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + sender: Sender, + timer: Timer, + ): self._my_connection_info = my_connection_info self._peer = peer self._timer = timer @@ -23,8 +25,8 @@ def __init__(self, self._finished = False self._send_attempt_count = 0 self._logger = LOGGER.bind( - peer=peer.dict(), - my_connection_info=my_connection_info.dict()) + peer=peer.dict(), my_connection_info=my_connection_info.dict() + ) self._logger.debug("init") def stop(self): @@ -48,8 +50,9 @@ def _send(self): __root__=messages.SynchronizeConnection( source=self._my_connection_info, destination=self._peer, - attempt=self._send_attempt_count - )) + attempt=self._send_attempt_count, + ) + ) self._sender.send(message) def _should_we_send(self): @@ -59,15 +62,14 @@ def _should_we_send(self): class SynchronizeConnectionSenderFactory: - def create(self, - my_connection_info: ConnectionInfo, - peer: Peer, - sender: Sender, - timer: Timer) -> SynchronizeConnectionSender: + def create( + self, + my_connection_info: ConnectionInfo, + peer: Peer, + sender: Sender, + timer: Timer, + ) -> SynchronizeConnectionSender: synchronize_connection_sender = SynchronizeConnectionSender( - my_connection_info=my_connection_info, - peer=peer, - sender=sender, - timer=timer + my_connection_info=my_connection_info, peer=peer, sender=sender, timer=timer ) return synchronize_connection_sender diff --git a/exasol/analytics/udf/communication/peer_communicator/timer.py b/exasol/analytics/udf/communication/peer_communicator/timer.py index 13a1cd63..2a9213c5 100644 --- a/exasol/analytics/udf/communication/peer_communicator/timer.py +++ b/exasol/analytics/udf/communication/peer_communicator/timer.py @@ -3,9 +3,7 @@ class Timer: - def __init__(self, - clock: Clock, - timeout_in_ms: int): + def __init__(self, clock: Clock, timeout_in_ms: int): self._timeout_in_ms = timeout_in_ms self._clock = clock self._last_send_timestamp_in_ms = clock.current_timestamp_in_ms() @@ -21,7 +19,5 @@ def is_time(self): class TimerFactory: - def create(self, - clock: Clock, - timeout_in_ms: int): + def create(self, clock: Clock, timeout_in_ms: int): return Timer(clock=clock, timeout_in_ms=timeout_in_ms) diff --git a/exasol/analytics/udf/communication/serialization.py b/exasol/analytics/udf/communication/serialization.py index 97220643..b211efef 100644 --- a/exasol/analytics/udf/communication/serialization.py +++ b/exasol/analytics/udf/communication/serialization.py @@ -8,7 +8,7 @@ def serialize_message(obj: BaseModel) -> bytes: return json_str.encode("UTF-8") -T = TypeVar('T', bound=BaseModel) +T = TypeVar("T", bound=BaseModel) def deserialize_message(message: bytes, base_model_class: Type[T]) -> T: diff --git a/exasol/analytics/udf/communication/socket_factory/abstract.py b/exasol/analytics/udf/communication/socket_factory/abstract.py index d8f5dab4..f742168d 100644 --- a/exasol/analytics/udf/communication/socket_factory/abstract.py +++ b/exasol/analytics/udf/communication/socket_factory/abstract.py @@ -1,6 +1,6 @@ import abc from enum import Enum, auto -from typing import Union, List, Set, Optional, Dict +from typing import Dict, List, Optional, Set, Union class Frame(abc.ABC): @@ -49,10 +49,11 @@ def connect(self, address: str): """Connect to the given address""" @abc.abstractmethod - def poll(self, - flags: Union[PollerFlag, Set[PollerFlag]], - timeout_in_ms: Optional[int] = None) \ - -> Optional[Set[PollerFlag]]: + def poll( + self, + flags: Union[PollerFlag, Set[PollerFlag]], + timeout_in_ms: Optional[int] = None, + ) -> Optional[Set[PollerFlag]]: """ Checks if the socket can receive or send without blocking or if timeout is set, it waits until a requested event occurred. @@ -83,12 +84,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): class Poller(abc.ABC): @abc.abstractmethod - def register(self, socket: Socket, flags: Union[PollerFlag, Set[PollerFlag]]) -> None: + def register( + self, socket: Socket, flags: Union[PollerFlag, Set[PollerFlag]] + ) -> None: """Register a socket with the events we want to poll.""" @abc.abstractmethod - def poll(self, timeout_in_ms: Optional[int] = None) -> Dict[Socket, Set[PollerFlag]]: - """Poll if an event occurred for the registered sockets or wait until an event occurred, if timeout is set. """ + def poll( + self, timeout_in_ms: Optional[int] = None + ) -> Dict[Socket, Set[PollerFlag]]: + """Poll if an event occurred for the registered sockets or wait until an event occurred, if timeout is set.""" class SocketType(Enum): diff --git a/exasol/analytics/udf/communication/socket_factory/fault_injection.py b/exasol/analytics/udf/communication/socket_factory/fault_injection.py index dd4a3c06..6ad485bb 100644 --- a/exasol/analytics/udf/communication/socket_factory/fault_injection.py +++ b/exasol/analytics/udf/communication/socket_factory/fault_injection.py @@ -1,4 +1,4 @@ -from typing import Union, List, Set, Optional, Dict, cast +from typing import Dict, List, Optional, Set, Union, cast from warnings import warn import structlog @@ -28,15 +28,18 @@ def _is_address_inproc(address): class Socket(abstract.Socket): - def __init__(self, internal_socket: abstract.Socket, send_fault_probability: float, - random_state: RandomState): + def __init__( + self, + internal_socket: abstract.Socket, + send_fault_probability: float, + random_state: RandomState, + ): self._random_state = random_state if not (0 <= send_fault_probability < 1): raise ValueError( - f"send_fault_probability needs to be between 0 and 1 (exclusive) was {send_fault_probability}.") - self._logger = LOGGER.bind( - socket=str(self) - ) + f"send_fault_probability needs to be between 0 and 1 (exclusive) was {send_fault_probability}." + ) + self._logger = LOGGER.bind(socket=str(self)) self._send_fault_probability = send_fault_probability self._internal_socket = internal_socket self._is_inproc = False @@ -87,10 +90,11 @@ def connect(self, address: str): self._is_inproc = _is_address_inproc(address) self._internal_socket.connect(address) - def poll(self, - flags: Union[abstract.PollerFlag, Set[abstract.PollerFlag]], - timeout_in_ms: Optional[int] = None) \ - -> Optional[Set[abstract.PollerFlag]]: + def poll( + self, + flags: Union[abstract.PollerFlag, Set[abstract.PollerFlag]], + timeout_in_ms: Optional[int] = None, + ) -> Optional[Set[abstract.PollerFlag]]: return self._internal_socket.poll(flags, timeout_in_ms) def close(self, linger=None): @@ -126,16 +130,20 @@ def __init__(self, internal_poller: abstract.Poller): self._internal_poller = internal_poller self._socket_map = {} - def register(self, socket: abstract.Socket, - flags: Union[abstract.PollerFlag, Set[abstract.PollerFlag]]) -> None: + def register( + self, + socket: abstract.Socket, + flags: Union[abstract.PollerFlag, Set[abstract.PollerFlag]], + ) -> None: if not isinstance(socket, Socket): raise TypeError(f"Socket type not supported {socket}") internal_socket = cast(Socket, socket)._internal_socket self._socket_map[internal_socket] = socket self._internal_poller.register(internal_socket, flags) - def poll(self, timeout_in_ms: Optional[int] = None) \ - -> Dict[abstract.Socket, Set[abstract.PollerFlag]]: + def poll( + self, timeout_in_ms: Optional[int] = None + ) -> Dict[abstract.Socket, Set[abstract.PollerFlag]]: poll_result = self._internal_poller.poll(timeout_in_ms) return { self._socket_map[internal_socket]: flags @@ -145,18 +153,26 @@ def poll(self, timeout_in_ms: Optional[int] = None) \ class FaultInjectionSocketFactory(abstract.SocketFactory): - def __init__(self, socket_factory: abstract.SocketFactory, send_fault_probability: float, - random_state: RandomState): + def __init__( + self, + socket_factory: abstract.SocketFactory, + send_fault_probability: float, + random_state: RandomState, + ): if not (0 <= send_fault_probability < 1): raise ValueError( - f"send_fault_probability needs to be between 0 and 1 (exclusive) was {send_fault_probability}.") + f"send_fault_probability needs to be between 0 and 1 (exclusive) was {send_fault_probability}." + ) self._send_fault_probability = send_fault_probability self._random_state = random_state self._socket_factory = socket_factory def create_socket(self, socket_type: abstract.SocketType) -> abstract.Socket: - return Socket(self._socket_factory.create_socket(socket_type), self._send_fault_probability, - self._random_state) + return Socket( + self._socket_factory.create_socket(socket_type), + self._send_fault_probability, + self._random_state, + ) def create_frame(self, message_part: bytes) -> abstract.Frame: return Frame(self._socket_factory.create_frame(message_part)) diff --git a/exasol/analytics/udf/communication/socket_factory/zmq_wrapper.py b/exasol/analytics/udf/communication/socket_factory/zmq_wrapper.py index 5ac1801b..20cc8cb5 100644 --- a/exasol/analytics/udf/communication/socket_factory/zmq_wrapper.py +++ b/exasol/analytics/udf/communication/socket_factory/zmq_wrapper.py @@ -1,9 +1,16 @@ -from typing import Union, List, Set, Optional, Dict +from typing import Dict, List, Optional, Set, Union from warnings import warn + import zmq -from exasol.analytics.udf.communication.socket_factory.abstract import Frame, Socket, \ - PollerFlag, Poller, SocketFactory, SocketType +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Frame, + Poller, + PollerFlag, + Socket, + SocketFactory, + SocketType, +) def _flags_to_bitmask(flags: Union[PollerFlag, Set[PollerFlag]]) -> int: @@ -80,10 +87,15 @@ def bind_to_random_port(self, address: str) -> int: def connect(self, address: str): self._internal_socket.connect(address) - def poll(self, flags: Union[PollerFlag, Set[PollerFlag]], timeout_in_ms: Optional[int] = None) \ - -> Set[PollerFlag]: + def poll( + self, + flags: Union[PollerFlag, Set[PollerFlag]], + timeout_in_ms: Optional[int] = None, + ) -> Set[PollerFlag]: input_bitmask = _flags_to_bitmask(flags) - result_bitmask = self._internal_socket.poll(flags=input_bitmask, timeout=timeout_in_ms) + result_bitmask = self._internal_socket.poll( + flags=input_bitmask, timeout=timeout_in_ms + ) result_set = _bitmask_to_flags(result_bitmask) return result_set @@ -119,7 +131,9 @@ def __init__(self): self._internal_poller = zmq.Poller() self._sockets_map: Dict[zmq.Socket, ZMQSocket] = {} - def register(self, socket: Socket, flags: Union[PollerFlag, Set[PollerFlag]]) -> None: + def register( + self, socket: Socket, flags: Union[PollerFlag, Set[PollerFlag]] + ) -> None: if isinstance(socket, ZMQSocket): self._sockets_map[socket._internal_socket] = socket bitmask = _flags_to_bitmask(flags) @@ -127,10 +141,14 @@ def register(self, socket: Socket, flags: Union[PollerFlag, Set[PollerFlag]]) -> else: raise ValueError(f"Socket not supported: {socket}") - def poll(self, timeout_in_ms: Optional[int] = None) -> Dict[Socket, Set[PollerFlag]]: + def poll( + self, timeout_in_ms: Optional[int] = None + ) -> Dict[Socket, Set[PollerFlag]]: poll_result = dict(self._internal_poller.poll(timeout_in_ms)) - result = {self._sockets_map[zmq_socket]: _bitmask_to_flags(bitmask) - for zmq_socket, bitmask in poll_result.items()} + result = { + self._sockets_map[zmq_socket]: _bitmask_to_flags(bitmask) + for zmq_socket, bitmask in poll_result.items() + } return result diff --git a/exasol/analytics/udf/utils/context_wrapper.py b/exasol/analytics/udf/utils/context_wrapper.py index ce0f9a65..13e2146f 100644 --- a/exasol/analytics/udf/utils/context_wrapper.py +++ b/exasol/analytics/udf/utils/context_wrapper.py @@ -1,14 +1,16 @@ from collections import OrderedDict -from typing import Union, Optional, Mapping +from typing import Mapping, Optional, Union import pandas as pd -class UDFContextWrapper(): +class UDFContextWrapper: def __init__(self, ctx, column_mapping: Mapping[str, str], start_col: int = 0): self.start_col = start_col if not isinstance(column_mapping, OrderedDict): - raise ValueError(f"column_mapping needs to be a OrderedDict, got {type(column_mapping)}") + raise ValueError( + f"column_mapping needs to be a OrderedDict, got {type(column_mapping)}" + ) self.column_mapping = column_mapping self.original_columns = list(self.column_mapping.keys()) self.new_columns = list(self.column_mapping.values()) @@ -17,15 +19,21 @@ def __init__(self, ctx, column_mapping: Mapping[str, str], start_col: int = 0): def _get_mapped_column(self, original_name: str) -> str: if original_name in self.column_mapping: return self.column_mapping[original_name] - raise ValueError(f"Column {original_name} does not exists in mapping {self.column_mapping}") + raise ValueError( + f"Column {original_name} does not exists in mapping {self.column_mapping}" + ) def __getattr__(self, name): return self.ctx[self._get_mapped_column(name)] - def get_dataframe(self, num_rows: Union[str, int], start_col: int = 0) -> Optional[pd.DataFrame]: + def get_dataframe( + self, num_rows: Union[str, int], start_col: int = 0 + ) -> Optional[pd.DataFrame]: df = self.ctx.get_dataframe(num_rows, start_col=self.start_col) filtered_df = df[self.original_columns] - filtered_df.columns = [self._get_mapped_column(column) for column in filtered_df.columns] + filtered_df.columns = [ + self._get_mapped_column(column) for column in filtered_df.columns + ] filtered_df_from_start_col = filtered_df.iloc[:, start_col:] return filtered_df_from_start_col diff --git a/exasol/analytics/udf/utils/iterators.py b/exasol/analytics/udf/utils/iterators.py index 10c81d88..a6a2d609 100644 --- a/exasol/analytics/udf/utils/iterators.py +++ b/exasol/analytics/udf/utils/iterators.py @@ -1,13 +1,16 @@ -from typing import Callable, Any +from typing import Any, Callable import pandas as pd -def iterate_trough_dataset(ctx, batch_size: int, - map_function: Callable[[pd.DataFrame], Any], - init_function: Callable[[], Any], - aggregate_function: Callable[[Any, Any], Any], - reset_function: Callable[[], Any]): +def iterate_trough_dataset( + ctx, + batch_size: int, + map_function: Callable[[pd.DataFrame], Any], + init_function: Callable[[], Any], + aggregate_function: Callable[[Any, Any], Any], + reset_function: Callable[[], Any], +): reset_function() number_of_tuples_left = ctx.size() state = init_function() diff --git a/exasol/analytics/utils/data_classes_runtime_type_check.py b/exasol/analytics/utils/data_classes_runtime_type_check.py index 9fdd1b28..eca48879 100644 --- a/exasol/analytics/utils/data_classes_runtime_type_check.py +++ b/exasol/analytics/utils/data_classes_runtime_type_check.py @@ -3,10 +3,12 @@ import typeguard from typeguard import TypeCheckError + def check_dataclass_types(datacls): for field in fields(datacls): try: - typeguard.check_type(value=datacls.__dict__[field.name], - expected_type=field.type) + typeguard.check_type( + value=datacls.__dict__[field.name], expected_type=field.type + ) except TypeCheckError as e: raise TypeCheckError(f"Field '{field.name}' has wrong type: {e}") diff --git a/exasol/analytics/utils/dynamic_modules.py b/exasol/analytics/utils/dynamic_modules.py index add6dde9..7b87ef4a 100644 --- a/exasol/analytics/utils/dynamic_modules.py +++ b/exasol/analytics/utils/dynamic_modules.py @@ -1,7 +1,7 @@ -import sys import importlib -from typing import Any +import sys from types import ModuleType +from typing import Any def _create_module(name: str) -> ModuleType: @@ -13,7 +13,6 @@ def _register_module_for_import(name: str, mod: ModuleType): sys.modules[name] = mod - class ModuleExistsException(Exception): """ When trying create a module that already exists. diff --git a/exasol/analytics/utils/hash_generation_for_object.py b/exasol/analytics/utils/hash_generation_for_object.py index 266b093f..2f8620a2 100644 --- a/exasol/analytics/utils/hash_generation_for_object.py +++ b/exasol/analytics/utils/hash_generation_for_object.py @@ -1,9 +1,8 @@ -from typing import Hashable, Dict, Any, Iterable, Set +from typing import Any, Dict, Hashable, Iterable, Set def generate_hash_for_object(obj: Any) -> int: - return hash(tuple(_hash_object(v, set()) - for k, v in sorted(obj.__dict__.items()))) + return hash(tuple(_hash_object(v, set()) for k, v in sorted(obj.__dict__.items()))) def _hash_object(obj: Any, already_seen: Set[int]) -> int: @@ -23,13 +22,12 @@ def _hash_object(obj: Any, already_seen: Set[int]) -> int: if isinstance(obj, Hashable): return hash(obj) elif isinstance(obj, Dict): - return \ - hash( - ( - _hash_object(obj.keys(), already_seen), - _hash_object(obj.values(), already_seen) - ) + return hash( + ( + _hash_object(obj.keys(), already_seen), + _hash_object(obj.values(), already_seen), ) + ) elif isinstance(obj, Iterable): return hash(tuple(_hash_object(item, already_seen) for item in obj)) else: diff --git a/exasol/analytics/utils/repr_generation_for_object.py b/exasol/analytics/utils/repr_generation_for_object.py index e3f35eff..f75d824b 100644 --- a/exasol/analytics/utils/repr_generation_for_object.py +++ b/exasol/analytics/utils/repr_generation_for_object.py @@ -1,3 +1,3 @@ def generate_repr_for_object(obj): parameters = ",".join(f"{k}: {v}" for k, v in obj.__dict__.items()) - return f"{obj.__class__.__name__}({parameters})" \ No newline at end of file + return f"{obj.__class__.__name__}({parameters})" diff --git a/noxconfig.py b/noxconfig.py index 91a064a5..212f57a2 100644 --- a/noxconfig.py +++ b/noxconfig.py @@ -10,7 +10,7 @@ class Config: root: Path = ROOT_DIR doc: Path = ROOT_DIR / "doc" version_file: Path = ROOT_DIR / "version.py" - path_filters: Iterable[str] = ("dist", ".eggs", "venv") + path_filters: Iterable[str] = ("dist", ".eggs", "venv", ".conda_env") PROJECT_CONFIG = Config() diff --git a/noxfile.py b/noxfile.py index b9ede710..df1ceb0e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,14 +1,15 @@ import json import os -from pathlib import Path -from exasol.analytics.query_handler.deployment.slc import custom_slc_builder from datetime import datetime +from pathlib import Path import nox -from nox import Session # imports all nox task provided by the toolbox from exasol.toolbox.nox.tasks import * +from nox import Session + +from exasol.analytics.query_handler.deployment.slc import custom_slc_builder from noxconfig import ROOT_DIR # default actions to be run if nothing is explicitly specified with the -s option @@ -54,7 +55,14 @@ def install_dev_env(session: Session): @nox.session(python=False) def amalgate_lua_scripts(session: Session): - script = ROOT_DIR / "exasol" / "analytics" / "query_handler" / "deployment" / "regenerate_scripts.py" + script = ( + ROOT_DIR + / "exasol" + / "analytics" + / "query_handler" + / "deployment" + / "regenerate_scripts.py" + ) _run_in_dev_env_poetry_call(session, "python", str(script)) @@ -71,10 +79,7 @@ def run_python_unit_tests(session: Session): def _generate_test_matrix_entry(test_file: Path): - return { - "name": str(test_file.name), - "path": str(test_file) - } + return {"name": str(test_file.name), "path": str(test_file)} def _generate_github_integration_tests_without_db_matrix() -> str: @@ -94,7 +99,7 @@ def generate_github_integration_tests_without_db_matrix_json(session: Session): @nox.session(python=False) def write_github_integration_tests_without_db_matrix(session: Session): json_str = _generate_github_integration_tests_without_db_matrix() - github_output_definition = f'matrix={json_str}' + github_output_definition = f"matrix={json_str}" if "GITHUB_OUTPUT" in os.environ: with open(os.environ["GITHUB_OUTPUT"], "a") as fh: print(github_output_definition, file=fh) diff --git a/tests/conftest.py b/tests/conftest.py index d79194e1..e0e84410 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ pytest_plugins = [ "tests.integration_tests.with_db.fixtures.setup_database_fixture", - "tests.unit_tests.query_handler.fixtures" + "tests.unit_tests.query_handler.fixtures", ] diff --git a/tests/deployment/test_aaf_exasol_lua_script_generator.py b/tests/deployment/test_aaf_exasol_lua_script_generator.py index e062a83a..1f5ba01a 100644 --- a/tests/deployment/test_aaf_exasol_lua_script_generator.py +++ b/tests/deployment/test_aaf_exasol_lua_script_generator.py @@ -1,6 +1,8 @@ from io import StringIO -from exasol.analytics.query_handler.deployment.aaf_exasol_lua_script_generator import get_aaf_query_loop_lua_script_generator +from exasol.analytics.query_handler.deployment.aaf_exasol_lua_script_generator import ( + get_aaf_query_loop_lua_script_generator, +) def test_get_aaf_query_loop_lua_script_generator(): @@ -9,6 +11,9 @@ def test_get_aaf_query_loop_lua_script_generator(): generator.generate_script(output_buffer) script = output_buffer.getvalue() - assert """CREATE OR REPLACE LUA SCRIPT "AAF_RUN_QUERY_HANDLER"(json_str) RETURNS TABLE AS""" in script and \ - "function query_handler_runner_main" in script and \ - "query_handler_runner_main(json_str, exa)" + assert ( + """CREATE OR REPLACE LUA SCRIPT "AAF_RUN_QUERY_HANDLER"(json_str) RETURNS TABLE AS""" + in script + and "function query_handler_runner_main" in script + and "query_handler_runner_main(json_str, exa)" + ) diff --git a/tests/deployment/test_exasol_lua_script_generator.py b/tests/deployment/test_exasol_lua_script_generator.py index 2418f3aa..50c790d4 100644 --- a/tests/deployment/test_exasol_lua_script_generator.py +++ b/tests/deployment/test_exasol_lua_script_generator.py @@ -3,24 +3,26 @@ from jinja2 import Environment, FileSystemLoader, select_autoescape -from exasol.analytics.query_handler.deployment.exasol_lua_script_generator import ExasolLuaScriptGenerator +from exasol.analytics.query_handler.deployment.exasol_lua_script_generator import ( + ExasolLuaScriptGenerator, +) from exasol.analytics.query_handler.deployment.lua_script_bundle import LuaScriptBundle def test(): resource_dir = Path(__file__).parent / "resources" lua_bundle_dir = resource_dir / "lua_bundle" - bundle = LuaScriptBundle(lua_main_file=lua_bundle_dir / "main.lua", - lua_source_files=[lua_bundle_dir / "test_module_1.lua"], - lua_modules=["test_module_1"]) + bundle = LuaScriptBundle( + lua_main_file=lua_bundle_dir / "main.lua", + lua_source_files=[lua_bundle_dir / "test_module_1.lua"], + lua_modules=["test_module_1"], + ) env = Environment( - loader=FileSystemLoader(resource_dir), - autoescape=select_autoescape() + loader=FileSystemLoader(resource_dir), autoescape=select_autoescape() ) template = env.get_template("create_script.jinja") generator = ExasolLuaScriptGenerator( - lua_script_bundle=bundle, - jinja_template=template + lua_script_bundle=bundle, jinja_template=template ) output_buffer = StringIO() generator.generate_script(output_buffer) diff --git a/tests/deployment/test_lua_script_bundle.py b/tests/deployment/test_lua_script_bundle.py index 794118b5..03c3cba5 100644 --- a/tests/deployment/test_lua_script_bundle.py +++ b/tests/deployment/test_lua_script_bundle.py @@ -6,13 +6,15 @@ def test(tmp_path): resource_dir = Path(__file__).parent / "resources" / "lua_bundle" - bundle = LuaScriptBundle(lua_main_file=resource_dir / "main.lua", - lua_source_files=[resource_dir / "test_module_1.lua"], - lua_modules=["test_module_1"]) + bundle = LuaScriptBundle( + lua_main_file=resource_dir / "main.lua", + lua_source_files=[resource_dir / "test_module_1.lua"], + lua_modules=["test_module_1"], + ) bundle_file_name = "bundle.lua" bundle_lua = tmp_path / bundle_file_name with bundle_lua.open("w") as file: bundle.bundle_lua_scripts(file) output = subprocess.check_output(["lua", bundle_file_name], cwd=tmp_path) output_decode = output.decode("utf-8").strip() - assert output_decode == "TEST_OUTPUT" \ No newline at end of file + assert output_decode == "TEST_OUTPUT" diff --git a/tests/deployment/test_scripts_deployer.py b/tests/deployment/test_scripts_deployer.py index 34ac2fb5..632216ca 100644 --- a/tests/deployment/test_scripts_deployer.py +++ b/tests/deployment/test_scripts_deployer.py @@ -1,5 +1,7 @@ -from exasol.analytics.query_handler.deployment.scripts_deployer import ScriptsDeployer -from exasol.analytics.query_handler.deployment.aaf_exasol_lua_script_generator import save_aaf_query_loop_lua_script +from exasol.analytics.query_handler.deployment.aaf_exasol_lua_script_generator import ( + save_aaf_query_loop_lua_script, +) +from exasol.analytics.query_handler.deployment.scripts_deployer import ScriptsDeployer from tests.utils.db_queries import DBQueries @@ -12,5 +14,4 @@ def test_scripts_deployer(deployed_slc, language_alias, pyexasol_connection, req schema_name, pyexasol_connection, ).deploy_scripts() - assert DBQueries.check_all_scripts_deployed( - pyexasol_connection, schema_name) + assert DBQueries.check_all_scripts_deployed(pyexasol_connection, schema_name) diff --git a/tests/deployment/test_scripts_deployer_cli.py b/tests/deployment/test_scripts_deployer_cli.py index 378e9b71..3aaad46f 100644 --- a/tests/deployment/test_scripts_deployer_cli.py +++ b/tests/deployment/test_scripts_deployer_cli.py @@ -1,12 +1,16 @@ from click.testing import CliRunner -from exasol.analytics import deploy -from tests.utils.db_queries import DBQueries + +from exasol.analytics.query_handler.deployment import deploy from exasol.analytics.query_handler.deployment.slc import LANGUAGE_ALIAS +from tests.utils.db_queries import DBQueries -def test_scripts_deployer_cli(upload_language_container, - backend_aware_database_params, - pyexasol_connection, request): +def test_scripts_deployer_cli( + upload_language_container, + backend_aware_database_params, + pyexasol_connection, + request, +): schema_name = request.node.name pyexasol_connection.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE;") dsn = backend_aware_database_params["dsn"] @@ -14,17 +18,18 @@ def test_scripts_deployer_cli(upload_language_container, password = backend_aware_database_params["password"] args_list = [ "scripts", - "--dsn", dns, - "--user", user, - "--pass", password, - "--schema", schema_name, - "--language-alias", LANGUAGE_ALIAS + "--dsn", + dns, + "--user", + user, + "--pass", + password, + "--schema", + schema_name, + "--language-alias", + LANGUAGE_ALIAS, ] runner = CliRunner() result = runner.invoke(deploy.main, args_list) assert result.exit_code == 0 - assert DBQueries.check_all_scripts_deployed( - pyexasol_connection, schema_name) - - - + assert DBQueries.check_all_scripts_deployed(pyexasol_connection, schema_name) diff --git a/tests/integration_tests/data_science_utils/test_pyexasol_sql_executor.py b/tests/integration_tests/data_science_utils/test_pyexasol_sql_executor.py index 62cac96c..9167456e 100644 --- a/tests/integration_tests/data_science_utils/test_pyexasol_sql_executor.py +++ b/tests/integration_tests/data_science_utils/test_pyexasol_sql_executor.py @@ -1,11 +1,7 @@ import pyexasol import pytest -from exasol.analytics.schema import ( - Column, - ColumnType, - ColumnNameBuilder, -) +from exasol.analytics.schema import Column, ColumnNameBuilder, ColumnType from exasol.analytics.sql_executor.pyexasol_impl import PyexasolSQLExecutor @@ -24,17 +20,24 @@ def pyexasol_sql_executor(): @pytest.fixture() def pyexasol_result_set(pyexasol_sql_executor): row_count = 100000 - expected_result = [(1, "a", '1.1')] * row_count + expected_result = [(1, "a", "1.1")] * row_count expected_columns = [ - Column(ColumnNameBuilder.create("c1"), - ColumnType(name="DECIMAL", precision=1, scale=0)), - Column(ColumnNameBuilder.create("c2"), - ColumnType(name="CHAR", size=1, characterSet="ASCII")), - Column(ColumnNameBuilder.create("c3"), - ColumnType(name="DECIMAL", precision=2, scale=1)), + Column( + ColumnNameBuilder.create("c1"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ), + Column( + ColumnNameBuilder.create("c2"), + ColumnType(name="CHAR", size=1, characterSet="ASCII"), + ), + Column( + ColumnNameBuilder.create("c3"), + ColumnType(name="DECIMAL", precision=2, scale=1), + ), ] result_set = pyexasol_sql_executor.execute( - f"""SELECT 1 as "c1", 'a' as "c2", 1.1 as "c3" FROM VALUES BETWEEN 1 and {row_count} as t(i);""") + f"""SELECT 1 as "c1", 'a' as "c2", 1.1 as "c3" FROM VALUES BETWEEN 1 and {row_count} as t(i);""" + ) return result_set, expected_result, expected_columns diff --git a/tests/integration_tests/with_db/conftest.py b/tests/integration_tests/with_db/conftest.py index 7925180f..69260c19 100644 --- a/tests/integration_tests/with_db/conftest.py +++ b/tests/integration_tests/with_db/conftest.py @@ -1,12 +1,12 @@ import pytest +from exasol.python_extension_common.deployment.language_container_builder import ( + LanguageContainerBuilder, + find_path_backwards, +) from exasol.analytics.query_handler.deployment.slc import ( - custom_slc_builder, LANGUAGE_ALIAS, -) -from exasol.python_extension_common.deployment.language_container_builder import ( - find_path_backwards, - LanguageContainerBuilder, + custom_slc_builder, ) diff --git a/tests/integration_tests/with_db/fixtures/setup_database_fixture.py b/tests/integration_tests/with_db/fixtures/setup_database_fixture.py index 04e30319..4b43b0b0 100644 --- a/tests/integration_tests/with_db/fixtures/setup_database_fixture.py +++ b/tests/integration_tests/with_db/fixtures/setup_database_fixture.py @@ -1,9 +1,12 @@ -import pytest +from typing import Any, Callable, Tuple + import pyexasol -from typing import Any, Tuple, Callable -from exasol.analytics.query_handler.deployment.scripts_deployer import ScriptsDeployer -from exasol.analytics.query_handler.deployment.aaf_exasol_lua_script_generator import save_aaf_query_loop_lua_script +import pytest +from exasol.analytics.query_handler.deployment.aaf_exasol_lua_script_generator import ( + save_aaf_query_loop_lua_script, +) +from exasol.analytics.query_handler.deployment.scripts_deployer import ScriptsDeployer BUCKETFS_CONNECTION_NAME = "TEST_AAF_BFS_CONN" @@ -28,10 +31,10 @@ def deployed_scripts(pyexasol_connection, itest_db_schema, language_alias) -> No @pytest.fixture(scope="module") def database_with_slc( - deployed_scripts, - itest_db_schema, - bucketfs_connection_factory, - deployed_slc, + deployed_scripts, + itest_db_schema, + bucketfs_connection_factory, + deployed_slc, ) -> Tuple[str, str]: bucketfs_connection_factory(BUCKETFS_CONNECTION_NAME, "my-folder") return BUCKETFS_CONNECTION_NAME, itest_db_schema diff --git a/tests/integration_tests/with_db/test_query_loop_integration.py b/tests/integration_tests/with_db/test_query_loop_integration.py index 4f867cec..799d308d 100644 --- a/tests/integration_tests/with_db/test_query_loop_integration.py +++ b/tests/integration_tests/with_db/test_query_loop_integration.py @@ -1,12 +1,15 @@ import json import textwrap -from typing import Tuple, List +from typing import List, Tuple import pyexasol import pytest -from tests.test_package.test_query_handlers.query_handler_test import \ - FINAL_RESULT, QUERY_LIST, TEST_INPUT +from tests.test_package.test_query_handlers.query_handler_test import ( + FINAL_RESULT, + QUERY_LIST, + TEST_INPUT, +) QUERY_FLUSH_STATS = """FLUSH STATISTICS""" QUERY_AUDIT_LOGS = """ @@ -19,29 +22,28 @@ def test_query_loop_integration_with_one_iteration( - database_with_slc, pyexasol_connection): + database_with_slc, pyexasol_connection +): bucketfs_connection_name, schema_name = database_with_slc args = json.dumps( { "query_handler": { "factory_class": { "name": "QueryHandlerTestWithOneIterationFactory", - "module": "test_query_handlers.query_handler_test" + "module": "test_query_handlers.query_handler_test", }, - "udf": { - "schema": schema_name, - "name": "AAF_QUERY_HANDLER_UDF" - }, - "parameter": TEST_INPUT + "udf": {"schema": schema_name, "name": "AAF_QUERY_HANDLER_UDF"}, + "parameter": TEST_INPUT, }, "temporary_output": { "bucketfs_location": { "directory": "directory", - "connection_name": bucketfs_connection_name + "connection_name": bucketfs_connection_name, }, - "schema_name": schema_name - } - }) + "schema_name": schema_name, + }, + } + ) query = f"EXECUTE SCRIPT {schema_name}.AAF_RUN_QUERY_HANDLER('{args}')" result = pyexasol_connection.execute(textwrap.dedent(query)).fetchall() @@ -50,7 +52,8 @@ def test_query_loop_integration_with_one_iteration( def test_query_loop_integration_with_one_iteration_with_not_released_child_query_handler_context( - database_with_slc, backend_aware_database_params): + database_with_slc, backend_aware_database_params +): # start a new db session, to isolate the EXECUTE SCRIPT and the QueryHandler queries # into its own session, for easier retrieval conn = pyexasol.connect(**backend_aware_database_params) @@ -62,31 +65,34 @@ def test_query_loop_integration_with_one_iteration_with_not_released_child_query "query_handler": { "factory_class": { "name": "QueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContextFactory", - "module": "test_query_handlers.query_handler_test" + "module": "test_query_handlers.query_handler_test", }, - "udf": { - "schema": schema_name, - "name": "AAF_QUERY_HANDLER_UDF" - }, - "parameter": TEST_INPUT + "udf": {"schema": schema_name, "name": "AAF_QUERY_HANDLER_UDF"}, + "parameter": TEST_INPUT, }, "temporary_output": { "bucketfs_location": { "directory": "directory", - "connection_name": bucketfs_connection_name + "connection_name": bucketfs_connection_name, }, - "schema_name": schema_name - } - }) + "schema_name": schema_name, + }, + } + ) with pytest.raises(pyexasol.ExaQueryError) as caught_exception: query = f"EXECUTE SCRIPT {schema_name}.AAF_RUN_QUERY_HANDLER('{args}')" result = conn.execute(textwrap.dedent(query)).fetchall() - assert "E-AAF-4: Error occurred while calling the query handler." in caught_exception.value.message and \ - "The following child contexts were not released" in caught_exception.value.message + assert ( + "E-AAF-4: Error occurred while calling the query handler." + in caught_exception.value.message + and "The following child contexts were not released" + in caught_exception.value.message + ) def test_query_loop_integration_with_one_iteration_with_not_released_temporary_object( - database_with_slc, backend_aware_database_params): + database_with_slc, backend_aware_database_params +): # start a new db session, to isolate the EXECUTE SCRIPT and the QueryHandler queries # into its own session, for easier retrieval of the audit log conn = pyexasol.connect(**backend_aware_database_params) @@ -98,42 +104,49 @@ def test_query_loop_integration_with_one_iteration_with_not_released_temporary_o "query_handler": { "factory_class": { "name": "QueryHandlerWithOneIterationWithNotReleasedTemporaryObjectFactory", - "module": "test_query_handlers.query_handler_test" + "module": "test_query_handlers.query_handler_test", }, - "udf": { - "schema": schema_name, - "name": "AAF_QUERY_HANDLER_UDF" - }, - "parameter": TEST_INPUT + "udf": {"schema": schema_name, "name": "AAF_QUERY_HANDLER_UDF"}, + "parameter": TEST_INPUT, }, "temporary_output": { "bucketfs_location": { "directory": "directory", - "connection_name": bucketfs_connection_name + "connection_name": bucketfs_connection_name, }, - "schema_name": schema_name - } - }) + "schema_name": schema_name, + }, + } + ) with pytest.raises(pyexasol.ExaQueryError) as caught_exception: query = f"EXECUTE SCRIPT {schema_name}.AAF_RUN_QUERY_HANDLER('{args}')" result = conn.execute(textwrap.dedent(query)).fetchall() - assert "E-AAF-4: Error occurred while calling the query handler." in caught_exception.value.message and \ - "The following child contexts were not released" in caught_exception.value.message + assert ( + "E-AAF-4: Error occurred while calling the query handler." + in caught_exception.value.message + and "The following child contexts were not released" + in caught_exception.value.message + ) # get audit logs after executing query loop conn.execute(QUERY_FLUSH_STATS) - audit_logs: List[Tuple[str]] = conn.execute(textwrap.dedent(QUERY_AUDIT_LOGS)) \ - .fetchmany(N_FETCHED_ROWS) + audit_logs: List[Tuple[str]] = conn.execute( + textwrap.dedent(QUERY_AUDIT_LOGS) + ).fetchmany(N_FETCHED_ROWS) executed_queries = [row[0] for row in audit_logs] - table_cleanup_query = [query for query in executed_queries if - query.startswith(f'DROP TABLE IF EXISTS "{schema_name}"."DB1_')] + table_cleanup_query = [ + query + for query in executed_queries + if query.startswith(f'DROP TABLE IF EXISTS "{schema_name}"."DB1_') + ] for query in executed_queries: print("executed_query: ", query) assert table_cleanup_query def test_query_loop_integration_with_two_iteration( - database_with_slc, backend_aware_database_params): + database_with_slc, backend_aware_database_params +): # start a new db session, to isolate the EXECUTE SCRIPT and the QueryHandler queries # into its own session, for easier retrieval of the audit log conn = pyexasol.connect(**backend_aware_database_params) @@ -145,41 +158,45 @@ def test_query_loop_integration_with_two_iteration( "query_handler": { "factory_class": { "name": "QueryHandlerTestWithTwoIterationFactory", - "module": "test_query_handlers.query_handler_test" + "module": "test_query_handlers.query_handler_test", }, - "udf": { - "schema": schema_name, - "name": "AAF_QUERY_HANDLER_UDF" - }, - "parameter": TEST_INPUT + "udf": {"schema": schema_name, "name": "AAF_QUERY_HANDLER_UDF"}, + "parameter": TEST_INPUT, }, "temporary_output": { "bucketfs_location": { "directory": "directory", - "connection_name": bucketfs_connection_name + "connection_name": bucketfs_connection_name, }, - "schema_name": schema_name - } - }) + "schema_name": schema_name, + }, + } + ) query = f"EXECUTE SCRIPT {schema_name}.AAF_RUN_QUERY_HANDLER('{args}')" result = conn.execute(textwrap.dedent(query)).fetchall() # get audit logs after executing query loop conn.execute(QUERY_FLUSH_STATS) - audit_logs: List[Tuple[str]] = conn.execute(textwrap.dedent(QUERY_AUDIT_LOGS)) \ - .fetchmany(N_FETCHED_ROWS) + audit_logs: List[Tuple[str]] = conn.execute( + textwrap.dedent(QUERY_AUDIT_LOGS) + ).fetchmany(N_FETCHED_ROWS) executed_queries = [row[0] for row in audit_logs] - view_cleanup_query = [query for query in executed_queries if - query.startswith(f'DROP VIEW IF EXISTS "{schema_name}"."DB1_')] + view_cleanup_query = [ + query + for query in executed_queries + if query.startswith(f'DROP VIEW IF EXISTS "{schema_name}"."DB1_') + ] expected_query_list = {query.query_string for query in QUERY_LIST} - select_queries_from_query_handler = {query for query in executed_queries - if query in expected_query_list} + select_queries_from_query_handler = { + query for query in executed_queries if query in expected_query_list + } # TODO build an assert which can find a list of regex as a subsequence of a list of strings, # see https://kalnytskyi.com/posts/assert-str-matches-regex-in-pytest/ # asserts for query in executed_queries: print("executed_query: ", query) - assert result[0][0] == FINAL_RESULT \ - and select_queries_from_query_handler == expected_query_list \ - and view_cleanup_query, \ - f"Not all required queries where executed {executed_queries}" + assert ( + result[0][0] == FINAL_RESULT + and select_queries_from_query_handler == expected_query_list + and view_cleanup_query + ), f"Not all required queries where executed {executed_queries}" diff --git a/tests/integration_tests/with_db/test_user_guide_example.py b/tests/integration_tests/with_db/test_user_guide_example.py index 6129d1ce..dcf13ebe 100644 --- a/tests/integration_tests/with_db/test_user_guide_example.py +++ b/tests/integration_tests/with_db/test_user_guide_example.py @@ -1,11 +1,12 @@ import importlib.resources -import pytest import re - from contextlib import ExitStack -from exasol.analytics.query_handler.deployment import constants + +import pytest from exasol.python_extension_common.deployment.temp_schema import temp_schema +from exasol.analytics.query_handler.deployment import constants + @pytest.fixture def example_db_schemas(pyexasol_connection): @@ -22,8 +23,14 @@ def test_user_guide_example(database_with_slc, pyexasol_connection, example_db_s own python module. """ bucketfs_connection_name, schema_name = database_with_slc - dir = importlib.resources.files(constants.BASE_PACKAGE) \ - / ".." / ".." / "doc" / "user_guide" / "example-udf-script" + dir = ( + importlib.resources.files(constants.BASE_PACKAGE) + / ".." + / ".." + / "doc" + / "user_guide" + / "example-udf-script" + ) statement = ( (dir / "create.sql") diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/analyze_log.py b/tests/integration_tests/without_db/udf_communication/peer_communication/analyze_log.py index 0666c6a5..6e3a6a10 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/analyze_log.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/analyze_log.py @@ -1,11 +1,13 @@ import json import sys -from collections import defaultdict, Counter +from collections import Counter, defaultdict from pathlib import Path -from typing import Dict, List, Callable +from typing import Callable, Dict, List -def is_log_sequence_ok(lines: List[Dict[str, str]], line_predicate: Callable[[Dict[str, str]], bool]): +def is_log_sequence_ok( + lines: List[Dict[str, str]], line_predicate: Callable[[Dict[str, str]], bool] +): result = False for line in lines: if line_predicate(line): @@ -18,23 +20,39 @@ def is_peer_ready(line: Dict[str, str]): def is_connection_acknowledged(line: Dict[str, str]): - return line["module"] == "background_peer_state" and line["event"] == "received_acknowledge_connection" + return ( + line["module"] == "background_peer_state" + and line["event"] == "received_acknowledge_connection" + ) def is_connection_synchronized(line: Dict[str, str]): - return line["module"] == "background_peer_state" and line["event"] == "received_synchronize_connection" + return ( + line["module"] == "background_peer_state" + and line["event"] == "received_synchronize_connection" + ) + def is_register_peer_acknowledged(line: Dict[str, str]): - return line["module"] == "background_peer_state" and line["event"] == "received_acknowledge_register_peer" + return ( + line["module"] == "background_peer_state" + and line["event"] == "received_acknowledge_register_peer" + ) + def is_register_peer_complete(line: Dict[str, str]): - return line["module"] == "background_peer_state" and line["event"] == "received_register_peer_complete" + return ( + line["module"] == "background_peer_state" + and line["event"] == "received_register_peer_complete" + ) def analyze_source_target_interaction(log_file_path: Path): print("analyze_source_target_interaction") with open(log_file_path) as f: - group_source_target_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + group_source_target_map = defaultdict( + lambda: defaultdict(lambda: defaultdict(list)) + ) collect_source_target_interaction(f, group_source_target_map) print_source_target_interaction(group_source_target_map) @@ -45,7 +63,7 @@ def print_source_target_interaction(group_source_target_map): "is_connection_acknowledged": is_connection_acknowledged, "is_connection_synchronized": is_connection_synchronized, "is_register_peer_acknowledged": is_register_peer_acknowledged, - "is_register_peer_complete": is_register_peer_complete + "is_register_peer_complete": is_register_peer_complete, } for group, sources in group_source_target_map.items(): ok = Counter() @@ -56,32 +74,39 @@ def print_source_target_interaction(group_source_target_map): if not is_log_sequence_ok(lines, predicate): not_ok.update((predicate_name,)) if predicate_name == "is_peer_ready": - print(f"========== {predicate_name}-{group}-{source}-{target} ============") + print( + f"========== {predicate_name}-{group}-{source}-{target} ============" + ) else: ok.update((predicate_name,)) for predicate_name in predicates.keys(): - print(f"{group} {predicate_name} ok {ok[predicate_name]} not_ok {not_ok[predicate_name]}") + print( + f"{group} {predicate_name} ok {ok[predicate_name]} not_ok {not_ok[predicate_name]}" + ) print() def collect_source_target_interaction(f, group_source_target_map): for line in iter(f.readline, ""): json_line = json.loads(line) - if ("peer" in json_line - and "my_connection_info" in json_line - and "event" in json_line - and "module" in json_line - and json_line["event"] != "try_send" + if ( + "peer" in json_line + and "my_connection_info" in json_line + and "event" in json_line + and "module" in json_line + and json_line["event"] != "try_send" ): try: group = json_line["my_connection_info"]["group_identifier"] source = json_line["my_connection_info"]["name"] target = json_line["peer"]["connection_info"]["name"] - group_source_target_map[group][source][target].append({ - "event": json_line["event"], - "module": json_line["module"], - "timestamp": json_line["timestamp"], - }) + group_source_target_map[group][source][target].append( + { + "event": json_line["event"], + "module": json_line["module"], + "timestamp": json_line["timestamp"], + } + ) except Exception as e: raise Exception("Could not parse line: " + str(json_line)) from e @@ -89,17 +114,17 @@ def collect_source_target_interaction(f, group_source_target_map): def collect_close(f, group_source_map): for line in iter(f.readline, ""): json_line = json.loads(line) - if ("name" in json_line - and "group_identifier" in json_line - and "event" in json_line - and json_line["event"].startswith("after") + if ( + "name" in json_line + and "group_identifier" in json_line + and "event" in json_line + and json_line["event"].startswith("after") ): group = json_line["group_identifier"] source = json_line["name"] - group_source_map[group][source].append({ - "timestamp": json_line["timestamp"], - "event": json_line["event"] - }) + group_source_map[group][source].append( + {"timestamp": json_line["timestamp"], "event": json_line["event"]} + ) def print_close(group_source_map): @@ -125,6 +150,6 @@ def analyze_close(log_file_path: Path): if __name__ == "__main__": - log_file_path=Path(sys.argv[1]).absolute() + log_file_path = Path(sys.argv[1]).absolute() analyze_source_target_interaction(log_file_path) analyze_close(log_file_path) diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/insert_elk.py b/tests/integration_tests/without_db/udf_communication/peer_communication/insert_elk.py index 20b244d5..4392bd28 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/insert_elk.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/insert_elk.py @@ -19,8 +19,5 @@ def insert(log_file_path: Path, es: Elasticsearch): if __name__ == "__main__": root = Path(__file__).parent log_file_path = root / "test_add_peer_forward.log" - es = Elasticsearch( - "http://localhost:9200", - basic_auth=("elastic", "changeme") - ) + es = Elasticsearch("http://localhost:9200", basic_auth=("elastic", "changeme")) insert(log_file_path, es) diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_close.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_close.py index 9b40b3ce..00d8f4c6 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_close.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_close.py @@ -15,43 +15,61 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import \ - ForwardRegisterPeerConfig -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig -from exasol.analytics.udf.communication.socket_factory.fault_injection import \ - FaultInjectionSocketFactory -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ - ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import \ - PeerCommunicatorTestProcessParameter, BidirectionalQueue, TestProcess, assert_processes_finish +from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import ( + ForwardRegisterPeerConfig, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) +from exasol.analytics.udf.communication.socket_factory.fault_injection import ( + FaultInjectionSocketFactory, +) +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + PeerCommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger() def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQueue): - logger = LOGGER.bind(group_identifier=parameter.group_identifier, name=parameter.instance_name) + logger = LOGGER.bind( + group_identifier=parameter.group_identifier, name=parameter.instance_name + ) try: listen_ip = IPAddress(ip_address=f"127.1.0.1") context = zmq.Context() socket_factory = ZMQSocketFactory(context) - socket_factory = FaultInjectionSocketFactory(socket_factory, 0.01, RandomState(parameter.seed)) + socket_factory = FaultInjectionSocketFactory( + socket_factory, 0.01, RandomState(parameter.seed) + ) com = PeerCommunicator( name=parameter.instance_name, number_of_peers=parameter.number_of_instances, @@ -60,8 +78,7 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue socket_factory=socket_factory, config=PeerCommunicatorConfig( forward_register_peer_config=ForwardRegisterPeerConfig( - is_leader=False, - is_enabled=False + is_leader=False, is_enabled=False ) ), ) @@ -86,7 +103,9 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue logger.exception("Exception during test") -@pytest.mark.parametrize("number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)]) +@pytest.mark.parametrize( + "number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)] +) def test_reliability(number_of_instances: int, repetitions: int): run_test_with_repetitions(number_of_instances, repetitions) @@ -112,32 +131,42 @@ def test_functionality_25(): def run_test_with_repetitions(number_of_instances: int, repetitions: int): for i in range(repetitions): - LOGGER.info(f"Start iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances) + LOGGER.info( + f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + ) start_time = time.monotonic() group = f"{time.monotonic_ns()}" - expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances, seed=i) + expected_peers_of_threads, peers_of_threads = run_test( + group, number_of_instances, seed=i + ) assert expected_peers_of_threads == peers_of_threads end_time = time.monotonic() - LOGGER.info(f"Finish iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances, - duration=end_time - start_time) + LOGGER.info( + f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + duration=end_time - start_time, + ) def run_test(group: str, number_of_instances: int, seed: int): connection_infos: Dict[int, ConnectionInfo] = {} parameters = [ PeerCommunicatorTestProcessParameter( - instance_name=f"i{i}", group_identifier=group, + instance_name=f"i{i}", + group_identifier=group, number_of_instances=number_of_instances, - seed=seed + i) - for i in range(number_of_instances)] - processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + seed=seed + i, + ) + for i in range(number_of_instances) + ] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for i in range(number_of_instances): processes[i].start() connection_infos[i] = processes[i].get() @@ -147,8 +176,5 @@ def run_test(group: str, number_of_instances: int, seed: int): result_of_threads: Dict[int, List[ConnectionInfo]] = {} for i in range(number_of_instances): result_of_threads[i] = processes[i].get() - expected_results_of_threads = { - i: "Success" - for i in range(number_of_instances) - } + expected_results_of_threads = {i: "Success" for i in range(number_of_instances)} return expected_results_of_threads, result_of_threads diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_close.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_close.py index 5dcbe8f5..9dba628b 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_close.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_close.py @@ -13,36 +13,51 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import \ - ForwardRegisterPeerConfig -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ - ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import \ - PeerCommunicatorTestProcessParameter, TestProcess, assert_processes_finish, BidirectionalQueue +from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import ( + ForwardRegisterPeerConfig, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + PeerCommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger() def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQueue): - logger = LOGGER.bind(group_identifier=parameter.group_identifier, name=parameter.instance_name) + logger = LOGGER.bind( + group_identifier=parameter.group_identifier, name=parameter.instance_name + ) try: listen_ip = IPAddress(ip_address=f"127.1.0.1") context = zmq.Context() @@ -55,8 +70,7 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue socket_factory=socket_factory, config=PeerCommunicatorConfig( forward_register_peer_config=ForwardRegisterPeerConfig( - is_leader=False, - is_enabled=False + is_leader=False, is_enabled=False ) ), ) @@ -91,32 +105,42 @@ def test_functionality_2(): def run_test_with_repetitions(number_of_instances: int, repetitions: int): for i in range(repetitions): - LOGGER.info(f"Start iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances) + LOGGER.info( + f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + ) start_time = time.monotonic() group = f"{time.monotonic_ns()}" - expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances, seed=i) + expected_peers_of_threads, peers_of_threads = run_test( + group, number_of_instances, seed=i + ) assert expected_peers_of_threads == peers_of_threads end_time = time.monotonic() - LOGGER.info(f"Finish iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances, - duration=end_time - start_time) + LOGGER.info( + f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + duration=end_time - start_time, + ) def run_test(group: str, number_of_instances: int, seed: int): connection_infos: Dict[int, ConnectionInfo] = {} parameters = [ PeerCommunicatorTestProcessParameter( - instance_name=f"i{i}", group_identifier=group, + instance_name=f"i{i}", + group_identifier=group, number_of_instances=number_of_instances, - seed=seed + i) - for i in range(number_of_instances)] - processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + seed=seed + i, + ) + for i in range(number_of_instances) + ] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for i in range(number_of_instances): processes[i].start() connection_infos[i] = processes[i].get() @@ -126,8 +150,5 @@ def run_test(group: str, number_of_instances: int, seed: int): result_of_threads: Dict[int, List[ConnectionInfo]] = {} for i in range(number_of_instances): result_of_threads[i] = processes[i].get() - expected_results_of_threads = { - i: "Success" - for i in range(number_of_instances) - } + expected_results_of_threads = {i: "Success" for i in range(number_of_instances)} return expected_results_of_threads, result_of_threads diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_for_peers.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_for_peers.py index 8051a91d..45495a5a 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_for_peers.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_for_peers.py @@ -16,44 +16,64 @@ from exasol.analytics.udf.communication.ip_address import IPAddress from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import \ - ForwardRegisterPeerConfig -from exasol.analytics.udf.communication.peer_communicator.peer_communicator import key_for_peer -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig -from exasol.analytics.udf.communication.socket_factory.fault_injection import \ - FaultInjectionSocketFactory -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ - ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import \ - PeerCommunicatorTestProcessParameter, BidirectionalQueue, TestProcess, assert_processes_finish +from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import ( + ForwardRegisterPeerConfig, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator import ( + key_for_peer, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) +from exasol.analytics.udf.communication.socket_factory.fault_injection import ( + FaultInjectionSocketFactory, +) +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + PeerCommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger() def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQueue): - logger = LOGGER.bind(group_identifier=parameter.group_identifier, name=parameter.instance_name) + logger = LOGGER.bind( + group_identifier=parameter.group_identifier, name=parameter.instance_name + ) try: listen_ip = IPAddress(ip_address=f"127.1.0.1") context = zmq.Context() socket_factory = ZMQSocketFactory(context) - socket_factory = FaultInjectionSocketFactory(socket_factory, 0.01, RandomState(parameter.seed)) + socket_factory = FaultInjectionSocketFactory( + socket_factory, 0.01, RandomState(parameter.seed) + ) com = PeerCommunicator( name=parameter.instance_name, number_of_peers=parameter.number_of_instances, @@ -62,8 +82,7 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue socket_factory=socket_factory, config=PeerCommunicatorConfig( forward_register_peer_config=ForwardRegisterPeerConfig( - is_leader=False, - is_enabled=False + is_leader=False, is_enabled=False ) ), ) @@ -88,7 +107,9 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue logger.exception("Exception during test") -@pytest.mark.parametrize("number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)]) +@pytest.mark.parametrize( + "number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)] +) def test_reliability(number_of_instances: int, repetitions: int): run_test_with_repetitions(number_of_instances, repetitions) @@ -114,32 +135,42 @@ def test_functionality_25(): def run_test_with_repetitions(number_of_instances: int, repetitions: int): for i in range(repetitions): - LOGGER.info(f"Start iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances) + LOGGER.info( + f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + ) start_time = time.monotonic() group = f"{time.monotonic_ns()}" - expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances, seed=i) + expected_peers_of_threads, peers_of_threads = run_test( + group, number_of_instances, seed=i + ) assert expected_peers_of_threads == peers_of_threads end_time = time.monotonic() - LOGGER.info(f"Finish iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances, - duration=end_time - start_time) + LOGGER.info( + f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + duration=end_time - start_time, + ) def run_test(group: str, number_of_instances: int, seed: int): connection_infos: Dict[int, ConnectionInfo] = {} parameters = [ PeerCommunicatorTestProcessParameter( - instance_name=f"i{i}", group_identifier=group, + instance_name=f"i{i}", + group_identifier=group, number_of_instances=number_of_instances, - seed=seed + i) - for i in range(number_of_instances)] - processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + seed=seed + i, + ) + for i in range(number_of_instances) + ] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for i in range(number_of_instances): processes[i].start() connection_infos[i] = processes[i].get() @@ -150,10 +181,13 @@ def run_test(group: str, number_of_instances: int, seed: int): for i in range(number_of_instances): peers_of_threads[i] = processes[i].get() expected_peers_of_threads = { - i: sorted([ - Peer(connection_info=connection_info) - for index, connection_info in connection_infos.items() - ], key=key_for_peer) + i: sorted( + [ + Peer(connection_info=connection_info) + for index, connection_info in connection_infos.items() + ], + key=key_for_peer, + ) for i in range(number_of_instances) } return expected_peers_of_threads, peers_of_threads diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_close.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_close.py index 2f7a90c0..25c5eed4 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_close.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_close.py @@ -15,43 +15,61 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import \ - ForwardRegisterPeerConfig -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig -from exasol.analytics.udf.communication.socket_factory.fault_injection import \ - FaultInjectionSocketFactory -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ - ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import \ - PeerCommunicatorTestProcessParameter, BidirectionalQueue, TestProcess, assert_processes_finish +from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import ( + ForwardRegisterPeerConfig, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) +from exasol.analytics.udf.communication.socket_factory.fault_injection import ( + FaultInjectionSocketFactory, +) +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + PeerCommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger() def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQueue): - logger = LOGGER.bind(group_identifier=parameter.group_identifier, name=parameter.instance_name) + logger = LOGGER.bind( + group_identifier=parameter.group_identifier, name=parameter.instance_name + ) try: listen_ip = IPAddress(ip_address=f"127.1.0.1") context = zmq.Context() socket_factory = ZMQSocketFactory(context) - socket_factory = FaultInjectionSocketFactory(socket_factory, 0.01, RandomState(parameter.seed)) + socket_factory = FaultInjectionSocketFactory( + socket_factory, 0.01, RandomState(parameter.seed) + ) leader_name = "i0" leader = True if parameter.instance_name == leader_name else False com = PeerCommunicator( @@ -61,11 +79,10 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue group_identifier=parameter.group_identifier, config=PeerCommunicatorConfig( forward_register_peer_config=ForwardRegisterPeerConfig( - is_leader=leader, - is_enabled=True + is_leader=leader, is_enabled=True ), ), - socket_factory=socket_factory + socket_factory=socket_factory, ) try: queue.put(com.my_connection_info) @@ -89,7 +106,9 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue queue.put(f"Failed: {e}") -@pytest.mark.parametrize("number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)]) +@pytest.mark.parametrize( + "number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)] +) def test_reliability(number_of_instances: int, repetitions: int): run_test_with_repetitions(number_of_instances, repetitions) @@ -115,32 +134,42 @@ def test_functionality_25(): def run_test_with_repetitions(number_of_instances: int, repetitions: int): for i in range(repetitions): - LOGGER.info(f"Start iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances) + LOGGER.info( + f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + ) start_time = time.monotonic() group = f"{time.monotonic_ns()}" - expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances, seed=i) + expected_peers_of_threads, peers_of_threads = run_test( + group, number_of_instances, seed=i + ) assert expected_peers_of_threads == peers_of_threads end_time = time.monotonic() - LOGGER.info(f"Finish iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances, - duration=end_time - start_time) + LOGGER.info( + f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + duration=end_time - start_time, + ) def run_test(group: str, number_of_instances: int, seed: int): connection_infos: Dict[int, ConnectionInfo] = {} parameters = [ PeerCommunicatorTestProcessParameter( - instance_name=f"i{i}", group_identifier=group, + instance_name=f"i{i}", + group_identifier=group, number_of_instances=number_of_instances, - seed=seed + i) - for i in range(number_of_instances)] - processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + seed=seed + i, + ) + for i in range(number_of_instances) + ] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for i in range(number_of_instances): processes[i].start() connection_infos[i] = processes[i].get() @@ -150,8 +179,5 @@ def run_test(group: str, number_of_instances: int, seed: int): result_of_threads: Dict[int, List[ConnectionInfo]] = {} for i in range(number_of_instances): result_of_threads[i] = processes[i].get() - expected_results_of_threads = { - i: "Success" - for i in range(number_of_instances) - } + expected_results_of_threads = {i: "Success" for i in range(number_of_instances)} return expected_results_of_threads, result_of_threads diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_wait_for_peers.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_wait_for_peers.py index 97d8592b..59548744 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_wait_for_peers.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_wait_for_peers.py @@ -15,44 +15,64 @@ from exasol.analytics.udf.communication.ip_address import IPAddress from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import \ - ForwardRegisterPeerConfig -from exasol.analytics.udf.communication.peer_communicator.peer_communicator import key_for_peer -from exasol.analytics.udf.communication.socket_factory.fault_injection import \ - FaultInjectionSocketFactory -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ - ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import \ - PeerCommunicatorTestProcessParameter, BidirectionalQueue, TestProcess, assert_processes_finish +from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import ( + ForwardRegisterPeerConfig, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator import ( + key_for_peer, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) +from exasol.analytics.udf.communication.socket_factory.fault_injection import ( + FaultInjectionSocketFactory, +) +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + PeerCommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(fmt="ISO"), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger() def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQueue): - logger = LOGGER.bind(group_identifier=parameter.group_identifier, name=parameter.instance_name) + logger = LOGGER.bind( + group_identifier=parameter.group_identifier, name=parameter.instance_name + ) try: listen_ip = IPAddress(ip_address=f"127.1.0.1") context = zmq.Context() socket_factory = ZMQSocketFactory(context) - socket_factory = FaultInjectionSocketFactory(socket_factory, 0.01, RandomState(parameter.seed)) + socket_factory = FaultInjectionSocketFactory( + socket_factory, 0.01, RandomState(parameter.seed) + ) leader = False leader_name = "i0" leader = parameter.instance_name == leader_name @@ -63,11 +83,10 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue group_identifier=parameter.group_identifier, config=PeerCommunicatorConfig( forward_register_peer_config=ForwardRegisterPeerConfig( - is_leader=leader, - is_enabled=True + is_leader=leader, is_enabled=True ), ), - socket_factory=socket_factory + socket_factory=socket_factory, ) try: queue.put(com.my_connection_info) @@ -88,7 +107,9 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue queue.put([]) -@pytest.mark.parametrize("number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)]) +@pytest.mark.parametrize( + "number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)] +) def test_reliability(number_of_instances: int, repetitions: int): run_test_with_repetitions(number_of_instances, repetitions) @@ -122,32 +143,42 @@ def test_functionality_25(): def run_test_with_repetitions(number_of_instances: int, repetitions: int): for i in range(repetitions): - LOGGER.info(f"Start iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances) + LOGGER.info( + f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + ) start_time = time.monotonic() group = f"{time.monotonic_ns()}" - expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances, seed=i) + expected_peers_of_threads, peers_of_threads = run_test( + group, number_of_instances, seed=i + ) assert expected_peers_of_threads == peers_of_threads end_time = time.monotonic() - LOGGER.info(f"Finish iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances, - duration=end_time - start_time) + LOGGER.info( + f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + duration=end_time - start_time, + ) def run_test(group: str, number_of_instances: int, seed: int): connection_infos: Dict[int, ConnectionInfo] = {} parameters = [ PeerCommunicatorTestProcessParameter( - instance_name=f"i{i}", group_identifier=group, + instance_name=f"i{i}", + group_identifier=group, number_of_instances=number_of_instances, - seed=seed + i) - for i in range(number_of_instances)] - processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + seed=seed + i, + ) + for i in range(number_of_instances) + ] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for i in range(number_of_instances): processes[i].start() connection_infos[i] = processes[i].get() @@ -158,10 +189,13 @@ def run_test(group: str, number_of_instances: int, seed: int): for i in range(number_of_instances): peers_of_threads[i] = processes[i].get() expected_peers_of_threads = { - i: sorted([ - Peer(connection_info=connection_info) - for index, connection_info in connection_infos.items() - ], key=key_for_peer) + i: sorted( + [ + Peer(connection_info=connection_info) + for index, connection_info in connection_infos.items() + ], + key=key_for_peer, + ) for i in range(number_of_instances) } return expected_peers_of_threads, peers_of_threads diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_poll.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_poll.py index 63f4e328..96ea8a04 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/test_poll.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_poll.py @@ -2,7 +2,7 @@ import time import traceback from pathlib import Path -from typing import Dict, Set, List +from typing import Dict, List, Set import structlog import zmq @@ -13,37 +13,51 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import \ - ForwardRegisterPeerConfig -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ - ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, \ - BidirectionalQueue, assert_processes_finish, \ - PeerCommunicatorTestProcessParameter +from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import ( + ForwardRegisterPeerConfig, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + PeerCommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger() def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQueue): - logger = LOGGER.bind(group_identifier=parameter.group_identifier, name=parameter.instance_name) + logger = LOGGER.bind( + group_identifier=parameter.group_identifier, name=parameter.instance_name + ) received_values: Set[str] = set() try: listen_ip = IPAddress(ip_address=f"127.1.0.1") @@ -57,8 +71,7 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue socket_factory=socket_factory, config=PeerCommunicatorConfig( forward_register_peer_config=ForwardRegisterPeerConfig( - is_leader=False, - is_enabled=False + is_leader=False, is_enabled=False ), ), ) @@ -71,8 +84,17 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue LOGGER.info("Peer is ready", name=com.peer.connection_info.name) if com.peer.connection_info.name != "i0": time.sleep(10) - peer_i0 = next(peer for peer in com.peers() if peer.connection_info.name == "i0") - com.send(peer_i0, [socket_factory.create_frame(com.peer.connection_info.name.encode("utf8"))]) + peer_i0 = next( + peer for peer in com.peers() if peer.connection_info.name == "i0" + ) + com.send( + peer_i0, + [ + socket_factory.create_frame( + com.peer.connection_info.name.encode("utf8") + ) + ], + ) else: while len(received_values) < parameter.number_of_instances - 1: poll_peers = com.poll_peers(timeout_in_milliseconds=100) @@ -101,7 +123,9 @@ def test_functionality_5(): def run_test_with_assert(number_of_instances: int): group = f"{time.monotonic_ns()}" - expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances, 0) + expected_peers_of_threads, peers_of_threads = run_test( + group, number_of_instances, 0 + ) assert expected_peers_of_threads == peers_of_threads @@ -109,12 +133,16 @@ def run_test(group: str, number_of_instances: int, seed: int): connection_infos: Dict[int, ConnectionInfo] = {} parameters = [ PeerCommunicatorTestProcessParameter( - instance_name=f"i{i}", group_identifier=group, + instance_name=f"i{i}", + group_identifier=group, number_of_instances=number_of_instances, - seed=seed + i) - for i in range(number_of_instances)] - processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + seed=seed + i, + ) + for i in range(number_of_instances) + ] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for i in range(number_of_instances): processes[i].start() for i in range(number_of_instances): diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_send_recv.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_send_recv.py index bc5f2e33..ed23ff78 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/test_send_recv.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_send_recv.py @@ -3,7 +3,7 @@ import time import traceback from pathlib import Path -from typing import Dict, Set, List +from typing import Dict, List, Set import pytest import structlog @@ -17,45 +17,62 @@ from exasol.analytics.udf.communication.ip_address import IPAddress from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import \ - ForwardRegisterPeerConfig -from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import \ - PeerCommunicatorConfig -from exasol.analytics.udf.communication.socket_factory.fault_injection import \ - FaultInjectionSocketFactory -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ - ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, \ - BidirectionalQueue, assert_processes_finish, \ - PeerCommunicatorTestProcessParameter +from exasol.analytics.udf.communication.peer_communicator.forward_register_peer_config import ( + ForwardRegisterPeerConfig, +) +from exasol.analytics.udf.communication.peer_communicator.peer_communicator_config import ( + PeerCommunicatorConfig, +) +from exasol.analytics.udf.communication.socket_factory.fault_injection import ( + FaultInjectionSocketFactory, +) +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + PeerCommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger() def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQueue): - logger = LOGGER.bind(group_identifier=parameter.group_identifier, name=parameter.instance_name) + logger = LOGGER.bind( + group_identifier=parameter.group_identifier, name=parameter.instance_name + ) received_values: Set[str] = set() try: listen_ip = IPAddress(ip_address=f"127.1.0.1") context = zmq.Context() socket_factory = ZMQSocketFactory(context) - socket_factory = FaultInjectionSocketFactory(socket_factory, 0.01, RandomState(parameter.seed)) + socket_factory = FaultInjectionSocketFactory( + socket_factory, 0.01, RandomState(parameter.seed) + ) com = PeerCommunicator( name=parameter.instance_name, number_of_peers=parameter.number_of_instances, @@ -64,8 +81,7 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue socket_factory=socket_factory, config=PeerCommunicatorConfig( forward_register_peer_config=ForwardRegisterPeerConfig( - is_leader=False, - is_enabled=False + is_leader=False, is_enabled=False ), ), ) @@ -78,7 +94,14 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue LOGGER.info("Peer is ready", name=parameter.instance_name) for peer in com.peers(): if peer != Peer(connection_info=com.my_connection_info): - com.send(peer, [socket_factory.create_frame(parameter.instance_name.encode("utf8"))]) + com.send( + peer, + [ + socket_factory.create_frame( + parameter.instance_name.encode("utf8") + ) + ], + ) for peer in com.peers(): if peer != Peer(connection_info=com.my_connection_info): value = com.recv(peer) @@ -99,9 +122,7 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue queue.put(received_values) -@pytest.mark.parametrize("number_of_instances, repetitions", [ - (2, 1000), (10, 100) -]) +@pytest.mark.parametrize("number_of_instances, repetitions", [(2, 1000), (10, 100)]) def test_reliability(number_of_instances: int, repetitions: int): run_test_with_repetitions(number_of_instances, repetitions) @@ -120,41 +141,54 @@ def test_functionality_5(): def test_functionality_10(): run_test_with_repetitions(10, REPETITIONS_FOR_FUNCTIONALITY) -@pytest.mark.skipif("GITHUB_ACTIONS" in os.environ, - reason="This test is unstable on Github Action, " - "because of the limited number of cores on the default runners.") + +@pytest.mark.skipif( + "GITHUB_ACTIONS" in os.environ, + reason="This test is unstable on Github Action, " + "because of the limited number of cores on the default runners.", +) def test_functionality_25(): run_test_with_repetitions(25, REPETITIONS_FOR_FUNCTIONALITY) def run_test_with_repetitions(number_of_instances: int, repetitions: int): for i in range(repetitions): - LOGGER.info(f"Start iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances) + LOGGER.info( + f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + ) start_time = time.monotonic() group = f"{time.monotonic_ns()}" - expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances, seed=i) + expected_peers_of_threads, peers_of_threads = run_test( + group, number_of_instances, seed=i + ) assert expected_peers_of_threads == peers_of_threads end_time = time.monotonic() - LOGGER.info(f"Finish iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances, - duration=end_time - start_time) + LOGGER.info( + f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + duration=end_time - start_time, + ) def run_test(group: str, number_of_instances: int, seed: int): connection_infos: Dict[int, ConnectionInfo] = {} parameters = [ PeerCommunicatorTestProcessParameter( - instance_name=f"i{i}", group_identifier=group, + instance_name=f"i{i}", + group_identifier=group, number_of_instances=number_of_instances, - seed=seed + i) - for i in range(number_of_instances)] - processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + seed=seed + i, + ) + for i in range(number_of_instances) + ] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for i in range(number_of_instances): processes[i].start() for i in range(number_of_instances): diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/utils.py b/tests/integration_tests/without_db/udf_communication/peer_communication/utils.py index 26d21ab3..fb228a3b 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/utils.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/utils.py @@ -4,14 +4,14 @@ from abc import ABC from multiprocessing import Process from queue import Queue -from typing import Any, Callable, List, TypeVar, Generic +from typing import Any, Callable, Generic, List, TypeVar import structlog from structlog.typing import FilteringBoundLogger from exasol.analytics.udf.communication.ip_address import Port -NANOSECONDS_PER_SECOND = 10 ** 9 +NANOSECONDS_PER_SECOND = 10**9 LOGGER: FilteringBoundLogger = structlog.get_logger(__name__) @@ -30,6 +30,7 @@ def get(self) -> Any: class TestProcessParameter(ABC): + __test__ = False def __init__(self, seed: int): self.seed = seed @@ -38,7 +39,13 @@ def __repr__(self): class PeerCommunicatorTestProcessParameter(TestProcessParameter): - def __init__(self, instance_name: str, group_identifier: str, number_of_instances: int, seed: int): + def __init__( + self, + instance_name: str, + group_identifier: str, + number_of_instances: int, + seed: int, + ): super().__init__(seed) self.number_of_instances = number_of_instances self.group_identifier = group_identifier @@ -46,14 +53,16 @@ def __init__(self, instance_name: str, group_identifier: str, number_of_instance class CommunicatorTestProcessParameter(TestProcessParameter): - def __init__(self, - node_name: str, - instance_name: str, - group_identifier: str, - number_of_nodes: int, - number_of_instances_per_node: int, - local_discovery_port: Port, - seed: int): + def __init__( + self, + node_name: str, + instance_name: str, + group_identifier: str, + number_of_nodes: int, + number_of_instances_per_node: int, + local_discovery_port: Port, + seed: int, + ): super().__init__(seed) self.local_discovery_port = local_discovery_port self.number_of_instances_per_node = number_of_instances_per_node @@ -63,16 +72,18 @@ def __init__(self, self.instance_name = instance_name -T = TypeVar('T') +T = TypeVar("T") class TestProcess(Generic[T]): - def __init__(self, parameter: T, - run: Callable[[T, BidirectionalQueue], None]): + __test__ = False + def __init__(self, parameter: T, run: Callable[[T, BidirectionalQueue], None]): self.parameter = parameter put_queue = multiprocessing.Queue() get_queue = multiprocessing.Queue() - self._main_thread_queue = BidirectionalQueue(put_queue=get_queue, get_queue=put_queue) + self._main_thread_queue = BidirectionalQueue( + put_queue=get_queue, get_queue=put_queue + ) thread_queue = BidirectionalQueue(put_queue=put_queue, get_queue=get_queue) self._process = Process(target=run, args=(self.parameter, thread_queue)) @@ -110,7 +121,9 @@ def assert_processes_finish(processes: List[TestProcess], timeout_in_seconds: in if difference_ns > timeout_in_ns: break time.sleep(0.01) - alive_processes_before_kill = [process.parameter for process in get_alive_processes(processes)] + alive_processes_before_kill = [ + process.parameter for process in get_alive_processes(processes) + ] kill_alive_processes(processes) if len(get_alive_processes(processes)) > 0: time.sleep(2) diff --git a/tests/integration_tests/without_db/udf_communication/test_broadcast.py b/tests/integration_tests/without_db/udf_communication/test_broadcast.py index d170e0f7..0bb03f20 100644 --- a/tests/integration_tests/without_db/udf_communication/test_broadcast.py +++ b/tests/integration_tests/without_db/udf_communication/test_broadcast.py @@ -1,6 +1,6 @@ import time from pathlib import Path -from typing import List, Dict, Tuple +from typing import Dict, List, Tuple import structlog import zmq @@ -9,34 +9,43 @@ from structlog.types import FilteringBoundLogger from exasol.analytics.udf.communication.communicator import Communicator -from exasol.analytics.udf.communication.ip_address import Port, IPAddress -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ - ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, \ - BidirectionalQueue, assert_processes_finish, \ - CommunicatorTestProcessParameter +from exasol.analytics.udf.communication.ip_address import IPAddress, Port +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + CommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger(__name__) -def run(parameter: CommunicatorTestProcessParameter, - queue: BidirectionalQueue): +def run(parameter: CommunicatorTestProcessParameter, queue: BidirectionalQueue): try: is_discovery_leader_node = parameter.node_name == "n0" context = zmq.Context() @@ -52,13 +61,18 @@ def run(parameter: CommunicatorTestProcessParameter, number_of_nodes=parameter.number_of_nodes, number_of_instances_per_node=parameter.number_of_instances_per_node, is_discovery_leader_node=is_discovery_leader_node, - socket_factory=socket_factory + socket_factory=socket_factory, ) value = None if communicator.is_multi_node_leader(): value = b"Success" result = communicator.broadcast(value) - LOGGER.info("result", result=result, instance_name=parameter.instance_name, node_name=parameter.node_name) + LOGGER.info( + "result", + result=result, + instance_name=parameter.instance_name, + node_name=parameter.node_name, + ) queue.put(result.decode("utf-8")) except Exception as e: LOGGER.exception("Exception during test") @@ -69,55 +83,72 @@ def run(parameter: CommunicatorTestProcessParameter, def test_functionality_2_1(): - run_test_with_repetitions(number_of_nodes=2, - number_of_instances_per_node=1, - repetitions=REPETITIONS_FOR_FUNCTIONALITY) + run_test_with_repetitions( + number_of_nodes=2, + number_of_instances_per_node=1, + repetitions=REPETITIONS_FOR_FUNCTIONALITY, + ) def test_functionality_1_2(): - run_test_with_repetitions(number_of_nodes=1, - number_of_instances_per_node=2, - repetitions=REPETITIONS_FOR_FUNCTIONALITY) + run_test_with_repetitions( + number_of_nodes=1, + number_of_instances_per_node=2, + repetitions=REPETITIONS_FOR_FUNCTIONALITY, + ) def test_functionality_2_2(): - run_test_with_repetitions(number_of_nodes=2, - number_of_instances_per_node=2, - repetitions=REPETITIONS_FOR_FUNCTIONALITY) + run_test_with_repetitions( + number_of_nodes=2, + number_of_instances_per_node=2, + repetitions=REPETITIONS_FOR_FUNCTIONALITY, + ) def test_functionality_3_3(): - run_test_with_repetitions(number_of_nodes=3, - number_of_instances_per_node=3, - repetitions=REPETITIONS_FOR_FUNCTIONALITY) + run_test_with_repetitions( + number_of_nodes=3, + number_of_instances_per_node=3, + repetitions=REPETITIONS_FOR_FUNCTIONALITY, + ) -def run_test_with_repetitions(number_of_nodes: int, number_of_instances_per_node: int, repetitions: int): +def run_test_with_repetitions( + number_of_nodes: int, number_of_instances_per_node: int, repetitions: int +): for i in range(repetitions): group = f"{time.monotonic_ns()}" - LOGGER.info(f"Start iteration", - iteration=i + 1, - repetitions=repetitions, - group_identifier=group, - number_of_nodes=number_of_nodes, - number_of_instances_per_node=number_of_instances_per_node) + LOGGER.info( + f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + group_identifier=group, + number_of_nodes=number_of_nodes, + number_of_instances_per_node=number_of_instances_per_node, + ) start_time = time.monotonic() - expected_result_of_threads, actual_result_of_threads = \ - run_test(group_identifier=group, - number_of_nodes=number_of_nodes, - number_of_instances_per_node=number_of_instances_per_node) + expected_result_of_threads, actual_result_of_threads = run_test( + group_identifier=group, + number_of_nodes=number_of_nodes, + number_of_instances_per_node=number_of_instances_per_node, + ) assert expected_result_of_threads == actual_result_of_threads end_time = time.monotonic() - LOGGER.info(f"Finish iteration", - iteration=i + 1, - repetitions=repetitions, - group_identifier=group, - number_of_nodes=number_of_nodes, - number_of_instances_per_node=number_of_instances_per_node, - duration=end_time - start_time) + LOGGER.info( + f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + group_identifier=group, + number_of_nodes=number_of_nodes, + number_of_instances_per_node=number_of_instances_per_node, + duration=end_time - start_time, + ) -def run_test(group_identifier: str, number_of_nodes: int, number_of_instances_per_node: int): +def run_test( + group_identifier: str, number_of_nodes: int, number_of_instances_per_node: int +): parameters = [ CommunicatorTestProcessParameter( node_name=f"n{n}", @@ -126,11 +157,14 @@ def run_test(group_identifier: str, number_of_nodes: int, number_of_instances_pe number_of_nodes=number_of_nodes, number_of_instances_per_node=number_of_instances_per_node, local_discovery_port=Port(port=44445 + n), - seed=0) + seed=0, + ) for n in range(number_of_nodes) - for i in range(number_of_instances_per_node)] - processes: List[TestProcess[CommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + for i in range(number_of_instances_per_node) + ] + processes: List[TestProcess[CommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for process in processes: process.start() assert_processes_finish(processes, timeout_in_seconds=180) diff --git a/tests/integration_tests/without_db/udf_communication/test_communicator.py b/tests/integration_tests/without_db/udf_communication/test_communicator.py index e1f20579..87681a1e 100644 --- a/tests/integration_tests/without_db/udf_communication/test_communicator.py +++ b/tests/integration_tests/without_db/udf_communication/test_communicator.py @@ -1,6 +1,6 @@ import time from pathlib import Path -from typing import List, Dict, Tuple +from typing import Dict, List, Tuple import pytest import structlog @@ -10,32 +10,43 @@ from structlog.types import FilteringBoundLogger from exasol.analytics.udf.communication.communicator import Communicator -from exasol.analytics.udf.communication.ip_address import Port, IPAddress -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, BidirectionalQueue, assert_processes_finish, \ - CommunicatorTestProcessParameter +from exasol.analytics.udf.communication.ip_address import IPAddress, Port +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + CommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger(__name__) -def run(parameter: CommunicatorTestProcessParameter, - queue: BidirectionalQueue): +def run(parameter: CommunicatorTestProcessParameter, queue: BidirectionalQueue): is_discovery_leader_node = parameter.node_name == "n0" context = zmq.Context() socket_factory = ZMQSocketFactory(context) @@ -50,23 +61,31 @@ def run(parameter: CommunicatorTestProcessParameter, number_of_nodes=parameter.number_of_nodes, number_of_instances_per_node=parameter.number_of_instances_per_node, is_discovery_leader_node=is_discovery_leader_node, - socket_factory=socket_factory + socket_factory=socket_factory, ) queue.put("Finished") -@pytest.mark.parametrize("number_of_nodes, number_of_instances_per_node, repetitions", - [ - (2, 2, 100), - (3, 3, 20), - ]) -def test_reliability(number_of_nodes: int, number_of_instances_per_node: int, repetitions: int): - run_test_with_repetitions(number_of_nodes=number_of_nodes, - number_of_instances_per_node=number_of_instances_per_node, - repetitions=repetitions) + +@pytest.mark.parametrize( + "number_of_nodes, number_of_instances_per_node, repetitions", + [ + (2, 2, 100), + (3, 3, 20), + ], +) +def test_reliability( + number_of_nodes: int, number_of_instances_per_node: int, repetitions: int +): + run_test_with_repetitions( + number_of_nodes=number_of_nodes, + number_of_instances_per_node=number_of_instances_per_node, + repetitions=repetitions, + ) REPETITIONS_FOR_FUNCTIONALITY = 1 + def test_functionality_2_1(): run_test_with_repetitions(2, 1, REPETITIONS_FOR_FUNCTIONALITY) @@ -83,32 +102,41 @@ def test_functionality_3_3(): run_test_with_repetitions(3, 3, REPETITIONS_FOR_FUNCTIONALITY) -def run_test_with_repetitions(number_of_nodes: int, number_of_instances_per_node: int, repetitions: int): +def run_test_with_repetitions( + number_of_nodes: int, number_of_instances_per_node: int, repetitions: int +): for i in range(repetitions): group = f"{time.monotonic_ns()}" - LOGGER.info(f"Start iteration", - iteration=i + 1, - repetitions=repetitions, - group_identifier=group, - number_of_nodes=number_of_nodes, - number_of_instances_per_node=number_of_instances_per_node) + LOGGER.info( + f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + group_identifier=group, + number_of_nodes=number_of_nodes, + number_of_instances_per_node=number_of_instances_per_node, + ) start_time = time.monotonic() - expected_result_of_threads, actual_result_of_threads = \ - run_test(group_identifier=group, - number_of_nodes=number_of_nodes, - number_of_instances_per_node=number_of_instances_per_node) + expected_result_of_threads, actual_result_of_threads = run_test( + group_identifier=group, + number_of_nodes=number_of_nodes, + number_of_instances_per_node=number_of_instances_per_node, + ) assert expected_result_of_threads == actual_result_of_threads end_time = time.monotonic() - LOGGER.info(f"Finish iteration", - iteration=i + 1, - repetitions=repetitions, - group_identifier=group, - number_of_nodes=number_of_nodes, - number_of_instances_per_node=number_of_instances_per_node, - duration=end_time - start_time) + LOGGER.info( + f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + group_identifier=group, + number_of_nodes=number_of_nodes, + number_of_instances_per_node=number_of_instances_per_node, + duration=end_time - start_time, + ) -def run_test(group_identifier: str, number_of_nodes: int, number_of_instances_per_node: int): +def run_test( + group_identifier: str, number_of_nodes: int, number_of_instances_per_node: int +): parameters = [ CommunicatorTestProcessParameter( node_name=f"n{n}", @@ -117,11 +145,14 @@ def run_test(group_identifier: str, number_of_nodes: int, number_of_instances_pe number_of_nodes=number_of_nodes, number_of_instances_per_node=number_of_instances_per_node, local_discovery_port=Port(port=44445 + n), - seed=0) + seed=0, + ) for n in range(number_of_nodes) - for i in range(number_of_instances_per_node)] - processes: List[TestProcess[CommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + for i in range(number_of_instances_per_node) + ] + processes: List[TestProcess[CommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for process in processes: process.start() assert_processes_finish(processes, timeout_in_seconds=180) diff --git a/tests/integration_tests/without_db/udf_communication/test_gather.py b/tests/integration_tests/without_db/udf_communication/test_gather.py index a81cab84..d9d0ca6a 100644 --- a/tests/integration_tests/without_db/udf_communication/test_gather.py +++ b/tests/integration_tests/without_db/udf_communication/test_gather.py @@ -1,6 +1,6 @@ import time from pathlib import Path -from typing import List, Dict, Tuple +from typing import Dict, List, Tuple import structlog import zmq @@ -9,34 +9,43 @@ from structlog.types import FilteringBoundLogger from exasol.analytics.udf.communication.communicator import Communicator -from exasol.analytics.udf.communication.ip_address import Port, IPAddress -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ - ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, \ - BidirectionalQueue, assert_processes_finish, \ - CommunicatorTestProcessParameter +from exasol.analytics.udf.communication.ip_address import IPAddress, Port +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + CommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger(__name__) -def run(parameter: CommunicatorTestProcessParameter, - queue: BidirectionalQueue): +def run(parameter: CommunicatorTestProcessParameter, queue: BidirectionalQueue): try: is_discovery_leader_node = parameter.node_name == "n0" context = zmq.Context() @@ -52,16 +61,21 @@ def run(parameter: CommunicatorTestProcessParameter, number_of_nodes=parameter.number_of_nodes, number_of_instances_per_node=parameter.number_of_instances_per_node, is_discovery_leader_node=is_discovery_leader_node, - socket_factory=socket_factory + socket_factory=socket_factory, ) value = f"{parameter.node_name}_{parameter.instance_name}" names = { - f"n{node}_i{instance}".encode("utf-8") + f"n{node}_i{instance}".encode() for instance in range(parameter.number_of_instances_per_node) for node in range(parameter.number_of_nodes) } result = communicator.gather(value.encode("utf-8")) - LOGGER.info("result", result=result, instance_name=parameter.instance_name, node_name=parameter.node_name) + LOGGER.info( + "result", + result=result, + instance_name=parameter.instance_name, + node_name=parameter.node_name, + ) if communicator.is_multi_node_leader(): if isinstance(result, List): if names != set(result): @@ -84,55 +98,72 @@ def run(parameter: CommunicatorTestProcessParameter, def test_functionality_2_1(): - run_test_with_repetitions(number_of_nodes=2, - number_of_instances_per_node=1, - repetitions=REPETITIONS_FOR_FUNCTIONALITY) + run_test_with_repetitions( + number_of_nodes=2, + number_of_instances_per_node=1, + repetitions=REPETITIONS_FOR_FUNCTIONALITY, + ) def test_functionality_1_2(): - run_test_with_repetitions(number_of_nodes=1, - number_of_instances_per_node=2, - repetitions=REPETITIONS_FOR_FUNCTIONALITY) + run_test_with_repetitions( + number_of_nodes=1, + number_of_instances_per_node=2, + repetitions=REPETITIONS_FOR_FUNCTIONALITY, + ) def test_functionality_2_2(): - run_test_with_repetitions(number_of_nodes=2, - number_of_instances_per_node=2, - repetitions=REPETITIONS_FOR_FUNCTIONALITY) + run_test_with_repetitions( + number_of_nodes=2, + number_of_instances_per_node=2, + repetitions=REPETITIONS_FOR_FUNCTIONALITY, + ) def test_functionality_3_3(): - run_test_with_repetitions(number_of_nodes=3, - number_of_instances_per_node=3, - repetitions=REPETITIONS_FOR_FUNCTIONALITY) + run_test_with_repetitions( + number_of_nodes=3, + number_of_instances_per_node=3, + repetitions=REPETITIONS_FOR_FUNCTIONALITY, + ) -def run_test_with_repetitions(number_of_nodes: int, number_of_instances_per_node: int, repetitions: int): +def run_test_with_repetitions( + number_of_nodes: int, number_of_instances_per_node: int, repetitions: int +): for i in range(repetitions): group = f"{time.monotonic_ns()}" - LOGGER.info(f"Start iteration", - iteration=i + 1, - repetitions=repetitions, - group_identifier=group, - number_of_nodes=number_of_nodes, - number_of_instances_per_node=number_of_instances_per_node) + LOGGER.info( + f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + group_identifier=group, + number_of_nodes=number_of_nodes, + number_of_instances_per_node=number_of_instances_per_node, + ) start_time = time.monotonic() - expected_result_of_threads, actual_result_of_threads = \ - run_test(group_identifier=group, - number_of_nodes=number_of_nodes, - number_of_instances_per_node=number_of_instances_per_node) + expected_result_of_threads, actual_result_of_threads = run_test( + group_identifier=group, + number_of_nodes=number_of_nodes, + number_of_instances_per_node=number_of_instances_per_node, + ) assert expected_result_of_threads == actual_result_of_threads end_time = time.monotonic() - LOGGER.info(f"Finish iteration", - iteration=i + 1, - repetitions=repetitions, - group_identifier=group, - number_of_nodes=number_of_nodes, - number_of_instances_per_node=number_of_instances_per_node, - duration=end_time - start_time) + LOGGER.info( + f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + group_identifier=group, + number_of_nodes=number_of_nodes, + number_of_instances_per_node=number_of_instances_per_node, + duration=end_time - start_time, + ) -def run_test(group_identifier: str, number_of_nodes: int, number_of_instances_per_node: int): +def run_test( + group_identifier: str, number_of_nodes: int, number_of_instances_per_node: int +): parameters = [ CommunicatorTestProcessParameter( node_name=f"n{n}", @@ -141,11 +172,14 @@ def run_test(group_identifier: str, number_of_nodes: int, number_of_instances_pe number_of_nodes=number_of_nodes, number_of_instances_per_node=number_of_instances_per_node, local_discovery_port=Port(port=44445 + n), - seed=0) + seed=0, + ) for n in range(number_of_nodes) - for i in range(number_of_instances_per_node)] - processes: List[TestProcess[CommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + for i in range(number_of_instances_per_node) + ] + processes: List[TestProcess[CommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for process in processes: process.start() assert_processes_finish(processes, timeout_in_seconds=180) diff --git a/tests/integration_tests/without_db/udf_communication/test_localhost_discovery.py b/tests/integration_tests/without_db/udf_communication/test_localhost_discovery.py index 28895760..c0a93fa4 100644 --- a/tests/integration_tests/without_db/udf_communication/test_localhost_discovery.py +++ b/tests/integration_tests/without_db/udf_communication/test_localhost_discovery.py @@ -11,31 +11,44 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.discovery import localhost -from exasol.analytics.udf.communication.discovery.localhost.communicator import \ - CommunicatorFactory -from exasol.analytics.udf.communication.ip_address import Port, IPAddress +from exasol.analytics.udf.communication.discovery.localhost.communicator import ( + CommunicatorFactory, +) +from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.peer_communicator import key_for_peer -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ - ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, \ - BidirectionalQueue, assert_processes_finish, \ - PeerCommunicatorTestProcessParameter +from exasol.analytics.udf.communication.peer_communicator.peer_communicator import ( + key_for_peer, +) +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + PeerCommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger(__name__) @@ -54,7 +67,8 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue listen_ip=listen_ip, discovery_port=discovery_port, socket_factory=socket_factory, - discovery_socket_factory=discovery_socket_factory) + discovery_socket_factory=discovery_socket_factory, + ) queue.put(peer_communicator.my_connection_info) if peer_communicator.are_all_peers_connected(): peers = peer_communicator.peers() @@ -63,7 +77,9 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue queue.put([]) -@pytest.mark.parametrize("number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)]) +@pytest.mark.parametrize( + "number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)] +) def test_reliability(number_of_instances: int, repetitions: int): run_test_with_repetitions(number_of_instances, repetitions) @@ -85,32 +101,42 @@ def test_functionality_25(): def run_test_with_repetitions(number_of_instances: int, repetitions: int): for i in range(repetitions): - LOGGER.info(f"Start iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances) + LOGGER.info( + f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + ) start_time = time.monotonic() group = f"{time.monotonic_ns()}" - expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances) + expected_peers_of_threads, peers_of_threads = run_test( + group, number_of_instances + ) assert expected_peers_of_threads == peers_of_threads end_time = time.monotonic() - LOGGER.info(f"Finish iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances, - duration=end_time - start_time) + LOGGER.info( + f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + duration=end_time - start_time, + ) def run_test(group: str, number_of_instances: int): connection_infos: Dict[int, ConnectionInfo] = {} parameters = [ PeerCommunicatorTestProcessParameter( - instance_name=f"i{i}", group_identifier=group, + instance_name=f"i{i}", + group_identifier=group, number_of_instances=number_of_instances, - seed=0) - for i in range(number_of_instances)] - processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + seed=0, + ) + for i in range(number_of_instances) + ] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for i in range(number_of_instances): processes[i].start() for i in range(number_of_instances): @@ -120,10 +146,13 @@ def run_test(group: str, number_of_instances: int): for i in range(number_of_instances): peers_of_threads[i] = processes[i].get() expected_peers_of_threads = { - i: sorted([ - Peer(connection_info=connection_info) - for index, connection_info in connection_infos.items() - ], key=key_for_peer) + i: sorted( + [ + Peer(connection_info=connection_info) + for index, connection_info in connection_infos.items() + ], + key=key_for_peer, + ) for i in range(number_of_instances) } return expected_peers_of_threads, peers_of_threads diff --git a/tests/integration_tests/without_db/udf_communication/test_multi_node_discovery.py b/tests/integration_tests/without_db/udf_communication/test_multi_node_discovery.py index 1bae9f7a..88efa0a4 100644 --- a/tests/integration_tests/without_db/udf_communication/test_multi_node_discovery.py +++ b/tests/integration_tests/without_db/udf_communication/test_multi_node_discovery.py @@ -10,32 +10,47 @@ from structlog.types import FilteringBoundLogger from exasol.analytics.udf.communication.connection_info import ConnectionInfo -from exasol.analytics.udf.communication.discovery.multi_node import DiscoverySocketFactory -from exasol.analytics.udf.communication.discovery.multi_node.communicator import \ - CommunicatorFactory -from exasol.analytics.udf.communication.ip_address import Port, IPAddress +from exasol.analytics.udf.communication.discovery.multi_node import ( + DiscoverySocketFactory, +) +from exasol.analytics.udf.communication.discovery.multi_node.communicator import ( + CommunicatorFactory, +) +from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.peer_communicator import key_for_peer -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ - ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, \ - BidirectionalQueue, assert_processes_finish, \ - PeerCommunicatorTestProcessParameter +from exasol.analytics.udf.communication.peer_communicator.peer_communicator import ( + key_for_peer, +) +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ( + ConditionalMethodDropper, +) +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import ( + BidirectionalQueue, + PeerCommunicatorTestProcessParameter, + TestProcess, + assert_processes_finish, +) structlog.configure( context_class=dict, - logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + logger_factory=WriteLoggerFactory( + file=Path(__file__).with_suffix(".log").open("wt") + ), processors=[ structlog.contextvars.merge_contextvars, ConditionalMethodDropper(method_name="debug"), ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(), - structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.ExceptionRenderer( + exception_formatter=ExceptionDictTransformer(locals_max_string=320) + ), structlog.processors.CallsiteParameterAdder(), - structlog.processors.JSONRenderer() - ] + structlog.processors.JSONRenderer(), + ], ) LOGGER: FilteringBoundLogger = structlog.get_logger(__name__) @@ -60,7 +75,8 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue discovery_ip=listen_ip, discovery_port=discovery_port, socket_factory=socket_factory, - discovery_socket_factory=discovery_socket_factory) + discovery_socket_factory=discovery_socket_factory, + ) queue.put(peer_communicator.my_connection_info) if peer_communicator.are_all_peers_connected(): peers = peer_communicator.peers() @@ -69,7 +85,9 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue queue.put([]) -@pytest.mark.parametrize("number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)]) +@pytest.mark.parametrize( + "number_of_instances, repetitions", [(2, 1000), (10, 100), (25, 10)] +) def test_reliability(number_of_instances: int, repetitions: int): run_test_with_repetitions(number_of_instances, repetitions) @@ -99,32 +117,42 @@ def test_functionality_25(): def run_test_with_repetitions(number_of_instances: int, repetitions: int): for i in range(repetitions): - LOGGER.info(f"Start iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances) + LOGGER.info( + f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + ) start_time = time.monotonic() group = f"{time.monotonic_ns()}" - expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances) + expected_peers_of_threads, peers_of_threads = run_test( + group, number_of_instances + ) assert expected_peers_of_threads == peers_of_threads end_time = time.monotonic() - LOGGER.info(f"Finish iteration", - iteration=i + 1, - repetitions=repetitions, - number_of_instances=number_of_instances, - duration=end_time - start_time) + LOGGER.info( + f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + duration=end_time - start_time, + ) def run_test(group: str, number_of_instances: int): connection_infos: Dict[int, ConnectionInfo] = {} parameters = [ PeerCommunicatorTestProcessParameter( - instance_name=f"i{i}", group_identifier=group, + instance_name=f"i{i}", + group_identifier=group, number_of_instances=number_of_instances, - seed=0) - for i in range(number_of_instances)] - processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ - [TestProcess(parameter, run=run) for parameter in parameters] + seed=0, + ) + for i in range(number_of_instances) + ] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = [ + TestProcess(parameter, run=run) for parameter in parameters + ] for i in range(number_of_instances): processes[i].start() for i in range(number_of_instances): @@ -134,10 +162,13 @@ def run_test(group: str, number_of_instances: int): for i in range(number_of_instances): peers_of_threads[i] = processes[i].get() expected_peers_of_threads = { - i: sorted([ - Peer(connection_info=connection_info) - for index, connection_info in connection_infos.items() - ], key=key_for_peer) + i: sorted( + [ + Peer(connection_info=connection_info) + for index, connection_info in connection_infos.items() + ], + key=key_for_peer, + ) for i in range(number_of_instances) } return expected_peers_of_threads, peers_of_threads diff --git a/tests/mock_cast.py b/tests/mock_cast.py deleted file mode 100644 index c6fc3042..00000000 --- a/tests/mock_cast.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import cast, Any -from unittest.mock import Mock - - -def mock_cast(obj: Any) -> Mock: - return cast(Mock, obj) diff --git a/tests/test_package/test_query_handlers/query_handler_test.py b/tests/test_package/test_query_handlers/query_handler_test.py index 35891626..e43cbbda 100644 --- a/tests/test_package/test_query_handlers/query_handler_test.py +++ b/tests/test_package/test_query_handlers/query_handler_test.py @@ -1,22 +1,23 @@ from typing import Union -from exasol.analytics.schema.column import \ - Column -from exasol.analytics.schema.column_name \ - import ColumnName -from exasol.analytics.schema.column_type \ - import ColumnType - -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.query.select import SelectQuery, SelectQueryWithColumnDefinition -from exasol.analytics.query_handler.query_handler import ResultType -from exasol.analytics.query_handler.result import Finish, Continue -from exasol.analytics.query_handler.query.result.interface import QueryResult -from exasol.analytics.query_handler.udf.interface import UDFQueryHandler -from exasol.analytics.query_handler.udf.interface import UDFQueryHandlerFactory +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.query.result.interface import QueryResult +from exasol.analytics.query_handler.query.select import ( + SelectQuery, + SelectQueryWithColumnDefinition, +) +from exasol.analytics.query_handler.query_handler import ResultType +from exasol.analytics.query_handler.result import Continue, Finish +from exasol.analytics.query_handler.udf.interface import ( + UDFQueryHandler, + UDFQueryHandlerFactory, +) +from exasol.analytics.schema.column import Column +from exasol.analytics.schema.column_name import ColumnName +from exasol.analytics.schema.column_type import ColumnType TEST_INPUT = "<>" -FINAL_RESULT = '<>' +FINAL_RESULT = "<>" QUERY_LIST = [SelectQuery("SELECT 1 FROM DUAL"), SelectQuery("SELECT 2 FROM DUAL")] @@ -27,18 +28,24 @@ def __init__(self, parameter: str, query_handler_context: ScopeQueryHandlerConte if not isinstance(parameter, str): raise AssertionError(f"Expected parameter={parameter} to be a string.") if parameter != TEST_INPUT: - raise AssertionError(f"Expected parameter={parameter} to be '{TEST_INPUT}'.") + raise AssertionError( + f"Expected parameter={parameter} to be '{TEST_INPUT}'." + ) def start(self) -> Union[Continue, Finish[ResultType]]: return Finish(result=FINAL_RESULT) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[str]]: pass class QueryHandlerTestWithOneIterationFactory(UDFQueryHandlerFactory): - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: return QueryHandlerTestWithOneIteration(parameter, query_handler_context) @@ -51,16 +58,19 @@ def start(self) -> Union[Continue, Finish[str]]: return_query = 'SELECT 1 AS "a", 2 AS "b" FROM DUAL' return_query_columns = [ Column(ColumnName("a"), ColumnType("INTEGER")), - Column(ColumnName("b"), ColumnType("INTEGER"))] + Column(ColumnName("b"), ColumnType("INTEGER")), + ] query_handler_return_query = SelectQueryWithColumnDefinition( - query_string=return_query, - output_columns=return_query_columns) + query_string=return_query, output_columns=return_query_columns + ) query_handler_result = Continue( - query_list=QUERY_LIST, - input_query=query_handler_return_query) + query_list=QUERY_LIST, input_query=query_handler_return_query + ) return query_handler_result - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[str]]: a = query_result.a if a != 1: raise AssertionError(f"Expected query_result.a={a} to be 1.") @@ -75,11 +85,15 @@ def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Fini class QueryHandlerTestWithTwoIterationFactory(UDFQueryHandlerFactory): - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: return QueryHandlerTestWithTwoIteration(parameter, query_handler_context) -class QueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContext(UDFQueryHandler): +class QueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContext( + UDFQueryHandler +): def __init__(self, parameter: str, query_handler_context: ScopeQueryHandlerContext): super().__init__(parameter, query_handler_context) self.child = None @@ -88,14 +102,22 @@ def start(self) -> Union[Continue, Finish[str]]: self.child = self._query_handler_context.get_child_query_handler_context() return Finish(result=FINAL_RESULT) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[str]]: pass -class QueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContextFactory(UDFQueryHandlerFactory): +class QueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContextFactory( + UDFQueryHandlerFactory +): - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - return QueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContext(parameter, query_handler_context) + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: + return QueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContext( + parameter, query_handler_context + ) class QueryHandlerWithOneIterationWithNotReleasedTemporaryObject(UDFQueryHandler): @@ -110,11 +132,19 @@ def start(self) -> Union[Continue, Finish[str]]: self.proxy = self.child.get_temporary_table_name() return Finish(result=FINAL_RESULT) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[str]]: pass -class QueryHandlerWithOneIterationWithNotReleasedTemporaryObjectFactory(UDFQueryHandlerFactory): +class QueryHandlerWithOneIterationWithNotReleasedTemporaryObjectFactory( + UDFQueryHandlerFactory +): - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - return QueryHandlerWithOneIterationWithNotReleasedTemporaryObject(parameter, query_handler_context) + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: + return QueryHandlerWithOneIterationWithNotReleasedTemporaryObject( + parameter, query_handler_context + ) diff --git a/tests/unit_tests/data_science_utils/schema/test_column.py b/tests/unit_tests/data_science_utils/schema/test_column.py index 39c0a76c..3388b5a5 100644 --- a/tests/unit_tests/data_science_utils/schema/test_column.py +++ b/tests/unit_tests/data_science_utils/schema/test_column.py @@ -1,12 +1,8 @@ import pytest - -from exasol.analytics.schema import ( - Column, - ColumnType, - ColumnNameBuilder, -) from typeguard import TypeCheckError +from exasol.analytics.schema import Column, ColumnNameBuilder, ColumnType + def test_set_new_type_fail(): column = Column(ColumnNameBuilder.create("abc"), ColumnType("INTEGER")) @@ -24,16 +20,19 @@ def test_wrong_types_in_constructor(): with pytest.raises(TypeCheckError) as c: column = Column("abc", "INTEGER") + def test_equality(): column1 = Column(ColumnNameBuilder.create("abc"), ColumnType("INTEGER")) column2 = Column(ColumnNameBuilder.create("abc"), ColumnType("INTEGER")) assert column1 == column2 + def test_inequality_name(): column1 = Column(ColumnNameBuilder.create("abc"), ColumnType("INTEGER")) column2 = Column(ColumnNameBuilder.create("def"), ColumnType("INTEGER")) assert column1 != column2 + def test_inequality_type(): column1 = Column(ColumnNameBuilder.create("abc"), ColumnType("INTEGER")) column2 = Column(ColumnNameBuilder.create("def"), ColumnType("VARCHAR")) @@ -56,4 +55,3 @@ def test_hash_inequality_type(): column1 = Column(ColumnNameBuilder.create("abc"), ColumnType("INTEGER")) column2 = Column(ColumnNameBuilder.create("abc"), ColumnType("VARCHAR")) assert hash(column1) != hash(column2) - diff --git a/tests/unit_tests/data_science_utils/schema/test_column_builder.py b/tests/unit_tests/data_science_utils/schema/test_column_builder.py index 8285605c..8dc34f3c 100644 --- a/tests/unit_tests/data_science_utils/schema/test_column_builder.py +++ b/tests/unit_tests/data_science_utils/schema/test_column_builder.py @@ -1,12 +1,9 @@ import pytest - -from exasol.analytics.schema import ( - ColumnBuilder, - ColumnType, - ColumnNameBuilder, -) from typeguard import TypeCheckError +from exasol.analytics.schema import ColumnBuilder, ColumnNameBuilder, ColumnType + + def test_create_column_with_name_only(): with pytest.raises(TypeCheckError): column = ColumnBuilder().with_name(ColumnNameBuilder.create("column")).build() @@ -18,5 +15,10 @@ def test_create_column_with_type_only(): def test_create_column_with_name_and_type(): - column = ColumnBuilder().with_name(ColumnNameBuilder.create("column")).with_type(type=ColumnType("INTEGER")).build() + column = ( + ColumnBuilder() + .with_name(ColumnNameBuilder.create("column")) + .with_type(type=ColumnType("INTEGER")) + .build() + ) assert column.name.name == "column" and column.type.name == "INTEGER" diff --git a/tests/unit_tests/data_science_utils/schema/test_column_name.py b/tests/unit_tests/data_science_utils/schema/test_column_name.py index 6b7468c5..0caa8da4 100644 --- a/tests/unit_tests/data_science_utils/schema/test_column_name.py +++ b/tests/unit_tests/data_science_utils/schema/test_column_name.py @@ -1,10 +1,6 @@ import pytest -from exasol.analytics.schema import ( - TableNameBuilder, - SchemaName, - ColumnName, -) +from exasol.analytics.schema import ColumnName, SchemaName, TableNameBuilder def test_fully_qualified(): @@ -18,7 +14,9 @@ def test_fully_qualified_with_table(): def test_fully_qualified_with_table_and_schema(): - column = ColumnName("column", TableNameBuilder.create("table", schema=SchemaName("schema"))) + column = ColumnName( + "column", TableNameBuilder.create("table", schema=SchemaName("schema")) + ) assert column.fully_qualified == '"schema"."table"."column"' diff --git a/tests/unit_tests/data_science_utils/schema/test_column_name_builder.py b/tests/unit_tests/data_science_utils/schema/test_column_name_builder.py index e2429acc..f4b17d51 100644 --- a/tests/unit_tests/data_science_utils/schema/test_column_name_builder.py +++ b/tests/unit_tests/data_science_utils/schema/test_column_name_builder.py @@ -1,12 +1,8 @@ import pytest - -from exasol.analytics.schema import ( - TableNameImpl, - ColumnName, - ColumnNameBuilder, -) from typeguard import TypeCheckError +from exasol.analytics.schema import ColumnName, ColumnNameBuilder, TableNameImpl + def test_using_empty_constructor(): with pytest.raises(TypeCheckError) as ex: @@ -19,8 +15,10 @@ def test_using_constructor_name_only(): def test_using_constructor_table(): - column_name = ColumnNameBuilder(name="column", table_like_name=TableNameImpl("table")).build() - assert column_name.name == "column" and column_name.table_like_name.name is "table" + column_name = ColumnNameBuilder( + name="column", table_like_name=TableNameImpl("table") + ).build() + assert column_name.name == "column" and column_name.table_like_name.name == "table" def test_using_with_name_only(): @@ -29,43 +27,64 @@ def test_using_with_name_only(): def test_using_with_table(): - column_name = ColumnNameBuilder().with_name("table").with_table_like_name(TableNameImpl("table")).build() + column_name = ( + ColumnNameBuilder() + .with_name("table") + .with_table_like_name(TableNameImpl("table")) + .build() + ) assert column_name.name == "table" and column_name.table_like_name.name == "table" def test_from_existing_using_with_table(): source_column_name = ColumnName("column") - column_name = ColumnNameBuilder(column_name=source_column_name).with_table_like_name(TableNameImpl("table")).build() - assert source_column_name.name == "column" \ - and source_column_name.table_like_name is None \ - and column_name.name == "column" \ - and column_name.table_like_name.name == "table" + column_name = ( + ColumnNameBuilder(column_name=source_column_name) + .with_table_like_name(TableNameImpl("table")) + .build() + ) + assert ( + source_column_name.name == "column" + and source_column_name.table_like_name is None + and column_name.name == "column" + and column_name.table_like_name.name == "table" + ) def test_from_existing_using_with_name(): source_column_name = ColumnName("column", TableNameImpl("table")) - column_name = ColumnNameBuilder(column_name=source_column_name).with_name("column1").build() - assert source_column_name.name == "column" \ - and source_column_name.table_like_name.name == "table" \ - and column_name.table_like_name.name == "table" \ - and column_name.name == "column1" + column_name = ( + ColumnNameBuilder(column_name=source_column_name).with_name("column1").build() + ) + assert ( + source_column_name.name == "column" + and source_column_name.table_like_name.name == "table" + and column_name.table_like_name.name == "table" + and column_name.name == "column1" + ) def test_from_existing_and_new_table_in_constructor(): source_column_name = ColumnName("column") - column_name = ColumnNameBuilder(table_like_name=TableNameImpl("table"), - column_name=source_column_name).build() - assert source_column_name.name == "column" \ - and source_column_name.table_like_name is None \ - and column_name.name == "column" \ - and column_name.table_like_name.name == "table" + column_name = ColumnNameBuilder( + table_like_name=TableNameImpl("table"), column_name=source_column_name + ).build() + assert ( + source_column_name.name == "column" + and source_column_name.table_like_name is None + and column_name.name == "column" + and column_name.table_like_name.name == "table" + ) def test_from_existing_and_new_name_in_constructor(): source_column_name = ColumnName("column", TableNameImpl("table")) - column_name = ColumnNameBuilder(name="column1", - column_name=source_column_name).build() - assert source_column_name.name == "column" \ - and source_column_name.table_like_name.name == "table" \ - and column_name.table_like_name.name == "table" \ - and column_name.name == "column1" + column_name = ColumnNameBuilder( + name="column1", column_name=source_column_name + ).build() + assert ( + source_column_name.name == "column" + and source_column_name.table_like_name.name == "table" + and column_name.table_like_name.name == "table" + and column_name.name == "column1" + ) diff --git a/tests/unit_tests/data_science_utils/schema/test_column_type.py b/tests/unit_tests/data_science_utils/schema/test_column_type.py index f5c4f89d..a3f9416f 100644 --- a/tests/unit_tests/data_science_utils/schema/test_column_type.py +++ b/tests/unit_tests/data_science_utils/schema/test_column_type.py @@ -1,24 +1,26 @@ import pytest +from typeguard import TypeCheckError from exasol.analytics.schema import ( - SchemaName, - TableNameBuilder, ColumnName, ColumnType, + SchemaName, + TableNameBuilder, TableNameImpl, ) -from typeguard import TypeCheckError def test_correct_types(): - ColumnType(name="COLUMN", - precision=0, - scale=0, - size=0, - characterSet="UTF-8", - withLocalTimeZone=True, - fraction=0, - srid=0) + ColumnType( + name="COLUMN", + precision=0, + scale=0, + size=0, + characterSet="UTF-8", + withLocalTimeZone=True, + fraction=0, + srid=0, + ) def test_optionals(): @@ -71,22 +73,26 @@ def test_srid_wrong_type(): def test_equality(): - column1 = ColumnType(name="COLUMN", - precision=0, - scale=0, - size=0, - characterSet="UTF-8", - withLocalTimeZone=True, - fraction=0, - srid=0) - column2 = ColumnType(name="COLUMN", - precision=0, - scale=0, - size=0, - characterSet="UTF-8", - withLocalTimeZone=True, - fraction=0, - srid=0) + column1 = ColumnType( + name="COLUMN", + precision=0, + scale=0, + size=0, + characterSet="UTF-8", + withLocalTimeZone=True, + fraction=0, + srid=0, + ) + column2 = ColumnType( + name="COLUMN", + precision=0, + scale=0, + size=0, + characterSet="UTF-8", + withLocalTimeZone=True, + fraction=0, + srid=0, + ) assert column1 == column2 @@ -116,7 +122,10 @@ def test_inequality_size(): def test_inequality_characterSet(): column1 = ColumnType(name="COLUMN", characterSet="UTF-8") - column2 = ColumnType(name="COLUMN", characterSet="ASCII", ) + column2 = ColumnType( + name="COLUMN", + characterSet="ASCII", + ) assert column1 != column2 @@ -139,22 +148,26 @@ def test_inequality_srid(): def test_hash_equality(): - column1 = ColumnType(name="COLUMN", - precision=0, - scale=0, - size=0, - characterSet="UTF-8", - withLocalTimeZone=True, - fraction=0, - srid=0) - column2 = ColumnType(name="COLUMN", - precision=0, - scale=0, - size=0, - characterSet="UTF-8", - withLocalTimeZone=True, - fraction=0, - srid=0) + column1 = ColumnType( + name="COLUMN", + precision=0, + scale=0, + size=0, + characterSet="UTF-8", + withLocalTimeZone=True, + fraction=0, + srid=0, + ) + column2 = ColumnType( + name="COLUMN", + precision=0, + scale=0, + size=0, + characterSet="UTF-8", + withLocalTimeZone=True, + fraction=0, + srid=0, + ) assert hash(column1) == hash(column2) @@ -184,7 +197,10 @@ def test_hash_inequality_size(): def test_hash_inequality_characterSet(): column1 = ColumnType(name="COLUMN", characterSet="UTF-8") - column2 = ColumnType(name="COLUMN", characterSet="ASCII", ) + column2 = ColumnType( + name="COLUMN", + characterSet="ASCII", + ) assert hash(column1) != hash(column2) diff --git a/tests/unit_tests/data_science_utils/schema/test_connection_object_name.py b/tests/unit_tests/data_science_utils/schema/test_connection_object_name.py index db3f8a66..0d4bb54d 100644 --- a/tests/unit_tests/data_science_utils/schema/test_connection_object_name.py +++ b/tests/unit_tests/data_science_utils/schema/test_connection_object_name.py @@ -1,6 +1,5 @@ from exasol.analytics.schema import ConnectionObjectNameImpl - CONNECTION_UPPER = "CONNECTION" CONNECTION = "connection" diff --git a/tests/unit_tests/data_science_utils/schema/test_connection_object_name_builder.py b/tests/unit_tests/data_science_utils/schema/test_connection_object_name_builder.py index 53689b86..14bf420e 100644 --- a/tests/unit_tests/data_science_utils/schema/test_connection_object_name_builder.py +++ b/tests/unit_tests/data_science_utils/schema/test_connection_object_name_builder.py @@ -1,10 +1,10 @@ import pytest from exasol.analytics.schema import ( - TableNameImpl, - ConnectionObjectNameBuilder, ColumnName, ColumnNameBuilder, + ConnectionObjectNameBuilder, + TableNameImpl, ) diff --git a/tests/unit_tests/data_science_utils/schema/test_identifier.py b/tests/unit_tests/data_science_utils/schema/test_identifier.py index fd130a7d..a542f0bc 100644 --- a/tests/unit_tests/data_science_utils/schema/test_identifier.py +++ b/tests/unit_tests/data_science_utils/schema/test_identifier.py @@ -1,9 +1,6 @@ import pytest -from exasol.analytics.schema import ( - ExasolIdentifierImpl, - ExasolIdentifier, -) +from exasol.analytics.schema import ExasolIdentifier, ExasolIdentifierImpl class TestSchemaElement(ExasolIdentifierImpl): @@ -23,47 +20,28 @@ def __repr__(self): raise NotImplemented() -@pytest.mark.parametrize("test_name", - [ - "A", - "a", - "B_", - "Z1", - "Q\uFE33", - "Ü", - "1" - ]) +@pytest.mark.parametrize("test_name", ["A", "a", "B_", "Z1", "Q\uFE33", "Ü", "1"]) def test_name_valid(test_name): TestSchemaElement(test_name) -@pytest.mark.parametrize("test_name", - [ - ".", - "A.s" - "_", - ",", - ";", - ":", - "\uFE33", - '"', - 'A"', - "A'", - "A,", - "A;", - "A:" - ]) +@pytest.mark.parametrize( + "test_name", + [".", "A.s" "_", ",", ";", ":", "\uFE33", '"', 'A"', "A'", "A,", "A;", "A:"], +) def test_name_invalid(test_name): with pytest.raises(ValueError): TestSchemaElement(test_name) -@pytest.mark.parametrize("name,expected_quoted_name", - [ - ('ABC', '"ABC"'), - # ('A"BC', '"A""BC"'), names with double quotes at the moment not valid - ('abc', '"abc"') - ]) +@pytest.mark.parametrize( + "name,expected_quoted_name", + [ + ("ABC", '"ABC"'), + # ('A"BC', '"A""BC"'), names with double quotes at the moment not valid + ("abc", '"abc"'), + ], +) def test_quote(name, expected_quoted_name): quoted_name = TestSchemaElement(name).quoted_name assert quoted_name == expected_quoted_name diff --git a/tests/unit_tests/data_science_utils/schema/test_schema_name.py b/tests/unit_tests/data_science_utils/schema/test_schema_name.py index affcd2e5..702a1b4c 100644 --- a/tests/unit_tests/data_science_utils/schema/test_schema_name.py +++ b/tests/unit_tests/data_science_utils/schema/test_schema_name.py @@ -1,7 +1,6 @@ from exasol.analytics.schema import SchemaName - def test_fully_qualified(): schema = SchemaName("schema") assert schema.fully_qualified == '"schema"' diff --git a/tests/unit_tests/data_science_utils/schema/test_table.py b/tests/unit_tests/data_science_utils/schema/test_table.py index fa2571d5..cdfd1ac8 100644 --- a/tests/unit_tests/data_science_utils/schema/test_table.py +++ b/tests/unit_tests/data_science_utils/schema/test_table.py @@ -1,20 +1,23 @@ import pytest +from typeguard import TypeCheckError from exasol.analytics.schema import ( Column, ColumnNameBuilder, ColumnType, - TableNameImpl, Table, + TableNameImpl, ) -from typeguard import TypeCheckError def test_valid(): - table = Table(TableNameImpl("table"), [ - Column(ColumnNameBuilder.create("column1"), ColumnType("INTEGER")), - Column(ColumnNameBuilder.create("column2"), ColumnType("VACHAR")), - ]) + table = Table( + TableNameImpl("table"), + [ + Column(ColumnNameBuilder.create("column1"), ColumnType("INTEGER")), + Column(ColumnNameBuilder.create("column2"), ColumnType("VACHAR")), + ], + ) def test_no_columns_fail(): @@ -24,22 +27,33 @@ def test_no_columns_fail(): def test_duplicate_column_names_fail(): with pytest.raises(ValueError, match="Column names are not unique.") as c: - table = Table(TableNameImpl("table"), [ - Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), - Column(ColumnNameBuilder.create("column"), ColumnType("VACHAR")), - ]) + table = Table( + TableNameImpl("table"), + [ + Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), + Column(ColumnNameBuilder.create("column"), ColumnType("VACHAR")), + ], + ) def test_set_new_name_fail(): - table = Table(TableNameImpl("table"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + table = Table( + TableNameImpl("table"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) with pytest.raises(AttributeError) as c: table.name = "edf" def test_set_new_columns_fail(): - table = Table(TableNameImpl("table"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + table = Table( + TableNameImpl("table"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) with pytest.raises(AttributeError) as c: - table.columns = [Column(ColumnNameBuilder.create("column1"), ColumnType("INTEGER"))] + table.columns = [ + Column(ColumnNameBuilder.create("column1"), ColumnType("INTEGER")) + ] def test_wrong_types_in_constructor(): @@ -48,51 +62,88 @@ def test_wrong_types_in_constructor(): def test_columns_list_is_immutable(): - table = Table(TableNameImpl("table"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + table = Table( + TableNameImpl("table"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) columns = table.columns columns.append(Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))) assert len(columns) == 2 and len(table.columns) == 1 def test_equality(): - table1 = Table(TableNameImpl("table"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - table2 = Table(TableNameImpl("table"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + table1 = Table( + TableNameImpl("table"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + table2 = Table( + TableNameImpl("table"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) assert table1 == table2 def test_inequality_name(): - table1 = Table(TableNameImpl("table1"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - table2 = Table(TableNameImpl("table2"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + table1 = Table( + TableNameImpl("table1"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + table2 = Table( + TableNameImpl("table2"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) assert table1 != table2 def test_inequality_columns(): - table1 = Table(TableNameImpl("table1"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - table2 = Table(TableNameImpl("table1"), - [ - Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), - Column(ColumnNameBuilder.create("column2"), ColumnType("INTEGER")) - ]) + table1 = Table( + TableNameImpl("table1"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + table2 = Table( + TableNameImpl("table1"), + [ + Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), + Column(ColumnNameBuilder.create("column2"), ColumnType("INTEGER")), + ], + ) assert table1 != table2 def test_hash_equality(): - table1 = Table(TableNameImpl("table"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - table2 = Table(TableNameImpl("table"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + table1 = Table( + TableNameImpl("table"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + table2 = Table( + TableNameImpl("table"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) assert hash(table1) == hash(table2) def test_hash_inequality_name(): - table1 = Table(TableNameImpl("table1"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - table2 = Table(TableNameImpl("table2"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + table1 = Table( + TableNameImpl("table1"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + table2 = Table( + TableNameImpl("table2"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) assert hash(table1) != hash(table2) def test_hash_inequality_columns(): - table1 = Table(TableNameImpl("table1"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - table2 = Table(TableNameImpl("table1"), - [ - Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), - Column(ColumnNameBuilder.create("column2"), ColumnType("INTEGER")) - ]) + table1 = Table( + TableNameImpl("table1"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + table2 = Table( + TableNameImpl("table1"), + [ + Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), + Column(ColumnNameBuilder.create("column2"), ColumnType("INTEGER")), + ], + ) assert hash(table1) != hash(table2) diff --git a/tests/unit_tests/data_science_utils/schema/test_table_builder.py b/tests/unit_tests/data_science_utils/schema/test_table_builder.py index 5206065e..cdb8cb69 100644 --- a/tests/unit_tests/data_science_utils/schema/test_table_builder.py +++ b/tests/unit_tests/data_science_utils/schema/test_table_builder.py @@ -1,4 +1,5 @@ import pytest +from typeguard import TypeCheckError from exasol.analytics.schema import ( Column, @@ -7,7 +8,6 @@ TableBuilder, TableNameImpl, ) -from typeguard import TypeCheckError def test_create_table_with_name_only_fail(): @@ -17,10 +17,22 @@ def test_create_table_with_name_only_fail(): def test_create_table_with_columns_only_fail(): with pytest.raises(TypeCheckError): - column = TableBuilder().with_columns([Column(ColumnNameBuilder.create("abc"), ColumnType("INTEGER"))]).build() + column = ( + TableBuilder() + .with_columns( + [Column(ColumnNameBuilder.create("abc"), ColumnType("INTEGER"))] + ) + .build() + ) def test_create_table_with_name_and_columns(): - table = TableBuilder().with_name(TableNameImpl("table")).with_columns( - [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]).build() + table = ( + TableBuilder() + .with_name(TableNameImpl("table")) + .with_columns( + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))] + ) + .build() + ) assert table.name.name == "table" and table.columns[0].name.name == "column" diff --git a/tests/unit_tests/data_science_utils/schema/test_table_name.py b/tests/unit_tests/data_science_utils/schema/test_table_name.py index 52c75238..43ed75e8 100644 --- a/tests/unit_tests/data_science_utils/schema/test_table_name.py +++ b/tests/unit_tests/data_science_utils/schema/test_table_name.py @@ -1,9 +1,6 @@ import pytest -from exasol.analytics.schema import ( - SchemaName, - TableNameImpl, -) +from exasol.analytics.schema import SchemaName, TableNameImpl def test_fully_qualified(): diff --git a/tests/unit_tests/data_science_utils/schema/test_table_name_builder.py b/tests/unit_tests/data_science_utils/schema/test_table_name_builder.py index f3e19f7e..09c915dc 100644 --- a/tests/unit_tests/data_science_utils/schema/test_table_name_builder.py +++ b/tests/unit_tests/data_science_utils/schema/test_table_name_builder.py @@ -1,12 +1,12 @@ import pytest +from typeguard import TypeCheckError from exasol.analytics.schema import ( - TableNameBuilder, SchemaName, - TableNameImpl, TableName, + TableNameBuilder, + TableNameImpl, ) -from typeguard import TypeCheckError def test_using_empty_constructor(): @@ -16,82 +16,110 @@ def test_using_empty_constructor(): def test_using_constructor_name_only(): table_name = TableNameBuilder(name="table").build() - assert table_name.name == "table" \ - and table_name.schema_name is None \ - and isinstance(table_name, TableName) + assert ( + table_name.name == "table" + and table_name.schema_name is None + and isinstance(table_name, TableName) + ) def test_using_constructor_schema(): table_name = TableNameBuilder(name="table", schema=SchemaName("schema")).build() - assert table_name.name == "table" \ - and table_name.schema_name.name is "schema" \ - and isinstance(table_name, TableName) + assert ( + table_name.name == "table" + and table_name.schema_name.name == "schema" + and isinstance(table_name, TableName) + ) def test_using_with_name_only(): table_name = TableNameBuilder().with_name("table").build() - assert table_name.name == "table" \ - and table_name.schema_name is None \ - and isinstance(table_name, TableName) + assert ( + table_name.name == "table" + and table_name.schema_name is None + and isinstance(table_name, TableName) + ) def test_using_with_schema(): - table_name = TableNameBuilder().with_name("table").with_schema_name(SchemaName("schema")).build() - assert table_name.name == "table" \ - and table_name.schema_name.name == "schema" \ - and isinstance(table_name, TableName) + table_name = ( + TableNameBuilder() + .with_name("table") + .with_schema_name(SchemaName("schema")) + .build() + ) + assert ( + table_name.name == "table" + and table_name.schema_name.name == "schema" + and isinstance(table_name, TableName) + ) def test_from_existing_using_with_schema(): source_table_name = TableNameImpl("table") - table_name = TableNameBuilder(table_name=source_table_name).with_schema_name(SchemaName("schema")).build() - assert source_table_name.name == "table" \ - and source_table_name.schema_name is None \ - and table_name.name == "table" \ - and table_name.schema_name.name == "schema" \ - and isinstance(table_name, TableName) + table_name = ( + TableNameBuilder(table_name=source_table_name) + .with_schema_name(SchemaName("schema")) + .build() + ) + assert ( + source_table_name.name == "table" + and source_table_name.schema_name is None + and table_name.name == "table" + and table_name.schema_name.name == "schema" + and isinstance(table_name, TableName) + ) def test_from_existing_using_with_name(): source_table_name = TableNameImpl("table", SchemaName("schema")) - table_name = TableNameBuilder(table_name=source_table_name).with_name("table1").build() - assert source_table_name.name == "table" \ - and source_table_name.schema_name.name == "schema" \ - and table_name.schema_name.name == "schema" \ - and table_name.name == "table1" \ - and isinstance(table_name, TableName) + table_name = ( + TableNameBuilder(table_name=source_table_name).with_name("table1").build() + ) + assert ( + source_table_name.name == "table" + and source_table_name.schema_name.name == "schema" + and table_name.schema_name.name == "schema" + and table_name.name == "table1" + and isinstance(table_name, TableName) + ) def test_from_existing_and_new_schema_in_constructor(): source_table_name = TableNameImpl("table") - table_name = TableNameBuilder(schema=SchemaName("schema"), - table_name=source_table_name).build() - assert source_table_name.name == "table" \ - and source_table_name.schema_name is None \ - and table_name.name == "table" \ - and table_name.schema_name.name == "schema" \ - and isinstance(table_name, TableName) + table_name = TableNameBuilder( + schema=SchemaName("schema"), table_name=source_table_name + ).build() + assert ( + source_table_name.name == "table" + and source_table_name.schema_name is None + and table_name.name == "table" + and table_name.schema_name.name == "schema" + and isinstance(table_name, TableName) + ) def test_from_existing_and_new_name_in_constructor(): source_table_name = TableNameImpl("table", SchemaName("schema")) - table_name = TableNameBuilder(name="table1", - table_name=source_table_name).build() - assert source_table_name.name == "table" \ - and source_table_name.schema_name.name == "schema" \ - and table_name.schema_name.name == "schema" \ - and table_name.name == "table1" \ - and isinstance(table_name, TableName) + table_name = TableNameBuilder(name="table1", table_name=source_table_name).build() + assert ( + source_table_name.name == "table" + and source_table_name.schema_name.name == "schema" + and table_name.schema_name.name == "schema" + and table_name.name == "table1" + and isinstance(table_name, TableName) + ) def test_using_create_name_using_only_name(): table_name = TableNameBuilder.create(name="table") - assert table_name.name == "table" \ - and isinstance(table_name, TableName) + assert table_name.name == "table" and isinstance(table_name, TableName) def test_using_create_name_using_schema(): table_name = TableNameBuilder.create(name="table", schema=SchemaName("schema")) - assert table_name.name == "table" \ - and table_name.schema_name.name == "schema" \ - and isinstance(table_name, TableName) + assert ( + table_name.name == "table" + and table_name.schema_name.name == "schema" + and isinstance(table_name, TableName) + ) diff --git a/tests/unit_tests/data_science_utils/schema/test_udf_name.py b/tests/unit_tests/data_science_utils/schema/test_udf_name.py index a037d5dc..27838569 100644 --- a/tests/unit_tests/data_science_utils/schema/test_udf_name.py +++ b/tests/unit_tests/data_science_utils/schema/test_udf_name.py @@ -1,9 +1,6 @@ import pytest -from exasol.analytics.schema import ( - UDFNameImpl, - SchemaName, -) +from exasol.analytics.schema import SchemaName, UDFNameImpl def test_fully_qualified(): diff --git a/tests/unit_tests/data_science_utils/schema/test_udf_name_builder.py b/tests/unit_tests/data_science_utils/schema/test_udf_name_builder.py index 6a423000..23449cc2 100644 --- a/tests/unit_tests/data_science_utils/schema/test_udf_name_builder.py +++ b/tests/unit_tests/data_science_utils/schema/test_udf_name_builder.py @@ -1,13 +1,8 @@ import pytest - -from exasol.analytics.schema import ( - UDFNameImpl, - UDFNameBuilder, - SchemaName, - UDFName, -) from typeguard import TypeCheckError +from exasol.analytics.schema import SchemaName, UDFName, UDFNameBuilder, UDFNameImpl + def test_using_empty_constructor(): with pytest.raises(TypeCheckError): @@ -16,84 +11,110 @@ def test_using_empty_constructor(): def test_using_constructor_name_only(): udf_name = UDFNameBuilder(name="udf").build() - assert udf_name.name == "udf" \ - and udf_name.schema_name is None \ - and isinstance(udf_name, UDFName) + assert ( + udf_name.name == "udf" + and udf_name.schema_name is None + and isinstance(udf_name, UDFName) + ) def test_using_constructor_schema(): udf_name = UDFNameBuilder(name="udf", schema=SchemaName("schema")).build() - assert udf_name.name == "udf" \ - and udf_name.schema_name.name is "schema" \ - and isinstance(udf_name, UDFName) + assert ( + udf_name.name == "udf" + and udf_name.schema_name.name == "schema" + and isinstance(udf_name, UDFName) + ) def test_using_with_name_only(): udf_name = UDFNameBuilder().with_name("udf").build() - assert udf_name.name == "udf" \ - and udf_name.schema_name is None \ - and isinstance(udf_name, UDFName) + assert ( + udf_name.name == "udf" + and udf_name.schema_name is None + and isinstance(udf_name, UDFName) + ) def test_using_with_schema(): - udf_name = UDFNameBuilder().with_name("udf").with_schema_name(SchemaName("schema")).build() - assert udf_name.name == "udf" \ - and udf_name.schema_name.name == "schema" \ - and isinstance(udf_name, UDFName) + udf_name = ( + UDFNameBuilder().with_name("udf").with_schema_name(SchemaName("schema")).build() + ) + assert ( + udf_name.name == "udf" + and udf_name.schema_name.name == "schema" + and isinstance(udf_name, UDFName) + ) def test_from_existing_using_with_schema(): source_udf_name = UDFNameImpl("udf") - udf_name = UDFNameBuilder(udf_name=source_udf_name).with_schema_name(SchemaName("schema")).build() - assert source_udf_name.name == "udf" \ - and source_udf_name.schema_name is None \ - and udf_name.name == "udf" \ - and udf_name.schema_name.name == "schema" \ - and isinstance(udf_name, UDFName) + udf_name = ( + UDFNameBuilder(udf_name=source_udf_name) + .with_schema_name(SchemaName("schema")) + .build() + ) + assert ( + source_udf_name.name == "udf" + and source_udf_name.schema_name is None + and udf_name.name == "udf" + and udf_name.schema_name.name == "schema" + and isinstance(udf_name, UDFName) + ) def test_from_existing_using_with_name(): source_udf_name = UDFNameImpl("udf", SchemaName("schema")) udf_name = UDFNameBuilder(udf_name=source_udf_name).with_name("udf1").build() - assert source_udf_name.name == "udf" \ - and source_udf_name.schema_name.name == "schema" \ - and udf_name.schema_name.name == "schema" \ - and udf_name.name == "udf1" \ - and isinstance(udf_name, UDFName) + assert ( + source_udf_name.name == "udf" + and source_udf_name.schema_name.name == "schema" + and udf_name.schema_name.name == "schema" + and udf_name.name == "udf1" + and isinstance(udf_name, UDFName) + ) def test_from_existing_and_new_schema_in_constructor(): source_udf_name = UDFNameImpl("udf") - udf_name = UDFNameBuilder(schema=SchemaName("schema"), - udf_name=source_udf_name).build() - assert source_udf_name.name == "udf" \ - and source_udf_name.schema_name is None \ - and udf_name.name == "udf" \ - and udf_name.schema_name.name == "schema" \ - and isinstance(udf_name, UDFName) + udf_name = UDFNameBuilder( + schema=SchemaName("schema"), udf_name=source_udf_name + ).build() + assert ( + source_udf_name.name == "udf" + and source_udf_name.schema_name is None + and udf_name.name == "udf" + and udf_name.schema_name.name == "schema" + and isinstance(udf_name, UDFName) + ) def test_from_existing_and_new_name_in_constructor(): source_udf_name = UDFNameImpl("udf", SchemaName("schema")) - udf_name = UDFNameBuilder(name="udf1", - udf_name=source_udf_name).build() - assert source_udf_name.name == "udf" \ - and source_udf_name.schema_name.name == "schema" \ - and udf_name.schema_name.name == "schema" \ - and udf_name.name == "udf1" \ - and isinstance(udf_name, UDFName) + udf_name = UDFNameBuilder(name="udf1", udf_name=source_udf_name).build() + assert ( + source_udf_name.name == "udf" + and source_udf_name.schema_name.name == "schema" + and udf_name.schema_name.name == "schema" + and udf_name.name == "udf1" + and isinstance(udf_name, UDFName) + ) def test_using_create_name_using_only_name(): udf_name = UDFNameBuilder.create(name="udf") - assert udf_name.name == "udf" \ - and isinstance(udf_name, UDFName) \ - and isinstance(udf_name, UDFName) + assert ( + udf_name.name == "udf" + and isinstance(udf_name, UDFName) + and isinstance(udf_name, UDFName) + ) def test_using_create_name_using_schema(): udf_name = UDFNameBuilder.create(name="udf", schema=SchemaName("schema")) - assert udf_name.name == "udf" \ - and udf_name.schema_name.name == "schema" \ - and isinstance(udf_name, UDFName) \ - and isinstance(udf_name, UDFName) + assert ( + udf_name.name == "udf" + and udf_name.schema_name.name == "schema" + and isinstance(udf_name, UDFName) + and isinstance(udf_name, UDFName) + ) diff --git a/tests/unit_tests/data_science_utils/schema/test_view.py b/tests/unit_tests/data_science_utils/schema/test_view.py index e91dbd66..a140a8e2 100644 --- a/tests/unit_tests/data_science_utils/schema/test_view.py +++ b/tests/unit_tests/data_science_utils/schema/test_view.py @@ -1,4 +1,5 @@ import pytest +from typeguard import TypeCheckError from exasol.analytics.schema import ( Column, @@ -7,14 +8,16 @@ View, ViewNameImpl, ) -from typeguard import TypeCheckError def test_valid(): - table = View(ViewNameImpl("view_name"), [ - Column(ColumnNameBuilder.create("column1"), ColumnType("INTEGER")), - Column(ColumnNameBuilder.create("column2"), ColumnType("VACHAR")), - ]) + table = View( + ViewNameImpl("view_name"), + [ + Column(ColumnNameBuilder.create("column1"), ColumnType("INTEGER")), + Column(ColumnNameBuilder.create("column2"), ColumnType("VACHAR")), + ], + ) def test_no_columns_fail(): @@ -24,22 +27,33 @@ def test_no_columns_fail(): def test_duplicate_column_names_fail(): with pytest.raises(ValueError, match="Column names are not unique.") as c: - table = View(ViewNameImpl("view_name"), [ - Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), - Column(ColumnNameBuilder.create("column"), ColumnType("VACHAR")), - ]) + table = View( + ViewNameImpl("view_name"), + [ + Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), + Column(ColumnNameBuilder.create("column"), ColumnType("VACHAR")), + ], + ) def test_set_new_name_fail(): - view = View(ViewNameImpl("view"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + view = View( + ViewNameImpl("view"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) with pytest.raises(AttributeError) as c: view.name = "edf" def test_set_new_columns_fail(): - view = View(ViewNameImpl("view"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + view = View( + ViewNameImpl("view"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) with pytest.raises(AttributeError) as c: - view.columns = [Column(ColumnNameBuilder.create("column1"), ColumnType("INTEGER"))] + view.columns = [ + Column(ColumnNameBuilder.create("column1"), ColumnType("INTEGER")) + ] def test_wrong_types_in_constructor(): @@ -48,51 +62,88 @@ def test_wrong_types_in_constructor(): def test_columns_list_is_immutable(): - view = View(ViewNameImpl("view"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + view = View( + ViewNameImpl("view"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) columns = view.columns columns.append(Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))) assert len(columns) == 2 and len(view.columns) == 1 def test_equality(): - view1 = View(ViewNameImpl("view"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - view2 = View(ViewNameImpl("view"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + view1 = View( + ViewNameImpl("view"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + view2 = View( + ViewNameImpl("view"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) assert view1 == view2 def test_inequality_name(): - view1 = View(ViewNameImpl("view1"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - view2 = View(ViewNameImpl("view2"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + view1 = View( + ViewNameImpl("view1"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + view2 = View( + ViewNameImpl("view2"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) assert view1 != view2 def test_inequality_columns(): - view1 = View(ViewNameImpl("view1"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - view2 = View(ViewNameImpl("view1"), - [ - Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), - Column(ColumnNameBuilder.create("column2"), ColumnType("INTEGER")) - ]) + view1 = View( + ViewNameImpl("view1"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + view2 = View( + ViewNameImpl("view1"), + [ + Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), + Column(ColumnNameBuilder.create("column2"), ColumnType("INTEGER")), + ], + ) assert view1 != view2 def test_hash_equality(): - view1 = View(ViewNameImpl("view"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - view2 = View(ViewNameImpl("view"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + view1 = View( + ViewNameImpl("view"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + view2 = View( + ViewNameImpl("view"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) assert hash(view1) == hash(view2) def test_hash_inequality_name(): - view1 = View(ViewNameImpl("view1"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - view2 = View(ViewNameImpl("view2"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) + view1 = View( + ViewNameImpl("view1"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + view2 = View( + ViewNameImpl("view2"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) assert hash(view1) != hash(view2) def test_hash_inequality_columns(): - view1 = View(ViewNameImpl("view1"), [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))]) - view2 = View(ViewNameImpl("view1"), - [ - Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), - Column(ColumnNameBuilder.create("column2"), ColumnType("INTEGER")) - ]) + view1 = View( + ViewNameImpl("view1"), + [Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER"))], + ) + view2 = View( + ViewNameImpl("view1"), + [ + Column(ColumnNameBuilder.create("column"), ColumnType("INTEGER")), + Column(ColumnNameBuilder.create("column2"), ColumnType("INTEGER")), + ], + ) assert hash(view1) != hash(view2) diff --git a/tests/unit_tests/data_science_utils/schema/test_view_name.py b/tests/unit_tests/data_science_utils/schema/test_view_name.py index 44a10d67..418ce446 100644 --- a/tests/unit_tests/data_science_utils/schema/test_view_name.py +++ b/tests/unit_tests/data_science_utils/schema/test_view_name.py @@ -1,11 +1,6 @@ import pytest -from exasol.analytics.schema import ( - SchemaName, - ViewNameImpl, - TableNameImpl, - TableName, -) +from exasol.analytics.schema import SchemaName, TableName, TableNameImpl, ViewNameImpl def test_fully_qualified(): diff --git a/tests/unit_tests/data_science_utils/schema/test_view_name_builder.py b/tests/unit_tests/data_science_utils/schema/test_view_name_builder.py index f242e4bd..95c76b0c 100644 --- a/tests/unit_tests/data_science_utils/schema/test_view_name_builder.py +++ b/tests/unit_tests/data_science_utils/schema/test_view_name_builder.py @@ -1,13 +1,8 @@ import pytest - -from exasol.analytics.schema import ( - SchemaName, - ViewNameBuilder, - ViewNameImpl, - ViewName, -) from typeguard import TypeCheckError +from exasol.analytics.schema import SchemaName, ViewName, ViewNameBuilder, ViewNameImpl + def test_using_empty_constructor(): with pytest.raises(TypeCheckError): @@ -16,82 +11,108 @@ def test_using_empty_constructor(): def test_using_constructor_name_only(): view_name = ViewNameBuilder(name="view").build() - assert view_name.name == "view" \ - and view_name.schema_name is None \ - and isinstance(view_name, ViewName) + assert ( + view_name.name == "view" + and view_name.schema_name is None + and isinstance(view_name, ViewName) + ) def test_using_constructor_schema(): view_name = ViewNameBuilder(name="view", schema=SchemaName("schema")).build() - assert view_name.name == "view" \ - and view_name.schema_name.name is "schema" \ - and isinstance(view_name, ViewName) + assert ( + view_name.name == "view" + and view_name.schema_name.name == "schema" + and isinstance(view_name, ViewName) + ) def test_using_with_name_only(): view_name = ViewNameBuilder().with_name("view").build() - assert view_name.name == "view" \ - and view_name.schema_name is None \ - and isinstance(view_name, ViewName) + assert ( + view_name.name == "view" + and view_name.schema_name is None + and isinstance(view_name, ViewName) + ) def test_using_with_schema(): - view_name = ViewNameBuilder().with_name("view").with_schema_name(SchemaName("schema")).build() - assert view_name.name == "view" \ - and view_name.schema_name.name == "schema" \ - and isinstance(view_name, ViewName) + view_name = ( + ViewNameBuilder() + .with_name("view") + .with_schema_name(SchemaName("schema")) + .build() + ) + assert ( + view_name.name == "view" + and view_name.schema_name.name == "schema" + and isinstance(view_name, ViewName) + ) def test_from_existing_using_with_schema(): source_view_name = ViewNameImpl("view") - view_name = ViewNameBuilder(view_name=source_view_name).with_schema_name(SchemaName("schema")).build() - assert source_view_name.name == "view" \ - and source_view_name.schema_name is None \ - and view_name.name == "view" \ - and view_name.schema_name.name == "schema" \ - and isinstance(view_name, ViewName) + view_name = ( + ViewNameBuilder(view_name=source_view_name) + .with_schema_name(SchemaName("schema")) + .build() + ) + assert ( + source_view_name.name == "view" + and source_view_name.schema_name is None + and view_name.name == "view" + and view_name.schema_name.name == "schema" + and isinstance(view_name, ViewName) + ) def test_from_existing_using_with_name(): source_view_name = ViewNameImpl("view", SchemaName("schema")) view_name = ViewNameBuilder(view_name=source_view_name).with_name("view1").build() - assert source_view_name.name == "view" \ - and source_view_name.schema_name.name == "schema" \ - and view_name.schema_name.name == "schema" \ - and view_name.name == "view1" \ - and isinstance(view_name, ViewName) + assert ( + source_view_name.name == "view" + and source_view_name.schema_name.name == "schema" + and view_name.schema_name.name == "schema" + and view_name.name == "view1" + and isinstance(view_name, ViewName) + ) def test_from_existing_and_new_schema_in_constructor(): source_view_name = ViewNameImpl("view") - view_name = ViewNameBuilder(schema=SchemaName("schema"), - view_name=source_view_name).build() - assert source_view_name.name == "view" \ - and source_view_name.schema_name is None \ - and view_name.name == "view" \ - and view_name.schema_name.name == "schema" \ - and isinstance(view_name, ViewName) + view_name = ViewNameBuilder( + schema=SchemaName("schema"), view_name=source_view_name + ).build() + assert ( + source_view_name.name == "view" + and source_view_name.schema_name is None + and view_name.name == "view" + and view_name.schema_name.name == "schema" + and isinstance(view_name, ViewName) + ) def test_from_existing_and_new_name_in_constructor(): source_view_name = ViewNameImpl("view", SchemaName("schema")) - view_name = ViewNameBuilder(name="view1", - view_name=source_view_name).build() - assert source_view_name.name == "view" \ - and source_view_name.schema_name.name == "schema" \ - and view_name.schema_name.name == "schema" \ - and view_name.name == "view1" \ - and isinstance(view_name, ViewName) + view_name = ViewNameBuilder(name="view1", view_name=source_view_name).build() + assert ( + source_view_name.name == "view" + and source_view_name.schema_name.name == "schema" + and view_name.schema_name.name == "schema" + and view_name.name == "view1" + and isinstance(view_name, ViewName) + ) def test_using_create_name_using_only_name(): view_name = ViewNameBuilder.create(name="view") - assert view_name.name == "view" \ - and isinstance(view_name, ViewName) + assert view_name.name == "view" and isinstance(view_name, ViewName) def test_using_create_name_using_schema(): view_name = ViewNameBuilder.create(name="view", schema=SchemaName("schema")) - assert view_name.name == "view" \ - and view_name.schema_name.name == "schema" \ - and isinstance(view_name, ViewName) + assert ( + view_name.name == "view" + and view_name.schema_name.name == "schema" + and isinstance(view_name, ViewName) + ) diff --git a/tests/unit_tests/data_science_utils/test_mock_result_set.py b/tests/unit_tests/data_science_utils/test_mock_result_set.py index 3939c516..914e4242 100644 --- a/tests/unit_tests/data_science_utils/test_mock_result_set.py +++ b/tests/unit_tests/data_science_utils/test_mock_result_set.py @@ -1,10 +1,6 @@ import pytest -from exasol.analytics.schema import ( - Column, - ColumnType, - ColumnNameBuilder, -) +from exasol.analytics.schema import Column, ColumnNameBuilder, ColumnType from exasol.analytics.sql_executor.testing.mock_result_set import MockResultSet @@ -75,14 +71,18 @@ def test_fetchmany(): def test_columns(): input = [("a", 1), ("b", 2), ("c", 4)] - columns = [Column(ColumnNameBuilder.create("t1"), ColumnType(name="VARCHAR(200000)")), - Column(ColumnNameBuilder.create("t2"), ColumnType(name="INTEGER"))] + columns = [ + Column(ColumnNameBuilder.create("t1"), ColumnType(name="VARCHAR(200000)")), + Column(ColumnNameBuilder.create("t2"), ColumnType(name="INTEGER")), + ] result_set = MockResultSet(rows=input, columns=columns) assert columns == result_set.columns() def test_rows_and_columns_different_length(): input = [("a", 1), ("b", 2), ("c", 4)] - columns = [Column(ColumnNameBuilder.create("t1"), ColumnType(name="VARCHAR(200000)"))] + columns = [ + Column(ColumnNameBuilder.create("t1"), ColumnType(name="VARCHAR(200000)")) + ] with pytest.raises(AssertionError): result_set = MockResultSet(rows=input, columns=columns) diff --git a/tests/unit_tests/data_science_utils/test_mock_sql_executor.py b/tests/unit_tests/data_science_utils/test_mock_sql_executor.py index 3f6d5a8d..d9e003bf 100644 --- a/tests/unit_tests/data_science_utils/test_mock_sql_executor.py +++ b/tests/unit_tests/data_science_utils/test_mock_sql_executor.py @@ -1,5 +1,8 @@ from exasol.analytics.sql_executor.testing.mock_result_set import MockResultSet -from exasol.analytics.sql_executor.testing.mock_sql_executor import MockSQLExecutor, ExpectedQuery +from exasol.analytics.sql_executor.testing.mock_sql_executor import ( + ExpectedQuery, + MockSQLExecutor, +) def test_no_resultset(): diff --git a/tests/unit_tests/data_science_utils/udf_utils/test_ctx_iterator.py b/tests/unit_tests/data_science_utils/udf_utils/test_ctx_iterator.py index 9a7b63c2..759b6a0f 100644 --- a/tests/unit_tests/data_science_utils/udf_utils/test_ctx_iterator.py +++ b/tests/unit_tests/data_science_utils/udf_utils/test_ctx_iterator.py @@ -8,6 +8,7 @@ def udf_wrapper(): from exasol_udf_mock_python.udf_context import UDFContext + from exasol.analytics.udf.utils.iterators import ctx_iterator def run(ctx: UDFContext): @@ -22,11 +23,12 @@ def test_ctx_iterator(input_size): meta = MockMetaData( script_code_wrapper_function=udf_wrapper, input_type="SET", - input_columns=[Column("t1", int, "INTEGER"), - Column("t2", float, "FLOAT"), ], + input_columns=[ + Column("t1", int, "INTEGER"), + Column("t2", float, "FLOAT"), + ], output_type="EMITS", - output_columns=[Column("t1", int, "INTEGER"), - Column("t2", float, "FLOAT")] + output_columns=[Column("t1", int, "INTEGER"), Column("t2", float, "FLOAT")], ) exa = MockExaEnvironment(meta) input_data = [(i, 1.0 * i) for i in range(input_size)] diff --git a/tests/unit_tests/data_science_utils/udf_utils/test_iterate_trough_dataset.py b/tests/unit_tests/data_science_utils/udf_utils/test_iterate_trough_dataset.py index cfb92e0c..cbed7622 100644 --- a/tests/unit_tests/data_science_utils/udf_utils/test_iterate_trough_dataset.py +++ b/tests/unit_tests/data_science_utils/udf_utils/test_iterate_trough_dataset.py @@ -12,23 +12,28 @@ def udf_wrapper(): from exasol.analytics.udf.utils.iterators import iterate_trough_dataset def run(ctx: UDFContext): - iterate_trough_dataset(ctx, 10, - lambda x: x, - lambda: None, - lambda state, x: ctx.emit(x), - lambda: ctx.reset()) + iterate_trough_dataset( + ctx, + 10, + lambda x: x, + lambda: None, + lambda state, x: ctx.emit(x), + lambda: ctx.reset(), + ) -@pytest.mark.parametrize("input_size", [i for i in range(1,150)]) + +@pytest.mark.parametrize("input_size", [i for i in range(1, 150)]) def test_iterate_through_dataset(input_size): executor = UDFMockExecutor() meta = MockMetaData( script_code_wrapper_function=udf_wrapper, input_type="SET", - input_columns=[Column("t1", int, "INTEGER"), - Column("t2", float, "FLOAT"), ], + input_columns=[ + Column("t1", int, "INTEGER"), + Column("t2", float, "FLOAT"), + ], output_type="EMITS", - output_columns=[Column("t1", int, "INTEGER"), - Column("t2", float, "FLOAT")] + output_columns=[Column("t1", int, "INTEGER"), Column("t2", float, "FLOAT")], ) exa = MockExaEnvironment(meta) input_data = [(i, 1.0 * i) for i in range(input_size)] diff --git a/tests/unit_tests/data_science_utils/udf_utils/test_udf_context_wrapper.py b/tests/unit_tests/data_science_utils/udf_utils/test_udf_context_wrapper.py index 2ea745c9..de78ffe1 100644 --- a/tests/unit_tests/data_science_utils/udf_utils/test_udf_context_wrapper.py +++ b/tests/unit_tests/data_science_utils/udf_utils/test_udf_context_wrapper.py @@ -6,10 +6,12 @@ def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext from collections import OrderedDict - from exasol.analytics.udf.utils.context_wrapper import UDFContextWrapper + import numpy as np + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.udf.utils.context_wrapper import UDFContextWrapper def run(ctx: UDFContext): wrapper = UDFContextWrapper(ctx, OrderedDict([("t2", "a"), ("t1", "b")])) @@ -25,7 +27,16 @@ def run(ctx: UDFContext): assert list(df.dtypes) == [np.int64] wrapper.reset() - wrapper = UDFContextWrapper(ctx, OrderedDict([("t3", "d"), ("t2", "c"), ]), start_col=1) + wrapper = UDFContextWrapper( + ctx, + OrderedDict( + [ + ("t3", "d"), + ("t2", "c"), + ] + ), + start_col=1, + ) df = wrapper.get_dataframe(10) assert len(df) == 10 assert list(df.columns) == ["d", "c"] @@ -33,7 +44,6 @@ def run(ctx: UDFContext): assert all(df["c"] < 1.0) assert all(np.logical_or(df["d"] >= 1.0, df["d"] == 0.0)) - wrapper.reset() df = wrapper.get_dataframe(10, start_col=1) assert len(df) == 10 @@ -53,8 +63,7 @@ def test_partial_fit_iterator(): Column("t3", float, "FLOAT"), ], output_type="EMITS", - output_columns=[Column("t1", int, "INTEGER"), - Column("t2", float, "FLOAT")] + output_columns=[Column("t1", int, "INTEGER"), Column("t2", float, "FLOAT")], ) exa = MockExaEnvironment(meta) input_data = [(i, (1.0 * i) / 105, (1.0 * i)) for i in range(105)] diff --git a/tests/unit_tests/data_science_utils/utils/test_hash_generation_for_object.py b/tests/unit_tests/data_science_utils/utils/test_hash_generation_for_object.py index 6a2816ad..ef58cb53 100644 --- a/tests/unit_tests/data_science_utils/utils/test_hash_generation_for_object.py +++ b/tests/unit_tests/data_science_utils/utils/test_hash_generation_for_object.py @@ -1,5 +1,5 @@ import dataclasses -from typing import List, Dict +from typing import Dict, List from exasol.analytics.utils.hash_generation_for_object import generate_hash_for_object @@ -30,7 +30,9 @@ def test_object_with_list_equal(): def test_object_with_list_different_values_not_equal(): - test1 = ObjectWithList(test=[HashableTestObject(1), HashableTestObject(2), HashableTestObject(3)]) + test1 = ObjectWithList( + test=[HashableTestObject(1), HashableTestObject(2), HashableTestObject(3)] + ) test2 = ObjectWithList(test=[HashableTestObject(1), HashableTestObject(2)]) result1 = generate_hash_for_object(test1) result2 = generate_hash_for_object(test2) @@ -38,8 +40,12 @@ def test_object_with_list_different_values_not_equal(): def test_object_with_list_different_order_not_equal(): - test1 = ObjectWithList(test=[HashableTestObject(1), HashableTestObject(2), HashableTestObject(3)]) - test2 = ObjectWithList(test=[HashableTestObject(3), HashableTestObject(1), HashableTestObject(2)]) + test1 = ObjectWithList( + test=[HashableTestObject(1), HashableTestObject(2), HashableTestObject(3)] + ) + test2 = ObjectWithList( + test=[HashableTestObject(3), HashableTestObject(1), HashableTestObject(2)] + ) result1 = generate_hash_for_object(test1) result2 = generate_hash_for_object(test2) assert result1 != result2 @@ -52,16 +58,24 @@ class ObjectWithMutlipleAttributes: def test_object_with_multiple_attributes_equal(): - test1 = ObjectWithMutlipleAttributes(test1=HashableTestObject(1), test2=HashableTestObject(2)) - test2 = ObjectWithMutlipleAttributes(test1=HashableTestObject(1), test2=HashableTestObject(2)) + test1 = ObjectWithMutlipleAttributes( + test1=HashableTestObject(1), test2=HashableTestObject(2) + ) + test2 = ObjectWithMutlipleAttributes( + test1=HashableTestObject(1), test2=HashableTestObject(2) + ) result1 = generate_hash_for_object(test1) result2 = generate_hash_for_object(test2) assert result1 == result2 def test_object_with_multiple_attributes_not_equal(): - test1 = ObjectWithMutlipleAttributes(test1=HashableTestObject(1), test2=HashableTestObject(2)) - test2 = ObjectWithMutlipleAttributes(test1=HashableTestObject(1), test2=HashableTestObject(3)) + test1 = ObjectWithMutlipleAttributes( + test1=HashableTestObject(1), test2=HashableTestObject(2) + ) + test2 = ObjectWithMutlipleAttributes( + test1=HashableTestObject(1), test2=HashableTestObject(3) + ) result1 = generate_hash_for_object(test1) result2 = generate_hash_for_object(test2) assert result1 != result2 @@ -73,10 +87,18 @@ class ObjectWithDict: def test_object_with_dict_equal(): - test1 = ObjectWithDict(test={HashableTestObject(1): HashableTestObject(1), - HashableTestObject(2): HashableTestObject(2)}) - test2 = ObjectWithDict(test={HashableTestObject(1): HashableTestObject(1), - HashableTestObject(2): HashableTestObject(2)}) + test1 = ObjectWithDict( + test={ + HashableTestObject(1): HashableTestObject(1), + HashableTestObject(2): HashableTestObject(2), + } + ) + test2 = ObjectWithDict( + test={ + HashableTestObject(1): HashableTestObject(1), + HashableTestObject(2): HashableTestObject(2), + } + ) result1 = generate_hash_for_object(test1) result2 = generate_hash_for_object(test2) @@ -84,8 +106,12 @@ def test_object_with_dict_equal(): def test_object_with_dict_different_values_not_equal(): - test1 = ObjectWithDict(test={HashableTestObject(1): HashableTestObject(1), - HashableTestObject(2): HashableTestObject(2)}) + test1 = ObjectWithDict( + test={ + HashableTestObject(1): HashableTestObject(1), + HashableTestObject(2): HashableTestObject(2), + } + ) test2 = ObjectWithDict(test={HashableTestObject(1): HashableTestObject(1)}) result1 = generate_hash_for_object(test1) @@ -94,10 +120,18 @@ def test_object_with_dict_different_values_not_equal(): def test_object_with_dict_different_order_not_equal(): - test1 = ObjectWithDict(test={HashableTestObject(1): HashableTestObject(1), - HashableTestObject(2): HashableTestObject(2)}) - test2 = ObjectWithDict(test={HashableTestObject(2): HashableTestObject(2), - HashableTestObject(1): HashableTestObject(1)}) + test1 = ObjectWithDict( + test={ + HashableTestObject(1): HashableTestObject(1), + HashableTestObject(2): HashableTestObject(2), + } + ) + test2 = ObjectWithDict( + test={ + HashableTestObject(2): HashableTestObject(2), + HashableTestObject(1): HashableTestObject(1), + } + ) result1 = generate_hash_for_object(test1) result2 = generate_hash_for_object(test2) diff --git a/tests/unit_tests/query_handler/fixtures.py b/tests/unit_tests/query_handler/fixtures.py index 431d3c15..cbdd2777 100644 --- a/tests/unit_tests/query_handler/fixtures.py +++ b/tests/unit_tests/query_handler/fixtures.py @@ -1,8 +1,14 @@ -import pytest import exasol.bucketfs as bfs +import pytest -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext, Connection -from exasol.analytics.query_handler.context.top_level_query_handler_context import TopLevelQueryHandlerContext, ConnectionLookup +from exasol.analytics.query_handler.context.scope import ( + Connection, + ScopeQueryHandlerContext, +) +from exasol.analytics.query_handler.context.top_level_query_handler_context import ( + ConnectionLookup, + TopLevelQueryHandlerContext, +) PREFIX = "PREFIX" @@ -72,10 +78,11 @@ def mocked_temporary_bucketfs_location(tmp_path): @pytest.fixture def top_level_query_handler_context_mock( - sample_bucketfs_location: bfs.path.PathLike, - tmp_db_obj_prefix: str, - aaf_pytest_db_schema: str, - connection_lookup_mock: ConnectionLookup) -> TopLevelQueryHandlerContext: + sample_bucketfs_location: bfs.path.PathLike, + tmp_db_obj_prefix: str, + aaf_pytest_db_schema: str, + connection_lookup_mock: ConnectionLookup, +) -> TopLevelQueryHandlerContext: query_handler_context = TopLevelQueryHandlerContext( temporary_bucketfs_location=sample_bucketfs_location, temporary_db_object_name_prefix=tmp_db_obj_prefix, @@ -87,8 +94,8 @@ def top_level_query_handler_context_mock( @pytest.fixture(params=["top", "child"]) def scope_query_handler_context_mock( - top_level_query_handler_context_mock: TopLevelQueryHandlerContext, - request) -> ScopeQueryHandlerContext: + top_level_query_handler_context_mock: TopLevelQueryHandlerContext, request +) -> ScopeQueryHandlerContext: if request.param == "top": return top_level_query_handler_context_mock else: diff --git a/tests/unit_tests/query_handler/test_query_handler_interface.py b/tests/unit_tests/query_handler/test_query_handler_interface.py index 302f7c22..8d2f29dc 100644 --- a/tests/unit_tests/query_handler/test_query_handler_interface.py +++ b/tests/unit_tests/query_handler/test_query_handler_interface.py @@ -1,33 +1,40 @@ -from typing import Union, Dict, Any +from typing import Any, Dict, Union from unittest.mock import MagicMock import pytest -from exasol.analytics.schema import ( - Column, - ColumnType, - ColumnName, -) -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.query.result.interface import QueryResult +from exasol.analytics.query_handler.query.result.python_query_result import ( + PythonQueryResult, +) from exasol.analytics.query_handler.query.select import SelectQueryWithColumnDefinition from exasol.analytics.query_handler.query_handler import QueryHandler -from exasol.analytics.query_handler.result import Continue, Finish -from exasol.analytics.query_handler.query.result.python_query_result import PythonQueryResult -from exasol.analytics.query_handler.query.result.interface import QueryResult +from exasol.analytics.query_handler.result import Continue, Finish +from exasol.analytics.schema import Column, ColumnName, ColumnType class TestQueryHandler(QueryHandler[Dict[str, Any], int]): __test__ = False - def __init__(self, parameter: Dict[str, Any], query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: Dict[str, Any], query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter def start(self) -> Union[Continue, Finish[int]]: - return Continue([], SelectQueryWithColumnDefinition(f'SELECT {self._parameter["a"]} as "A"', - [Column(ColumnName("A"), ColumnType("DECIMAL(12,0)"))])) - - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[int]]: + return Continue( + [], + SelectQueryWithColumnDefinition( + f'SELECT {self._parameter["a"]} as "A"', + [Column(ColumnName("A"), ColumnType("DECIMAL(12,0)"))], + ), + ) + + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[int]]: return Finish(query_result.A) diff --git a/tests/unit_tests/query_handler/test_scope_query_handler_context.py b/tests/unit_tests/query_handler/test_scope_query_handler_context.py index eaaa43d6..ebb2a897 100644 --- a/tests/unit_tests/query_handler/test_scope_query_handler_context.py +++ b/tests/unit_tests/query_handler/test_scope_query_handler_context.py @@ -1,21 +1,26 @@ from contextlib import contextmanager -import pytest import exasol.bucketfs as bfs +import pytest + +from exasol.analytics.query_handler.context.connection_name import ConnectionName +from exasol.analytics.query_handler.context.scope import ( + Connection, + ScopeQueryHandlerContext, +) +from exasol.analytics.query_handler.context.top_level_query_handler_context import ( + ChildContextNotReleasedError, +) from exasol.analytics.schema import ( - SchemaName, ColumnBuilder, ColumnName, - UDFName, ColumnType, - View, + SchemaName, Table, + UDFName, + View, ) -from exasol.analytics.query_handler.context.connection_name import ConnectionName -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext, Connection -from exasol.analytics.query_handler.context.top_level_query_handler_context import ChildContextNotReleasedError - @pytest.fixture def context_mock(scope_query_handler_context_mock) -> ScopeQueryHandlerContext: @@ -37,7 +42,7 @@ def test_temporary_table_temporary_schema(context_mock, aaf_pytest_db_schema: st assert proxy.schema_name.name == aaf_pytest_db_schema -def test_temporary_view_prefix_in_name(context_mock,prefix): +def test_temporary_view_prefix_in_name(context_mock, prefix): proxy = context_mock.get_temporary_view_name() assert proxy.name.startswith(prefix) @@ -53,16 +58,18 @@ def test_temporary_connection_temporary(context_mock: ScopeQueryHandlerContext): def test_temporary_udf_temporary( - context_mock: ScopeQueryHandlerContext, - aaf_pytest_db_schema: str): + context_mock: ScopeQueryHandlerContext, aaf_pytest_db_schema: str +): proxy = context_mock.get_temporary_udf_name() - assert isinstance(proxy, UDFName) and \ - proxy.schema_name == SchemaName(aaf_pytest_db_schema) + assert isinstance(proxy, UDFName) and proxy.schema_name == SchemaName( + aaf_pytest_db_schema + ) -def test_temporary_bucketfs_file_prefix_in_name(sample_bucketfs_location: bfs.path.PathLike, - context_mock: ScopeQueryHandlerContext): - proxy = context_mock.get_temporary_bucketfs_location() +def test_temporary_bucketfs_file_prefix_in_name( + sample_bucketfs_location: bfs.path.PathLike, context_mock: ScopeQueryHandlerContext +): + proxy = context_mock.get_temporary_bucketfs_location() actual_path = proxy.bucketfs_location().as_udf_path() expected_prefix_path = sample_bucketfs_location.as_udf_path() assert actual_path.startswith(expected_prefix_path) @@ -80,7 +87,9 @@ def test_two_temporary_view_are_not_equal(context_mock: ScopeQueryHandlerContext assert proxy1.name != proxy2.name -def test_two_temporary_bucketfs_files_are_not_equal(context_mock: ScopeQueryHandlerContext): +def test_two_temporary_bucketfs_files_are_not_equal( + context_mock: ScopeQueryHandlerContext, +): proxy1 = context_mock.get_temporary_bucketfs_location() proxy2 = context_mock.get_temporary_bucketfs_location() path1 = proxy1.bucketfs_location().as_udf_path() @@ -88,14 +97,18 @@ def test_two_temporary_bucketfs_files_are_not_equal(context_mock: ScopeQueryHand assert path1 != path2 -def test_temporary_table_name_proxy_use_name_after_release_fails(context_mock: ScopeQueryHandlerContext): +def test_temporary_table_name_proxy_use_name_after_release_fails( + context_mock: ScopeQueryHandlerContext, +): proxy = context_mock.get_temporary_table_name() context_mock.release() with pytest.raises(RuntimeError, match="TableNameProxy.* already released."): proxy_name = proxy.name -def test_temporary_view_name_proxy_use_name_after_release_fails(context_mock: ScopeQueryHandlerContext): +def test_temporary_view_name_proxy_use_name_after_release_fails( + context_mock: ScopeQueryHandlerContext, +): proxy = context_mock.get_temporary_view_name() context_mock.release() with pytest.raises(RuntimeError, match="ViewNameProxy.* already released."): @@ -103,7 +116,8 @@ def test_temporary_view_name_proxy_use_name_after_release_fails(context_mock: Sc def test_temporary_table_name_proxy_use_schema_after_release_fails( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): proxy = context_mock.get_temporary_table_name() context_mock.release() with pytest.raises(RuntimeError, match="TableNameProxy.* already released."): @@ -111,7 +125,8 @@ def test_temporary_table_name_proxy_use_schema_after_release_fails( def test_temporary_view_name_proxy_use_schema_after_release_fails( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): proxy = context_mock.get_temporary_view_name() context_mock.release() with pytest.raises(RuntimeError, match="ViewNameProxy.* already released."): @@ -119,7 +134,8 @@ def test_temporary_view_name_proxy_use_schema_after_release_fails( def test_temporary_table_name_proxy_use_quoted_name_after_release_fails( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): proxy = context_mock.get_temporary_table_name() context_mock.release() with pytest.raises(RuntimeError, match="TableNameProxy.* already released."): @@ -127,7 +143,8 @@ def test_temporary_table_name_proxy_use_quoted_name_after_release_fails( def test_temporary_view_name_proxy_use_quoted_name_after_release_fails( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): proxy = context_mock.get_temporary_view_name() context_mock.release() with pytest.raises(RuntimeError, match="ViewNameProxy.* already released."): @@ -135,7 +152,8 @@ def test_temporary_view_name_proxy_use_quoted_name_after_release_fails( def test_temporary_table_name_proxy_use_fully_qualified_after_release_fails( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): proxy = context_mock.get_temporary_table_name() context_mock.release() with pytest.raises(RuntimeError, match="TableNameProxy.* already released."): @@ -143,7 +161,8 @@ def test_temporary_table_name_proxy_use_fully_qualified_after_release_fails( def test_temporary_view_name_proxy_use_fully_qualified_after_release_fails( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): proxy = context_mock.get_temporary_view_name() context_mock.release() with pytest.raises(RuntimeError, match="ViewNameProxy.* already released."): @@ -156,13 +175,17 @@ def test_get_temporary_view_after_release_fails(context_mock: ScopeQueryHandlerC proxy = context_mock.get_temporary_view_name() -def test_get_temporary_table_after_release_fails(context_mock: ScopeQueryHandlerContext): +def test_get_temporary_table_after_release_fails( + context_mock: ScopeQueryHandlerContext, +): context_mock.release() with pytest.raises(RuntimeError, match="Context already released."): proxy = context_mock.get_temporary_table_name() -def test_get_temporary_bucketfs_file_after_release_fails(context_mock: ScopeQueryHandlerContext): +def test_get_temporary_bucketfs_file_after_release_fails( + context_mock: ScopeQueryHandlerContext, +): context_mock.release() with pytest.raises(RuntimeError, match="Context already released."): proxy = context_mock.get_temporary_bucketfs_location() @@ -183,7 +206,7 @@ def not_raises(exception): try: yield except exception: - raise pytest.fail("DID RAISE {0}".format(exception)) + raise pytest.fail(f"DID RAISE {exception}") def test_transfer_between_siblings(context_mock: ScopeQueryHandlerContext): @@ -200,7 +223,9 @@ def test_transfer_between_siblings(context_mock: ScopeQueryHandlerContext): _ = object_proxy2.name -def test_transfer_siblings_check_ownership_transfer_to_target(context_mock: ScopeQueryHandlerContext): +def test_transfer_siblings_check_ownership_transfer_to_target( + context_mock: ScopeQueryHandlerContext, +): child1 = context_mock.get_child_query_handler_context() child2 = context_mock.get_child_query_handler_context() object_proxy1 = child1.get_temporary_table_name() @@ -216,46 +241,58 @@ def test_transfer_siblings_check_ownership_transfer_to_target(context_mock: Scop def test_transfer_child_parent_check_ownership_transfer_to_target( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): parent = context_mock child1 = parent.get_child_query_handler_context() child2 = parent.get_child_query_handler_context() object_proxy1 = child1.get_temporary_table_name() child1.transfer_object_to(object_proxy1, parent) - with pytest.raises(RuntimeError, match="Object not owned by this ScopeQueryHandlerContext."): + with pytest.raises( + RuntimeError, match="Object not owned by this ScopeQueryHandlerContext." + ): child1.transfer_object_to(object_proxy1, child2) def test_transfer_parent_child_check_ownership_transfer_to_target( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): parent = context_mock child1 = parent.get_child_query_handler_context() child2 = parent.get_child_query_handler_context() object_proxy1 = parent.get_temporary_table_name() parent.transfer_object_to(object_proxy1, child1) - with pytest.raises(RuntimeError, match="Object not owned by this ScopeQueryHandlerContext."): + with pytest.raises( + RuntimeError, match="Object not owned by this ScopeQueryHandlerContext." + ): parent.transfer_object_to(object_proxy1, child2) -def test_transfer_siblings_checK_losing_ownership(context_mock: ScopeQueryHandlerContext): +def test_transfer_siblings_checK_losing_ownership( + context_mock: ScopeQueryHandlerContext, +): child1 = context_mock.get_child_query_handler_context() child2 = context_mock.get_child_query_handler_context() child3 = context_mock.get_child_query_handler_context() object_proxy1 = child1.get_temporary_table_name() child1.transfer_object_to(object_proxy1, child2) - with pytest.raises(RuntimeError, match="Object not owned by this ScopeQueryHandlerContext."): + with pytest.raises( + RuntimeError, match="Object not owned by this ScopeQueryHandlerContext." + ): child1.transfer_object_to(object_proxy1, child3) def test_transfer_between_siblings_object_from_different_context( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): child1 = context_mock.get_child_query_handler_context() child2 = context_mock.get_child_query_handler_context() grand_child1 = child1.get_child_query_handler_context() object_proxy = grand_child1.get_temporary_table_name() - with pytest.raises(RuntimeError, - match="Object not owned by this ScopeQueryHandlerContext."): + with pytest.raises( + RuntimeError, match="Object not owned by this ScopeQueryHandlerContext." + ): child1.transfer_object_to(object_proxy, child2) @@ -285,29 +322,37 @@ def test_transfer_between_parent_and_child(context_mock: ScopeQueryHandlerContex def test_illegal_transfer_between_grand_child_and_parent( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): parent = context_mock child = context_mock.get_child_query_handler_context() grand_child = child.get_child_query_handler_context() object_proxy = grand_child.get_temporary_table_name() - with pytest.raises(RuntimeError, match="Given ScopeQueryHandlerContext not a child, parent or sibling."): + with pytest.raises( + RuntimeError, + match="Given ScopeQueryHandlerContext not a child, parent or sibling.", + ): grand_child.transfer_object_to(object_proxy, parent) def test_illegal_transfer_between_parent_and_grand_child( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): parent = context_mock child = context_mock.get_child_query_handler_context() grand_child = child.get_child_query_handler_context() object_proxy = parent.get_temporary_table_name() - with pytest.raises(RuntimeError, - match="Given ScopeQueryHandlerContext not a child, parent or sibling.|" - "Given ScopeQueryHandlerContext not a child."): + with pytest.raises( + RuntimeError, + match="Given ScopeQueryHandlerContext not a child, parent or sibling.|" + "Given ScopeQueryHandlerContext not a child.", + ): parent.transfer_object_to(object_proxy, grand_child) def test_release_parent_before_child_with_temporary_object_expect_exception( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): parent = context_mock child = context_mock.get_child_query_handler_context() _ = child.get_temporary_table_name() @@ -316,7 +361,8 @@ def test_release_parent_before_child_with_temporary_object_expect_exception( def test_release_parent_before_child_without_temporary_object_expect_exception( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): parent = context_mock _ = context_mock.get_child_query_handler_context() with pytest.raises(ChildContextNotReleasedError): @@ -324,7 +370,8 @@ def test_release_parent_before_child_without_temporary_object_expect_exception( def test_release_parent_before_grand_child_with_temporary_object_expect_exception( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): parent = context_mock child = context_mock.get_child_query_handler_context() grand_child = child.get_child_query_handler_context() @@ -334,7 +381,8 @@ def test_release_parent_before_grand_child_with_temporary_object_expect_exceptio def test_release_parent_before_grand_child_without_temporary_object_expect_exception( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): parent = context_mock child = context_mock.get_child_query_handler_context() _ = child.get_child_query_handler_context() @@ -343,7 +391,8 @@ def test_release_parent_before_grand_child_without_temporary_object_expect_excep def test_cleanup_parent_before_grand_child_without_temporary_objects( - context_mock: ScopeQueryHandlerContext): + context_mock: ScopeQueryHandlerContext, +): child1 = context_mock.get_child_query_handler_context() child2 = context_mock.get_child_query_handler_context() _ = child1.get_child_query_handler_context() @@ -360,39 +409,43 @@ def test_cleanup_parent_before_grand_child_without_temporary_objects( def test_using_table_name_proxy_in_table(context_mock: ScopeQueryHandlerContext): table_name = context_mock.get_temporary_table_name() - table = Table(table_name, - columns=[ - ( - ColumnBuilder(). - with_name(ColumnName("COLUMN1")) - .with_type(ColumnType("VARCHAR")) - .build() - ) - ]) + table = Table( + table_name, + columns=[ + ( + ColumnBuilder() + .with_name(ColumnName("COLUMN1")) + .with_type(ColumnType("VARCHAR")) + .build() + ) + ], + ) assert table.name is not None def test_using_view_name_proxy_in_view(context_mock: ScopeQueryHandlerContext): view_name = context_mock.get_temporary_view_name() - view = View(view_name, columns=[ - ( - ColumnBuilder(). - with_name(ColumnName("COLUMN1")) - .with_type(ColumnType("VARCHAR")) - .build() - )]) + view = View( + view_name, + columns=[ + ( + ColumnBuilder() + .with_name(ColumnName("COLUMN1")) + .with_type(ColumnType("VARCHAR")) + .build() + ) + ], + ) assert view.name is not None def test_get_connection_existing_connection( - context_mock: ScopeQueryHandlerContext, - connection_mock: Connection + context_mock: ScopeQueryHandlerContext, connection_mock: Connection ): connection = context_mock.get_connection("existing") assert connection == connection -def test_get_connection_not_existing_connection( - context_mock: ScopeQueryHandlerContext): +def test_get_connection_not_existing_connection(context_mock: ScopeQueryHandlerContext): with pytest.raises(KeyError): context_mock.get_connection("not_existing") diff --git a/tests/unit_tests/query_handler/test_top_level_query_handler_context.py b/tests/unit_tests/query_handler/test_top_level_query_handler_context.py index e8f7fe73..792fd3d7 100644 --- a/tests/unit_tests/query_handler/test_top_level_query_handler_context.py +++ b/tests/unit_tests/query_handler/test_top_level_query_handler_context.py @@ -1,7 +1,10 @@ -import pytest import exasol.bucketfs as bfs +import pytest -from exasol.analytics.query_handler.context.top_level_query_handler_context import TopLevelQueryHandlerContext, ChildContextNotReleasedError +from exasol.analytics.query_handler.context.top_level_query_handler_context import ( + ChildContextNotReleasedError, + TopLevelQueryHandlerContext, +) from exasol.analytics.query_handler.query.drop.table import DropTableQuery from exasol.analytics.query_handler.query.drop.view import DropViewQuery @@ -16,8 +19,11 @@ def test_cleanup_released_temporary_table_proxies(context_mock): proxy_fully_qualified = proxy.fully_qualified context_mock.release() queries = context_mock.cleanup_released_object_proxies() - assert len(queries) == 1 and isinstance(queries[0], DropTableQuery) \ - and queries[0].query_string == f"DROP TABLE IF EXISTS {proxy_fully_qualified};" + assert ( + len(queries) == 1 + and isinstance(queries[0], DropTableQuery) + and queries[0].query_string == f"DROP TABLE IF EXISTS {proxy_fully_qualified};" + ) def test_cleanup_released_temporary_view_proxies(context_mock): @@ -26,12 +32,16 @@ def test_cleanup_released_temporary_view_proxies(context_mock): context_mock.release() queries = context_mock.cleanup_released_object_proxies() - assert len(queries) == 1 and isinstance(queries[0], DropViewQuery) \ - and queries[0].query_string == f"DROP VIEW IF EXISTS {proxy_fully_qualified};" + assert ( + len(queries) == 1 + and isinstance(queries[0], DropViewQuery) + and queries[0].query_string == f"DROP VIEW IF EXISTS {proxy_fully_qualified};" + ) -def test_cleanup_released_bucketfs_object_with_uploaded_file_proxies(context_mock, - sample_bucketfs_location: bfs.path.PathLike): +def test_cleanup_released_bucketfs_object_with_uploaded_file_proxies( + context_mock, sample_bucketfs_location: bfs.path.PathLike +): proxy = context_mock.get_temporary_bucketfs_location() # create dummy file with content "test" (proxy.bucketfs_location() / "test_file.txt").write(b"test") @@ -40,7 +50,9 @@ def test_cleanup_released_bucketfs_object_with_uploaded_file_proxies(context_moc assert not sample_bucketfs_location.is_dir() -def test_cleanup_released_bucketfs_object_without_uploaded_file_proxies_after_release(context_mock): +def test_cleanup_released_bucketfs_object_without_uploaded_file_proxies_after_release( + context_mock, +): _ = context_mock.get_temporary_bucketfs_location() context_mock.release() context_mock.cleanup_released_object_proxies() @@ -52,8 +64,9 @@ def test_cleanup_release_in_reverse_order_at_top_level(context_mock): context_mock.release() query_objects = context_mock.cleanup_released_object_proxies() actual_queries = [query.query_string for query in query_objects] - expected_queries = [f"DROP TABLE IF EXISTS {table_name};" - for table_name in reversed(table_names)] + expected_queries = [ + f"DROP TABLE IF EXISTS {table_name};" for table_name in reversed(table_names) + ] assert expected_queries == actual_queries @@ -66,18 +79,24 @@ def test_cleanup_release_in_reverse_order_at_child(context_mock): child.release() child_query_objects = context_mock.cleanup_released_object_proxies() child_actual_queries = [query.query_string for query in child_query_objects] - child_expected_queries = [f"DROP TABLE IF EXISTS {table_name};" - for table_name in reversed(child_table_names)] + child_expected_queries = [ + f"DROP TABLE IF EXISTS {table_name};" + for table_name in reversed(child_table_names) + ] parent_proxies.extend([context_mock.get_temporary_table_name() for _ in range(10)]) parent_table_names = [proxy.fully_qualified for proxy in parent_proxies] context_mock.release() parent_query_objects = context_mock.cleanup_released_object_proxies() parent_actual_queries = [query.query_string for query in parent_query_objects] - parent_expected_queries = [f"DROP TABLE IF EXISTS {table_name};" - for table_name in reversed(parent_table_names)] - assert child_expected_queries == child_actual_queries and \ - parent_expected_queries == parent_actual_queries + parent_expected_queries = [ + f"DROP TABLE IF EXISTS {table_name};" + for table_name in reversed(parent_table_names) + ] + assert ( + child_expected_queries == child_actual_queries + and parent_expected_queries == parent_actual_queries + ) def test_cleanup_parent_before_grand_child_with_temporary_objects(context_mock): diff --git a/tests/unit_tests/query_handler_runner/test_python_query_handler_runner.py b/tests/unit_tests/query_handler_runner/test_python_query_handler_runner.py index 0d75eb59..9d63a493 100644 --- a/tests/unit_tests/query_handler_runner/test_python_query_handler_runner.py +++ b/tests/unit_tests/query_handler_runner/test_python_query_handler_runner.py @@ -1,48 +1,59 @@ import re +from inspect import cleandoc from pathlib import PurePosixPath from typing import List, Union -from inspect import cleandoc import pytest -from exasol.analytics.schema import ( - Column, - ColumnType, - ColumnName, -) -from exasol.analytics.sql_executor.testing.mock_result_set import MockResultSet -from exasol.analytics.sql_executor.testing.mock_sql_executor import MockSQLExecutor, ExpectedQuery -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.context.top_level_query_handler_context import TopLevelQueryHandlerContext -from exasol.analytics.query_handler.query.select import SelectQueryWithColumnDefinition, SelectQuery +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.context.top_level_query_handler_context import ( + TopLevelQueryHandlerContext, +) +from exasol.analytics.query_handler.python_query_handler_runner import ( + PythonQueryHandlerRunner, +) +from exasol.analytics.query_handler.query.result.interface import QueryResult +from exasol.analytics.query_handler.query.select import ( + SelectQuery, + SelectQueryWithColumnDefinition, +) from exasol.analytics.query_handler.query_handler import QueryHandler from exasol.analytics.query_handler.result import Continue, Finish -from exasol.analytics.query_handler.query.result.interface import QueryResult -from exasol.analytics.query_handler.python_query_handler_runner import PythonQueryHandlerRunner +from exasol.analytics.schema import Column, ColumnName, ColumnType +from exasol.analytics.sql_executor.testing.mock_result_set import MockResultSet +from exasol.analytics.sql_executor.testing.mock_sql_executor import ( + ExpectedQuery, + MockSQLExecutor, +) EXPECTED_EXCEPTION = "ExpectedException" -def expect_query(template: str, result_set = MockResultSet()): +def expect_query(template: str, result_set=MockResultSet()): return [template, result_set] def create_sql_executor_1(schema: str, *args): - return MockSQLExecutor([ - ExpectedQuery( - cleandoc(template.format(schema=schema)), - result_set or MockResultSet() - ) for template, result_set in args - ]) + return MockSQLExecutor( + [ + ExpectedQuery( + cleandoc(template.format(schema=schema)), result_set or MockResultSet() + ) + for template, result_set in args + ] + ) def create_sql_executor(schema: str, prefix: str, *args): - return MockSQLExecutor([ - ExpectedQuery( - cleandoc(template.format(schema=schema, prefix=prefix)), - result_set or MockResultSet() - ) for template, result_set in args - ]) + return MockSQLExecutor( + [ + ExpectedQuery( + cleandoc(template.format(schema=schema, prefix=prefix)), + result_set or MockResultSet(), + ) + for template, result_set in args + ] + ) @pytest.fixture() @@ -61,19 +72,24 @@ class TestInput: class TestOutput: __test__ = False + def __init__(self, test_input: TestInput): self.test_input = test_input class StartFinishTestQueryHandler(QueryHandler[TestInput, TestOutput]): - def __init__(self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter def start(self) -> Union[Continue, Finish[TestOutput]]: return Finish[TestOutput](TestOutput(self._parameter)) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[TestOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[TestOutput]]: pass @@ -88,7 +104,7 @@ def test_start_finish(context_mock): sql_executor=sql_executor, top_level_query_handler_context=context_mock, parameter=test_input, - query_handler_factory=StartFinishTestQueryHandler + query_handler_factory=StartFinishTestQueryHandler, ) test_output = query_handler_runner.run() assert test_output.test_input == test_input @@ -96,7 +112,9 @@ def test_start_finish(context_mock): class StartFinishCleanupQueriesTestQueryHandler(QueryHandler[TestInput, TestOutput]): - def __init__(self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter @@ -104,7 +122,9 @@ def start(self) -> Union[Continue, Finish[TestOutput]]: self._query_handler_context.get_temporary_table_name() return Finish[TestOutput](TestOutput(self._parameter)) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[TestOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[TestOutput]]: pass @@ -117,23 +137,24 @@ def test_start_finish_cleanup_queries(aaf_pytest_db_schema, prefix, context_mock sql_executor = create_sql_executor( aaf_pytest_db_schema, prefix, - expect_query( - 'DROP TABLE IF EXISTS "{schema}"."{prefix}_1";' - )) + expect_query('DROP TABLE IF EXISTS "{schema}"."{prefix}_1";'), + ) test_input = TestInput() query_handler_runner = PythonQueryHandlerRunner[TestInput, TestOutput]( sql_executor=sql_executor, top_level_query_handler_context=context_mock, parameter=test_input, - query_handler_factory=StartFinishCleanupQueriesTestQueryHandler + query_handler_factory=StartFinishCleanupQueriesTestQueryHandler, ) test_output = query_handler_runner.run() assert test_output.test_input == test_input class StartErrorCleanupQueriesTestQueryHandler(QueryHandler[TestInput, TestOutput]): - def __init__(self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter @@ -141,7 +162,9 @@ def start(self) -> Union[Continue, Finish[TestOutput]]: self._query_handler_context.get_temporary_table_name() raise Exception("Start failed") - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[TestOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[TestOutput]]: pass @@ -154,16 +177,15 @@ def test_start_error_cleanup_queries(aaf_pytest_db_schema, prefix, context_mock) sql_executor = create_sql_executor( aaf_pytest_db_schema, prefix, - expect_query( - 'DROP TABLE IF EXISTS "{schema}"."{prefix}_1";' - )) + expect_query('DROP TABLE IF EXISTS "{schema}"."{prefix}_1";'), + ) test_input = TestInput() query_handler_runner = PythonQueryHandlerRunner[TestInput, TestOutput]( sql_executor=sql_executor, top_level_query_handler_context=context_mock, parameter=test_input, - query_handler_factory=StartErrorCleanupQueriesTestQueryHandler + query_handler_factory=StartErrorCleanupQueriesTestQueryHandler, ) with pytest.raises(Exception, match="Execution of query handler .* failed.") as ex: test_output = query_handler_runner.run() @@ -171,20 +193,23 @@ def test_start_error_cleanup_queries(aaf_pytest_db_schema, prefix, context_mock) class ContinueFinishTestQueryHandler(QueryHandler[TestInput, TestOutput]): - def __init__(self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter def start(self) -> Union[Continue, Finish[TestOutput]]: column_name = ColumnName("a") - input_query = SelectQueryWithColumnDefinition(f"""SELECT 1 as {column_name.quoted_name}""", - [Column(ColumnName("a"), - ColumnType(name="DECIMAL", - precision=1, - scale=0))]) + input_query = SelectQueryWithColumnDefinition( + f"""SELECT 1 as {column_name.quoted_name}""", + [Column(ColumnName("a"), ColumnType(name="DECIMAL", precision=1, scale=0))], + ) return Continue(query_list=[], input_query=input_query) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[TestOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[TestOutput]]: if query_result.a != 1: raise AssertionError(f"query_result.a != 1, got {query_result.a}") return Finish[TestOutput](TestOutput(self._parameter)) @@ -214,12 +239,15 @@ def test_continue_finish(aaf_pytest_db_schema, prefix, context_mock): """, MockResultSet( rows=[(1,)], - columns=[Column(ColumnName("a"), ColumnType( - name="DECIMAL", precision=1, scale=0))]) - ), - expect_query( - 'DROP VIEW IF EXISTS "{schema}"."{prefix}_2_1";' + columns=[ + Column( + ColumnName("a"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ) + ], + ), ), + expect_query('DROP VIEW IF EXISTS "{schema}"."{prefix}_2_1";'), ) test_input = TestInput() @@ -227,23 +255,29 @@ def test_continue_finish(aaf_pytest_db_schema, prefix, context_mock): sql_executor=sql_executor, top_level_query_handler_context=context_mock, parameter=test_input, - query_handler_factory=ContinueFinishTestQueryHandler + query_handler_factory=ContinueFinishTestQueryHandler, ) test_output = query_handler_runner.run() assert test_output.test_input == test_input class ContinueWrongColumnsTestQueryHandler(QueryHandler[TestInput, TestOutput]): - def __init__(self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter def start(self) -> Union[Continue, Finish[TestOutput]]: - input_query = SelectQueryWithColumnDefinition(f"""SELECT 1 as {ColumnName("b").quoted_name}""", - [Column(ColumnName("a"), ColumnType("INTEGER"))]) + input_query = SelectQueryWithColumnDefinition( + f"""SELECT 1 as {ColumnName("b").quoted_name}""", + [Column(ColumnName("a"), ColumnType("INTEGER"))], + ) return Continue(query_list=[], input_query=input_query) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[TestOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[TestOutput]]: raise AssertionError("handle_query_result shouldn't be called") @@ -271,19 +305,22 @@ def test_continue_wrong_columns(aaf_pytest_db_schema, prefix, context_mock): """, MockResultSet( rows=[(1,)], - columns=[Column(ColumnName("b"), ColumnType( - name="DECIMAL", precision=1, scale=0))]) - ), - expect_query( - 'DROP VIEW IF EXISTS "{schema}"."{prefix}_2_1";' + columns=[ + Column( + ColumnName("b"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ) + ], + ), ), + expect_query('DROP VIEW IF EXISTS "{schema}"."{prefix}_2_1";'), ) test_input = TestInput() query_handler_runner = PythonQueryHandlerRunner[TestInput, TestOutput]( sql_executor=sql_executor, top_level_query_handler_context=context_mock, parameter=test_input, - query_handler_factory=ContinueWrongColumnsTestQueryHandler + query_handler_factory=ContinueWrongColumnsTestQueryHandler, ) with pytest.raises(RuntimeError) as exception: test_output = query_handler_runner.run() @@ -291,21 +328,24 @@ def test_continue_wrong_columns(aaf_pytest_db_schema, prefix, context_mock): class ContinueQueryListTestQueryHandler(QueryHandler[TestInput, TestOutput]): - def __init__(self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter def start(self) -> Union[Continue, Finish[TestOutput]]: column_name = ColumnName("a") - input_query = SelectQueryWithColumnDefinition(f"""SELECT 1 as {column_name.quoted_name}""", - [Column(ColumnName("a"), - ColumnType(name="DECIMAL", - precision=1, - scale=0))]) + input_query = SelectQueryWithColumnDefinition( + f"""SELECT 1 as {column_name.quoted_name}""", + [Column(ColumnName("a"), ColumnType(name="DECIMAL", precision=1, scale=0))], + ) query_list = [SelectQuery(query_string="SELECT 1")] return Continue(query_list=query_list, input_query=input_query) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[TestOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[TestOutput]]: return Finish[TestOutput](TestOutput(self._parameter)) @@ -327,8 +367,13 @@ def test_continue_query_list(aaf_pytest_db_schema, prefix, context_mock): """, MockResultSet( rows=[(1,)], - columns=[Column(ColumnName("a"), ColumnType( - name="DECIMAL", precision=1, scale=0))]) + columns=[ + Column( + ColumnName("a"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ) + ], + ), ), expect_query( """ @@ -338,39 +383,45 @@ def test_continue_query_list(aaf_pytest_db_schema, prefix, context_mock): """, MockResultSet( rows=[(1,)], - columns=[Column(ColumnName("a"), ColumnType( - name="DECIMAL", precision=1, scale=0))]) - ), - expect_query( - 'DROP VIEW IF EXISTS "{schema}"."{prefix}_2_1";' + columns=[ + Column( + ColumnName("a"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ) + ], + ), ), + expect_query('DROP VIEW IF EXISTS "{schema}"."{prefix}_2_1";'), ) test_input = TestInput() query_handler_runner = PythonQueryHandlerRunner[TestInput, TestOutput]( sql_executor=sql_executor, top_level_query_handler_context=context_mock, parameter=test_input, - query_handler_factory=ContinueQueryListTestQueryHandler + query_handler_factory=ContinueQueryListTestQueryHandler, ) test_output = query_handler_runner.run() assert test_output.test_input == test_input class ContinueErrorCleanupQueriesTestQueryHandler(QueryHandler[TestInput, TestOutput]): - def __init__(self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter def start(self) -> Union[Continue, Finish[TestOutput]]: column_name = ColumnName("a") - input_query = SelectQueryWithColumnDefinition(f"""SELECT 1 as {column_name.quoted_name}""", - [Column(ColumnName("a"), - ColumnType(name="DECIMAL", - precision=1, - scale=0))]) + input_query = SelectQueryWithColumnDefinition( + f"""SELECT 1 as {column_name.quoted_name}""", + [Column(ColumnName("a"), ColumnType(name="DECIMAL", precision=1, scale=0))], + ) return Continue(query_list=[], input_query=input_query) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[TestOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[TestOutput]]: self._query_handler_context.get_temporary_table_name() raise Exception("Start failed") @@ -389,7 +440,8 @@ def test_continue_error_cleanup_queries(aaf_pytest_db_schema, prefix, context_mo """ CREATE OR REPLACE VIEW "{schema}"."{prefix}_2_1" AS SELECT 1 as "a"; - """), + """ + ), expect_query( """ SELECT @@ -398,8 +450,13 @@ def test_continue_error_cleanup_queries(aaf_pytest_db_schema, prefix, context_mo """, MockResultSet( rows=[(1,)], - columns=[Column(ColumnName("a"), ColumnType( - name="DECIMAL", precision=1, scale=0))]) + columns=[ + Column( + ColumnName("a"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ) + ], + ), ), expect_query('DROP TABLE IF EXISTS "{schema}"."{prefix}_3";'), expect_query('DROP VIEW IF EXISTS "{schema}"."{prefix}_2_1";'), @@ -409,7 +466,7 @@ def test_continue_error_cleanup_queries(aaf_pytest_db_schema, prefix, context_mo sql_executor=sql_executor, top_level_query_handler_context=context_mock, parameter=test_input, - query_handler_factory=ContinueErrorCleanupQueriesTestQueryHandler + query_handler_factory=ContinueErrorCleanupQueriesTestQueryHandler, ) with pytest.raises(Exception, match="Execution of query handler .* failed."): test_output = query_handler_runner.run() @@ -417,7 +474,9 @@ def test_continue_error_cleanup_queries(aaf_pytest_db_schema, prefix, context_mo class ContinueContinueFinishTestQueryHandler(QueryHandler[TestInput, TestOutput]): - def __init__(self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter self._iter = 0 @@ -426,22 +485,25 @@ def start(self) -> Union[Continue, Finish[TestOutput]]: column_name = ColumnName("a") input_query = SelectQueryWithColumnDefinition( f"""SELECT 1 as {column_name.quoted_name}""", - [Column(ColumnName("a"), - ColumnType(name="DECIMAL", - precision=1, - scale=0))]) + [Column(ColumnName("a"), ColumnType(name="DECIMAL", precision=1, scale=0))], + ) return Continue(query_list=[], input_query=input_query) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[TestOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[TestOutput]]: if self._iter == 0: self._iter += 1 column_name = ColumnName("b") input_query = SelectQueryWithColumnDefinition( f"""SELECT 1 as {column_name.quoted_name}""", - [Column(ColumnName("b"), - ColumnType(name="DECIMAL", - precision=1, - scale=0))]) + [ + Column( + ColumnName("b"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ) + ], + ) return Continue(query_list=[], input_query=input_query) else: return Finish[TestOutput](TestOutput(self._parameter)) @@ -461,7 +523,8 @@ def test_continue_continue_finish(aaf_pytest_db_schema, prefix, context_mock): """ CREATE OR REPLACE VIEW "{schema}"."{prefix}_2_1" AS SELECT 1 as "a"; - """), + """ + ), expect_query( """ SELECT @@ -470,17 +533,21 @@ def test_continue_continue_finish(aaf_pytest_db_schema, prefix, context_mock): """, MockResultSet( rows=[(1,)], - columns=[Column(ColumnName("a"), ColumnType( - name="DECIMAL", precision=1, scale=0))]) - ), - expect_query( - 'DROP VIEW IF EXISTS "{schema}"."{prefix}_2_1";' + columns=[ + Column( + ColumnName("a"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ) + ], + ), ), + expect_query('DROP VIEW IF EXISTS "{schema}"."{prefix}_2_1";'), expect_query( """ CREATE OR REPLACE VIEW "{schema}"."{prefix}_4_1" AS SELECT 1 as "b"; - """), + """ + ), expect_query( """ SELECT @@ -489,50 +556,65 @@ def test_continue_continue_finish(aaf_pytest_db_schema, prefix, context_mock): """, MockResultSet( rows=[(1,)], - columns=[Column(ColumnName("b"), ColumnType( - name="DECIMAL", precision=1, scale=0))]) + columns=[ + Column( + ColumnName("b"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ) + ], + ), ), - expect_query( - 'DROP VIEW IF EXISTS "{schema}"."{prefix}_4_1";'), + expect_query('DROP VIEW IF EXISTS "{schema}"."{prefix}_4_1";'), ) test_input = TestInput() query_handler_runner = PythonQueryHandlerRunner[TestInput, TestOutput]( sql_executor=sql_executor, top_level_query_handler_context=context_mock, parameter=test_input, - query_handler_factory=ContinueContinueFinishTestQueryHandler + query_handler_factory=ContinueContinueFinishTestQueryHandler, ) test_output = query_handler_runner.run() assert test_output.test_input == test_input -class ContinueContinueCleanupFinishTestQueryHandler(QueryHandler[TestInput, TestOutput]): - def __init__(self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext): +class ContinueContinueCleanupFinishTestQueryHandler( + QueryHandler[TestInput, TestOutput] +): + def __init__( + self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter self._iter = 0 def start(self) -> Union[Continue, Finish[TestOutput]]: - self._child_query_handler_conntext = self._query_handler_context.get_child_query_handler_context() + self._child_query_handler_conntext = ( + self._query_handler_context.get_child_query_handler_context() + ) self._table = self._child_query_handler_conntext.get_temporary_table_name() column_name = ColumnName("a") - input_query = SelectQueryWithColumnDefinition(f"""SELECT 1 as {column_name.quoted_name}""", - [Column(ColumnName("a"), - ColumnType(name="DECIMAL", - precision=1, - scale=0))]) + input_query = SelectQueryWithColumnDefinition( + f"""SELECT 1 as {column_name.quoted_name}""", + [Column(ColumnName("a"), ColumnType(name="DECIMAL", precision=1, scale=0))], + ) return Continue(query_list=[], input_query=input_query) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[TestOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[TestOutput]]: if self._iter == 0: self._child_query_handler_conntext.release() self._iter += 1 column_name = ColumnName("b") - input_query = SelectQueryWithColumnDefinition(f"""SELECT 1 as {column_name.quoted_name}""", - [Column(ColumnName("b"), - ColumnType(name="DECIMAL", - precision=1, - scale=0))]) + input_query = SelectQueryWithColumnDefinition( + f"""SELECT 1 as {column_name.quoted_name}""", + [ + Column( + ColumnName("b"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ) + ], + ) return Continue(query_list=[], input_query=input_query) else: return Finish[TestOutput](TestOutput(self._parameter)) @@ -554,7 +636,8 @@ def test_continue_cleanup_continue_finish(aaf_pytest_db_schema, prefix, context_ """ CREATE OR REPLACE VIEW "{schema}"."{prefix}_4_1" AS SELECT 1 as "a"; - """), + """ + ), expect_query( """ SELECT @@ -563,20 +646,22 @@ def test_continue_cleanup_continue_finish(aaf_pytest_db_schema, prefix, context_ """, MockResultSet( rows=[(1,)], - columns=[Column(ColumnName("a"), ColumnType( - name="DECIMAL", precision=1, scale=0))]) - ), - expect_query( - 'DROP VIEW IF EXISTS "{schema}"."{prefix}_4_1";' - ), - expect_query( - 'DROP TABLE IF EXISTS "{schema}"."{prefix}_2_1";' + columns=[ + Column( + ColumnName("a"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ) + ], + ), ), + expect_query('DROP VIEW IF EXISTS "{schema}"."{prefix}_4_1";'), + expect_query('DROP TABLE IF EXISTS "{schema}"."{prefix}_2_1";'), expect_query( """ CREATE OR REPLACE VIEW "{schema}"."{prefix}_6_1" AS SELECT 1 as "b"; - """), + """ + ), expect_query( """ SELECT @@ -585,26 +670,31 @@ def test_continue_cleanup_continue_finish(aaf_pytest_db_schema, prefix, context_ """, MockResultSet( rows=[(1,)], - columns=[Column(ColumnName("b"), ColumnType( - name="DECIMAL", precision=1, scale=0))]) - ), - expect_query( - 'DROP VIEW IF EXISTS "{schema}"."{prefix}_6_1";' + columns=[ + Column( + ColumnName("b"), + ColumnType(name="DECIMAL", precision=1, scale=0), + ) + ], + ), ), + expect_query('DROP VIEW IF EXISTS "{schema}"."{prefix}_6_1";'), ) test_input = TestInput() query_handler_runner = PythonQueryHandlerRunner[TestInput, TestOutput]( sql_executor=sql_executor, top_level_query_handler_context=context_mock, parameter=test_input, - query_handler_factory=ContinueContinueCleanupFinishTestQueryHandler + query_handler_factory=ContinueContinueCleanupFinishTestQueryHandler, ) test_output = query_handler_runner.run() assert test_output.test_input == test_input class FailInCleanupAfterException(QueryHandler[TestInput, TestOutput]): - def __init__(self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: TestInput, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter self._iter = 0 @@ -613,7 +703,9 @@ def start(self) -> Union[Continue, Finish[TestOutput]]: self._query_handler_context.get_child_query_handler_context() raise Exception(EXPECTED_EXCEPTION) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[TestOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[TestOutput]]: pass @@ -624,10 +716,12 @@ def test_fail_in_cleanup(aaf_pytest_db_schema, context_mock): sql_executor=sql_executor, top_level_query_handler_context=context_mock, parameter=test_input, - query_handler_factory=FailInCleanupAfterException + query_handler_factory=FailInCleanupAfterException, ) - with pytest.raises(RuntimeError, match="Execution of query handler .* failed.") as e: + with pytest.raises( + RuntimeError, match="Execution of query handler .* failed." + ) as e: query_handler_runner.run() assert e.value.__cause__.args[0] == EXPECTED_EXCEPTION diff --git a/tests/unit_tests/query_result/test_python_query_result.py b/tests/unit_tests/query_result/test_python_query_result.py index 1b536613..e18c30f1 100644 --- a/tests/unit_tests/query_result/test_python_query_result.py +++ b/tests/unit_tests/query_result/test_python_query_result.py @@ -1,20 +1,16 @@ - -from exasol.analytics.schema import ( - Column, - ColumnType, - ColumnName, +from exasol.analytics.query_handler.query.result.python_query_result import ( + PythonQueryResult, ) - -from exasol.analytics.query_handler.query.result.python_query_result import PythonQueryResult +from exasol.analytics.schema import Column, ColumnName, ColumnType DATA_SIZE = 100 FETCH_SIZE = 10 -INPUT_DATA = [(i, (1.0 * i / DATA_SIZE), str(2 * i)) - for i in range(1, DATA_SIZE + 1)] +INPUT_DATA = [(i, (1.0 * i / DATA_SIZE), str(2 * i)) for i in range(1, DATA_SIZE + 1)] INPUT_COLUMNS = [ Column(ColumnName("t1"), ColumnType("INTEGER")), Column(ColumnName("t2"), ColumnType("FLOAT")), - Column(ColumnName("t3"), ColumnType("VARCHAR(2000)"))] + Column(ColumnName("t3"), ColumnType("VARCHAR(2000)")), +] def test_fetch_as_dataframe_column_names(): diff --git a/tests/unit_tests/query_result/test_udf_query_result.py b/tests/unit_tests/query_result/test_udf_query_result.py index 0e0a7fdb..711b9005 100644 --- a/tests/unit_tests/query_result/test_udf_query_result.py +++ b/tests/unit_tests/query_result/test_udf_query_result.py @@ -6,24 +6,28 @@ DATA_SIZE = 100 FETCH_SIZE = 10 -INPUT_DATA = [(i, (1.0 * i / DATA_SIZE), str(2 * i)) - for i in range(1, DATA_SIZE + 1)] +INPUT_DATA = [(i, (1.0 * i / DATA_SIZE), str(2 * i)) for i in range(1, DATA_SIZE + 1)] INPUT_COLUMNS = [ Column("t1", int, "INTEGER"), Column("t2", float, "FLOAT"), - Column("t3", str, "VARCHAR(2000)")] + Column("t3", str, "VARCHAR(2000)"), +] def test_fetch_as_dataframe_column_names(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) df = wrapper.fetch_as_dataframe(num_rows=1) for column in df.columns: ctx.emit(column) @@ -34,7 +38,9 @@ def run(ctx: UDFContext): input_type="SET", input_columns=INPUT_COLUMNS, output_type="EMITS", - output_columns=[Column("column", str, "VARCHAR(2000000)"), ] + output_columns=[ + Column("column", str, "VARCHAR(2000000)"), + ], ) exa = MockExaEnvironment(meta) @@ -44,14 +50,18 @@ def run(ctx: UDFContext): def test_fetch_as_dataframe_first_batch(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) df = wrapper.fetch_as_dataframe(num_rows=10) ctx.emit(df[["a", "c", "b"]]) @@ -65,7 +75,7 @@ def run(ctx: UDFContext): Column("a", int, "INTEGER"), Column("c", float, "FLOAT"), Column("b", str, "VARCHAR(2000)"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -75,14 +85,18 @@ def run(ctx: UDFContext): def test_fetch_as_dataframe_second_batch(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) df = wrapper.fetch_as_dataframe(num_rows=10) df = wrapper.fetch_as_dataframe(num_rows=20) ctx.emit(df[["a", "c", "b"]]) @@ -97,7 +111,7 @@ def run(ctx: UDFContext): Column("a", int, "INTEGER"), Column("c", float, "FLOAT"), Column("b", str, "VARCHAR(2000)"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -107,14 +121,18 @@ def run(ctx: UDFContext): def test_fetch_as_dataframe_after_last_batch(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) while True: df = wrapper.fetch_as_dataframe(num_rows=10) if df is None: @@ -131,7 +149,7 @@ def run(ctx: UDFContext): Column("a", int, "INTEGER"), Column("c", float, "FLOAT"), Column("b", str, "VARCHAR(2000)"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -141,14 +159,18 @@ def run(ctx: UDFContext): def test_fetch_as_dataframe_all_rows(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) df = wrapper.fetch_as_dataframe(num_rows="all") ctx.emit(df[["a", "c", "b"]]) @@ -162,7 +184,7 @@ def run(ctx: UDFContext): Column("a", int, "INTEGER"), Column("c", float, "FLOAT"), Column("b", str, "VARCHAR(2000)"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -172,14 +194,18 @@ def run(ctx: UDFContext): def test_fetch_as_dataframe_start_col(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) df = wrapper.fetch_as_dataframe(num_rows=1, start_col=1) ctx.emit(df[["c", "b"]]) @@ -192,7 +218,7 @@ def run(ctx: UDFContext): output_columns=[ Column("c", float, "FLOAT"), Column("b", str, "VARCHAR(2000)"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -202,14 +228,18 @@ def run(ctx: UDFContext): def test_rowcount(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) ctx.emit(wrapper.rowcount()) executor = UDFMockExecutor() @@ -220,7 +250,7 @@ def run(ctx: UDFContext): output_type="EMITS", output_columns=[ Column("rowcount", int, "INTEGER"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -230,14 +260,18 @@ def run(ctx: UDFContext): def test_column_names(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) for column_name in wrapper.column_names(): ctx.emit(column_name) @@ -249,7 +283,7 @@ def run(ctx: UDFContext): output_type="EMITS", output_columns=[ Column("column_name", str, "VARCHAR(1000)"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -259,14 +293,18 @@ def run(ctx: UDFContext): def test_columns(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) for column in wrapper.columns(): ctx.emit(column.name.name, column.type.name) @@ -279,7 +317,7 @@ def run(ctx: UDFContext): output_columns=[ Column("column_name", str, "VARCHAR(1000)"), Column("sql_type", str, "VARCHAR(1000)"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -289,14 +327,18 @@ def run(ctx: UDFContext): def test_column_get_attr(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) a = wrapper.a b = wrapper.b c = wrapper.c @@ -312,7 +354,7 @@ def run(ctx: UDFContext): Column("a", int, "INTEGER"), Column("c", float, "FLOAT"), Column("b", str, "VARCHAR(2000)"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -322,14 +364,18 @@ def run(ctx: UDFContext): def test_column_get_item(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) a = wrapper["a"] b = wrapper["b"] c = wrapper["c"] @@ -345,7 +391,7 @@ def run(ctx: UDFContext): Column("a", int, "INTEGER"), Column("c", float, "FLOAT"), Column("b", str, "VARCHAR(2000)"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -355,14 +401,18 @@ def run(ctx: UDFContext): def test_column_next_get_item(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) while True: a = wrapper["a"] b = wrapper["b"] @@ -381,7 +431,7 @@ def run(ctx: UDFContext): Column("a", int, "INTEGER"), Column("c", float, "FLOAT"), Column("b", str, "VARCHAR(2000)"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -391,14 +441,18 @@ def run(ctx: UDFContext): def test_column_next_get_attr(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) while True: a = wrapper.a b = wrapper.b @@ -417,7 +471,7 @@ def run(ctx: UDFContext): Column("a", int, "INTEGER"), Column("c", float, "FLOAT"), Column("b", str, "VARCHAR(2000)"), - ] + ], ) exa = MockExaEnvironment(meta) @@ -427,14 +481,18 @@ def run(ctx: UDFContext): def test_column_iterator(): def udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol.analytics.query_handler.query.result.udf_query_result \ - import UDFQueryResult from collections import OrderedDict + from exasol_udf_mock_python.udf_context import UDFContext + + from exasol.analytics.query_handler.query.result.udf_query_result import ( + UDFQueryResult, + ) + def run(ctx: UDFContext): wrapper = UDFQueryResult( - ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")])) + ctx, exa, OrderedDict([("t1", "a"), ("t3", "b"), ("t2", "c")]) + ) for row in wrapper: ctx.emit(row[0], row[2], row[1]) @@ -448,7 +506,7 @@ def run(ctx: UDFContext): Column("a", int, "INTEGER"), Column("c", float, "FLOAT"), Column("b", str, "VARCHAR(2000)"), - ] + ], ) exa = MockExaEnvironment(meta) diff --git a/tests/unit_tests/sql_stage_graph/mock_cast.py b/tests/unit_tests/sql_stage_graph/mock_cast.py deleted file mode 100644 index 3a3866ea..00000000 --- a/tests/unit_tests/sql_stage_graph/mock_cast.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Any, cast -from unittest.mock import Mock - - -def mock_cast(obj: Any) -> Mock: - return cast(Mock, obj) diff --git a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/assert_helper.py b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/assert_helper.py index fad501cb..de1814b0 100644 --- a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/assert_helper.py +++ b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/assert_helper.py @@ -1,21 +1,33 @@ from typing import List -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput -from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import SQLStageTrainQueryHandlerInput -from tests.mock_cast import mock_cast -from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.state_test_setup import TestSetup +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) +from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import ( + SQLStageTrainQueryHandlerInput, +) +from tests.utils.mock_cast import mock_cast +from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.state_test_setup import ( + TestSetup, +) def assert_reference_counting_bag_not_called(test_setup: TestSetup): - reference_counting_bag_mock_setup = test_setup.state_setup.reference_counting_bag_mock_setup + reference_counting_bag_mock_setup = ( + test_setup.state_setup.reference_counting_bag_mock_setup + ) assert reference_counting_bag_mock_setup.bag.mock_calls == [] def assert_reference_counting_bag_creation(test_setup: TestSetup): parent_query_handler_context = test_setup.state_setup.parent_query_handler_context - reference_counting_bag_mock_setup = test_setup.state_setup.reference_counting_bag_mock_setup + reference_counting_bag_mock_setup = ( + test_setup.state_setup.reference_counting_bag_mock_setup + ) - reference_counting_bag_mock_setup.factory.assert_called_once_with(parent_query_handler_context) + reference_counting_bag_mock_setup.factory.assert_called_once_with( + parent_query_handler_context + ) assert reference_counting_bag_mock_setup.bag.mock_calls == [] @@ -32,24 +44,31 @@ def assert_stage_not_called(test_setup: TestSetup, *, stage_index: int): def assert_stage_train_query_handler_created( - test_setup: TestSetup, *, stage_index: int, stage_inputs: List[SQLStageInputOutput]): + test_setup: TestSetup, *, stage_index: int, stage_inputs: List[SQLStageInputOutput] +): stage_setup = test_setup.stage_setups[stage_index] - mock_cast(test_setup.state_setup.result_bucketfs_location.joinpath).assert_called_once_with(str(stage_index)) - mock_cast(test_setup.state_setup.parent_query_handler_context.get_child_query_handler_context).assert_called_once() - result_bucketfs_location = test_setup.stage_setups[stage_index].result_bucketfs_location + mock_cast( + test_setup.state_setup.result_bucketfs_location.joinpath + ).assert_called_once_with(str(stage_index)) + mock_cast( + test_setup.state_setup.parent_query_handler_context.get_child_query_handler_context + ).assert_called_once() + result_bucketfs_location = test_setup.stage_setups[ + stage_index + ].result_bucketfs_location stage_input = SQLStageTrainQueryHandlerInput( - result_bucketfs_location=result_bucketfs_location, - sql_stage_inputs=stage_inputs + result_bucketfs_location=result_bucketfs_location, sql_stage_inputs=stage_inputs ) mock_cast(stage_setup.stage.create_train_query_handler).assert_called_once_with( - stage_input, - stage_setup.child_query_handler_context + stage_input, stage_setup.child_query_handler_context ) assert stage_setup.train_query_handler.mock_calls == [] assert stage_setup.child_query_handler_context.mock_calls == [] -def assert_release_on_query_handler_context_for_stage(test_setup: TestSetup, *, stage_index: int): +def assert_release_on_query_handler_context_for_stage( + test_setup: TestSetup, *, stage_index: int +): stage_setup = test_setup.stage_setups[stage_index] assert stage_setup.train_query_handler.mock_calls == [] mock_cast(stage_setup.child_query_handler_context.release).assert_called_once() diff --git a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/state_test_setup.py b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/state_test_setup.py index 0068086c..618672e6 100644 --- a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/state_test_setup.py +++ b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/state_test_setup.py @@ -1,21 +1,33 @@ import dataclasses -from typing import Union, List +from typing import List, Union from unittest.mock import MagicMock, Mock, create_autospec +from exasol_bucketfs_utils_python.abstract_bucketfs_location import ( + AbstractBucketFSLocation, +) + from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.graph.stage.sql.execution.input import ( + SQLStageGraphExecutionInput, +) +from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counting_bag import ( + ObjectProxyReferenceCountingBag, +) +from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import ( + SQLStageGraphExecutionQueryHandlerState, +) +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) +from exasol.analytics.query_handler.graph.stage.sql.sql_stage import SQLStage +from exasol.analytics.query_handler.graph.stage.sql.sql_stage_graph import SQLStageGraph +from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import ( + SQLStageQueryHandler, +) +from exasol.analytics.query_handler.query.result.interface import QueryResult from exasol.analytics.query_handler.query_handler import QueryHandler from exasol.analytics.query_handler.result import Continue, Finish -from exasol.analytics.query_handler.query.result.interface import QueryResult -from exasol_bucketfs_utils_python.abstract_bucketfs_location import AbstractBucketFSLocation - -from exasol.analytics.query_handler.graph.stage.sql.sql_stage_graph import SQLStageGraph -from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counting_bag import ObjectProxyReferenceCountingBag -from exasol.analytics.query_handler.graph.stage.sql.execution.input import SQLStageGraphExecutionInput -from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import SQLStageGraphExecutionQueryHandlerState -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput -from exasol.analytics.query_handler.graph.stage.sql.sql_stage import SQLStage -from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import SQLStageQueryHandler -from tests.mock_cast import mock_cast +from tests.utils.mock_cast import mock_cast MockScopeQueryHandlerContext = Union[ScopeQueryHandlerContext, MagicMock] MockSQLStageTrainQueryHandler = Union[SQLStageQueryHandler, MagicMock] @@ -25,7 +37,9 @@ MockQueryResult = Union[QueryResult, MagicMock] MockSQLStageGraphExecutionInput = Union[SQLStageGraphExecutionInput, MagicMock] MockObjectProxyReferenceCountingBag = Union[ObjectProxyReferenceCountingBag, MagicMock] -MockObjectProxyReferenceCountingBagFactory = Union[ObjectProxyReferenceCountingBag, Mock] +MockObjectProxyReferenceCountingBagFactory = Union[ + ObjectProxyReferenceCountingBag, Mock +] MockBucketFSLocation = Union[AbstractBucketFSLocation, MagicMock] @@ -85,67 +99,90 @@ def reset_mock(self): def create_execution_query_handler_state_setup( - sql_stage_graph: SQLStageGraph, - stage_setups: List[StageSetup]) \ - -> ExecutionQueryHandlerStateSetup: - child_scoped_query_handler_contexts = [stage_setup.child_query_handler_context - for stage_setup in stage_setups] - scoped_query_handler_context = create_mock_query_handler_context(child_scoped_query_handler_contexts) - sql_stage_input_output: MockSQLStageInputOutput = create_autospec(SQLStageInputOutput) - mock_result_bucketfs_location: MockBucketFSLocation = create_autospec(AbstractBucketFSLocation) - stage_result_bucketfs_locations = [stage_setup.result_bucketfs_location - for stage_setup in stage_setups] - mock_cast(mock_result_bucketfs_location.joinpath).side_effect = stage_result_bucketfs_locations + sql_stage_graph: SQLStageGraph, stage_setups: List[StageSetup] +) -> ExecutionQueryHandlerStateSetup: + child_scoped_query_handler_contexts = [ + stage_setup.child_query_handler_context for stage_setup in stage_setups + ] + scoped_query_handler_context = create_mock_query_handler_context( + child_scoped_query_handler_contexts + ) + sql_stage_input_output: MockSQLStageInputOutput = create_autospec( + SQLStageInputOutput + ) + mock_result_bucketfs_location: MockBucketFSLocation = create_autospec( + AbstractBucketFSLocation + ) + stage_result_bucketfs_locations = [ + stage_setup.result_bucketfs_location for stage_setup in stage_setups + ] + mock_cast(mock_result_bucketfs_location.joinpath).side_effect = ( + stage_result_bucketfs_locations + ) parameter = SQLStageGraphExecutionInput( input=sql_stage_input_output, sql_stage_graph=sql_stage_graph, - result_bucketfs_location=mock_result_bucketfs_location) + result_bucketfs_location=mock_result_bucketfs_location, + ) reference_counting_bag_mock_setup = create_reference_counting_bag_mock_setup() execution_query_handler_state = SQLStageGraphExecutionQueryHandlerState( parameter=parameter, query_handler_context=scoped_query_handler_context, - reference_counting_bag_factory=reference_counting_bag_mock_setup.factory + reference_counting_bag_factory=reference_counting_bag_mock_setup.factory, ) return ExecutionQueryHandlerStateSetup( reference_counting_bag_mock_setup=reference_counting_bag_mock_setup, parent_query_handler_context=scoped_query_handler_context, sql_stage_input_output=sql_stage_input_output, execution_query_handler_state=execution_query_handler_state, - result_bucketfs_location=mock_result_bucketfs_location + result_bucketfs_location=mock_result_bucketfs_location, ) def create_reference_counting_bag_mock_setup() -> ReferenceCountingBagSetup: - reference_counting_bag: MockObjectProxyReferenceCountingBag = create_autospec(ObjectProxyReferenceCountingBag) + reference_counting_bag: MockObjectProxyReferenceCountingBag = create_autospec( + ObjectProxyReferenceCountingBag + ) reference_counting_bag_factory: MockObjectProxyReferenceCountingBagFactory = Mock() reference_counting_bag_factory.return_value = reference_counting_bag - return ReferenceCountingBagSetup(bag=reference_counting_bag, - factory=reference_counting_bag_factory) + return ReferenceCountingBagSetup( + bag=reference_counting_bag, factory=reference_counting_bag_factory + ) -def create_mock_query_handler_context(child_scoped_query_handler_contexts: List[ScopeQueryHandlerContext]) \ - -> MockScopeQueryHandlerContext: - scoped_query_handler_context: MockScopeQueryHandlerContext = \ - create_autospec(ScopeQueryHandlerContext) - scoped_query_handler_context.get_child_query_handler_context.side_effect = child_scoped_query_handler_contexts +def create_mock_query_handler_context( + child_scoped_query_handler_contexts: List[ScopeQueryHandlerContext], +) -> MockScopeQueryHandlerContext: + scoped_query_handler_context: MockScopeQueryHandlerContext = create_autospec( + ScopeQueryHandlerContext + ) + scoped_query_handler_context.get_child_query_handler_context.side_effect = ( + child_scoped_query_handler_contexts + ) return scoped_query_handler_context -def create_mocks_for_stage(result_prototypes: List[Union[Finish, Continue]], *, stage_index: int) -> StageSetup: - child_scoped_query_handler_context: MockScopeQueryHandlerContext = \ - create_autospec(ScopeQueryHandlerContext) +def create_mocks_for_stage( + result_prototypes: List[Union[Finish, Continue]], *, stage_index: int +) -> StageSetup: + child_scoped_query_handler_context: MockScopeQueryHandlerContext = create_autospec( + ScopeQueryHandlerContext + ) sql_stage: MockSQLStage = create_autospec(SQLStage) sql_stage.__hash__.return_value = stage_index - result: List[MockQueryHandlerResult] = [create_autospec(result_prototype) - for result_prototype in result_prototypes] + result: List[MockQueryHandlerResult] = [ + create_autospec(result_prototype) for result_prototype in result_prototypes + ] train_query_handler: MockSQLStageTrainQueryHandler = create_autospec(QueryHandler) sql_stage.create_train_query_handler.return_value = train_query_handler - mock_result_bucketfs_location: MockBucketFSLocation = create_autospec(AbstractBucketFSLocation) + mock_result_bucketfs_location: MockBucketFSLocation = create_autospec( + AbstractBucketFSLocation + ) return StageSetup( index=stage_index, child_query_handler_context=child_scoped_query_handler_context, train_query_handler=train_query_handler, stage=sql_stage, results=result, - result_bucketfs_location=mock_result_bucketfs_location + result_bucketfs_location=mock_result_bucketfs_location, ) diff --git a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_query_handler_integration.py b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_query_handler_integration.py index 0058376e..4f069394 100644 --- a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_query_handler_integration.py +++ b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_query_handler_integration.py @@ -2,55 +2,75 @@ import enum from contextlib import contextmanager from pathlib import PurePosixPath -from typing import List, Union, Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union from unittest.mock import Mock import pytest +from exasol_bucketfs_utils_python.localfs_mock_bucketfs_location import ( + LocalFSMockBucketFSLocation, +) from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.context.top_level_query_handler_context import TopLevelQueryHandlerContext -from exasol.analytics.query_handler.query.select import SelectQueryWithColumnDefinition -from exasol.analytics.query_handler.result import Finish, Continue +from exasol.analytics.query_handler.context.top_level_query_handler_context import ( + TopLevelQueryHandlerContext, +) +from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition +from exasol.analytics.query_handler.graph.stage.sql.dataset import Dataset +from exasol.analytics.query_handler.graph.stage.sql.execution.input import ( + SQLStageGraphExecutionInput, +) +from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler import ( + SQLStageGraphExecutionQueryHandler, +) +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) +from exasol.analytics.query_handler.graph.stage.sql.sql_stage import SQLStage +from exasol.analytics.query_handler.graph.stage.sql.sql_stage_graph import SQLStageGraph +from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import ( + SQLStageQueryHandler, + SQLStageTrainQueryHandlerInput, +) from exasol.analytics.query_handler.query.result.interface import QueryResult -from exasol_bucketfs_utils_python.localfs_mock_bucketfs_location import LocalFSMockBucketFSLocation +from exasol.analytics.query_handler.query.select import SelectQueryWithColumnDefinition +from exasol.analytics.query_handler.result import Continue, Finish from exasol.analytics.schema import ( + ColumnBuilder, + ColumnNameBuilder, + ColumnType, SchemaName, TableBuilder, TableName, - ColumnBuilder, TableNameBuilder, - ColumnType, - ColumnNameBuilder, ) -from exasol.analytics.query_handler.graph.stage.sql.sql_stage_graph import SQLStageGraph -from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition -from exasol.analytics.query_handler.graph.stage.sql.dataset import Dataset -from exasol.analytics.query_handler.graph.stage.sql.execution.input import SQLStageGraphExecutionInput -from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler import SQLStageGraphExecutionQueryHandler -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput -from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import SQLStageQueryHandler, SQLStageTrainQueryHandlerInput -from exasol.analytics.query_handler.graph.stage.sql.sql_stage import SQLStage class StartOnlyForwardInputTestSQLStageTrainQueryHandler(SQLStageQueryHandler): - def __init__(self, parameter: SQLStageTrainQueryHandlerInput, - query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, + parameter: SQLStageTrainQueryHandlerInput, + query_handler_context: ScopeQueryHandlerContext, + ): super().__init__(parameter, query_handler_context) self._parameter = parameter def start(self) -> Union[Continue, Finish[SQLStageInputOutput]]: return Finish[SQLStageInputOutput](self._parameter.sql_stage_inputs[0]) - def handle_query_result(self, query_result: QueryResult) \ - -> Union[Continue, Finish[SQLStageInputOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[SQLStageInputOutput]]: raise NotImplementedError() class StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler(SQLStageQueryHandler): - def __init__(self, parameter: SQLStageTrainQueryHandlerInput, - query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, + parameter: SQLStageTrainQueryHandlerInput, + query_handler_context: ScopeQueryHandlerContext, + ): super().__init__(parameter, query_handler_context) self._parameter = parameter self.stage_input_output: Optional[SQLStageInputOutput] = None @@ -58,7 +78,9 @@ def __init__(self, parameter: SQLStageTrainQueryHandlerInput, def start(self) -> Union[Continue, Finish[SQLStageInputOutput]]: dataset = self._parameter.sql_stage_inputs[0].dataset - input_table_like = dataset.data_partitions[TestDatasetPartitionName.TRAIN].table_like + input_table_like = dataset.data_partitions[ + TestDatasetPartitionName.TRAIN + ].table_like # This tests also, if temporary table names are still valid self.input_table_like_name = input_table_like.name.fully_qualified @@ -66,15 +88,21 @@ def start(self) -> Union[Continue, Finish[SQLStageInputOutput]]: self.stage_input_output = create_stage_input_output(output_table_name) return Finish[SQLStageInputOutput](self.stage_input_output) - def handle_query_result(self, query_result: QueryResult) \ - -> Union[Continue, Finish[SQLStageInputOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[SQLStageInputOutput]]: raise NotImplementedError() -class HandleQueryResultCreateNewOutputTestSQLStageTrainQueryHandler(SQLStageQueryHandler): +class HandleQueryResultCreateNewOutputTestSQLStageTrainQueryHandler( + SQLStageQueryHandler +): - def __init__(self, parameter: SQLStageTrainQueryHandlerInput, - query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, + parameter: SQLStageTrainQueryHandlerInput, + query_handler_context: ScopeQueryHandlerContext, + ): super().__init__(parameter, query_handler_context) self._parameter = parameter self.stage_input_output: Optional[SQLStageInputOutput] = None @@ -87,14 +115,16 @@ def start(self) -> Union[Continue, Finish[SQLStageInputOutput]]: table_like_name = table_like.name table_like_columns = table_like.columns select_query_with_column_definition = SelectQueryWithColumnDefinition( - f"{table_like_name.fully_qualified}", - table_like_columns + f"{table_like_name.fully_qualified}", table_like_columns + ) + self.continue_result = Continue( + query_list=[], input_query=select_query_with_column_definition ) - self.continue_result = Continue(query_list=[], input_query=select_query_with_column_definition) return self.continue_result - def handle_query_result(self, query_result: QueryResult) \ - -> Union[Continue, Finish[SQLStageInputOutput]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[SQLStageInputOutput]]: self.query_result = query_result table_name = self._query_handler_context.get_temporary_table_name() self.stage_input_output = create_stage_input_output(table_name) @@ -102,23 +132,28 @@ def handle_query_result(self, query_result: QueryResult) \ TrainQueryHandlerFactory = Callable[ - [SQLStageTrainQueryHandlerInput, ScopeQueryHandlerContext], SQLStageQueryHandler] + [SQLStageTrainQueryHandlerInput, ScopeQueryHandlerContext], SQLStageQueryHandler +] class TestSQLStage(SQLStage): __test__ = False - def __init__(self, *, - index: int, - train_query_handler_factory: TrainQueryHandlerFactory): + def __init__( + self, *, index: int, train_query_handler_factory: TrainQueryHandlerFactory + ): self._train_query_handler_factory = train_query_handler_factory self.sql_stage_train_query_handler: Optional[SQLStageQueryHandler] = None self._index = index - def create_train_query_handler(self, query_handler_input: SQLStageTrainQueryHandlerInput, - query_handler_context: ScopeQueryHandlerContext) -> SQLStageQueryHandler: - self.sql_stage_train_query_handler = self._train_query_handler_factory(query_handler_input, - query_handler_context) + def create_train_query_handler( + self, + query_handler_input: SQLStageTrainQueryHandlerInput, + query_handler_context: ScopeQueryHandlerContext, + ) -> SQLStageQueryHandler: + self.sql_stage_train_query_handler = self._train_query_handler_factory( + query_handler_input, query_handler_context + ) return self.sql_stage_train_query_handler def __eq__(self, other): @@ -143,23 +178,30 @@ def create_input() -> SQLStageInputOutput: def create_stage_input_output(table_name: TableName): - identifier_column = ColumnBuilder() \ - .with_name(ColumnNameBuilder().with_name("ID").build()) \ - .with_type(ColumnType("INTEGER")).build() - sample_column = ColumnBuilder() \ - .with_name(ColumnNameBuilder().with_name("SAMPLE").build()) \ - .with_type(ColumnType("INTEGER")).build() - target_column = ColumnBuilder() \ - .with_name(ColumnNameBuilder().with_name("TARGET").build()) \ - .with_type(ColumnType("INTEGER")).build() + identifier_column = ( + ColumnBuilder() + .with_name(ColumnNameBuilder().with_name("ID").build()) + .with_type(ColumnType("INTEGER")) + .build() + ) + sample_column = ( + ColumnBuilder() + .with_name(ColumnNameBuilder().with_name("SAMPLE").build()) + .with_type(ColumnType("INTEGER")) + .build() + ) + target_column = ( + ColumnBuilder() + .with_name(ColumnNameBuilder().with_name("TARGET").build()) + .with_type(ColumnType("INTEGER")) + .build() + ) columns = [ identifier_column, sample_column, target_column, ] - table_like = TableBuilder() \ - .with_name(table_name) \ - .with_columns(columns).build() + table_like = TableBuilder().with_name(table_name).with_columns(columns).build() data_partition = DataPartition( table_like=table_like, ) @@ -167,11 +209,9 @@ def create_stage_input_output(table_name: TableName): data_partitions={TestDatasetPartitionName.TRAIN: data_partition}, target_columns=[target_column], sample_columns=[sample_column], - identifier_columns=[identifier_column] - ) - stage_input_output = SQLStageInputOutput( - dataset=dataset + identifier_columns=[identifier_column], ) + stage_input_output = SQLStageInputOutput(dataset=dataset) return stage_input_output @@ -184,31 +224,36 @@ class TestSetup: query_handler: SQLStageGraphExecutionQueryHandler -def create_test_setup(*, sql_stage_graph: SQLStageGraph, - stages: List[TestSQLStage], - context: TopLevelQueryHandlerContext, - local_fs_mock_bucket_fs_tmp_path: PurePosixPath) -> TestSetup: +def create_test_setup( + *, + sql_stage_graph: SQLStageGraph, + stages: List[TestSQLStage], + context: TopLevelQueryHandlerContext, + local_fs_mock_bucket_fs_tmp_path: PurePosixPath, +) -> TestSetup: stage_input_output = create_input() parameter = SQLStageGraphExecutionInput( sql_stage_graph=sql_stage_graph, - result_bucketfs_location=LocalFSMockBucketFSLocation(local_fs_mock_bucket_fs_tmp_path), - input=stage_input_output + result_bucketfs_location=LocalFSMockBucketFSLocation( + local_fs_mock_bucket_fs_tmp_path + ), + input=stage_input_output, ) child_query_handler_context = context.get_child_query_handler_context() query_handler = SQLStageGraphExecutionQueryHandler( - parameter=parameter, - query_handler_context=child_query_handler_context + parameter=parameter, query_handler_context=child_query_handler_context ) return TestSetup( stages=stages, stage_input_output=stage_input_output, child_query_handler_context=child_query_handler_context, - query_handler=query_handler + query_handler=query_handler, ) def test_start_with_single_stage_with_start_only_forward_train_query_handler( - top_level_query_handler_context_mock, tmp_path): + top_level_query_handler_context_mock, tmp_path +): """ This test runs an integration test for the start method of a SQLStageGraphExecutionQueryHandler on a SQLStageGraph with a single stage which returns a StartOnlyForwardInputTestSQLStageTrainQueryHandler. @@ -223,16 +268,15 @@ def test_start_with_single_stage_with_start_only_forward_train_query_handler( def arrange() -> TestSetup: stage1 = TestSQLStage( index=1, - train_query_handler_factory=StartOnlyForwardInputTestSQLStageTrainQueryHandler) - sql_stage_graph = SQLStageGraph( - start_node=stage1, - end_node=stage1, - edges=[] + train_query_handler_factory=StartOnlyForwardInputTestSQLStageTrainQueryHandler, + ) + sql_stage_graph = SQLStageGraph(start_node=stage1, end_node=stage1, edges=[]) + test_setup = create_test_setup( + sql_stage_graph=sql_stage_graph, + stages=[stage1], + context=top_level_query_handler_context_mock, + local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path), ) - test_setup = create_test_setup(sql_stage_graph=sql_stage_graph, - stages=[stage1], - context=top_level_query_handler_context_mock, - local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path)) return test_setup def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: @@ -242,18 +286,24 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: test_setup = arrange() result = act(test_setup) - assert isinstance(result, Finish) \ - and isinstance(result.result, SQLStageInputOutput) \ - and result.result.dataset == test_setup.stage_input_output.dataset \ - and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 0 + assert ( + isinstance(result, Finish) + and isinstance(result.result, SQLStageInputOutput) + and result.result.dataset == test_setup.stage_input_output.dataset + and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) + == 0 + ) test_setup.child_query_handler_context.release() - assert len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 0 + assert ( + len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 0 + ) def test_start_with_two_stages_with_start_only_forward_train_query_handler( - top_level_query_handler_context_mock, tmp_path): + top_level_query_handler_context_mock, tmp_path +): """ This test runs an integration test for the start method of a SQLStageGraphExecutionQueryHandler on a SQLStageGraph with two stages which return a StartOnlyForwardInputTestSQLStageTrainQueryHandler. @@ -268,19 +318,21 @@ def test_start_with_two_stages_with_start_only_forward_train_query_handler( def arrange() -> TestSetup: stage1 = TestSQLStage( index=1, - train_query_handler_factory=StartOnlyForwardInputTestSQLStageTrainQueryHandler) + train_query_handler_factory=StartOnlyForwardInputTestSQLStageTrainQueryHandler, + ) stage2 = TestSQLStage( index=2, - train_query_handler_factory=StartOnlyForwardInputTestSQLStageTrainQueryHandler) + train_query_handler_factory=StartOnlyForwardInputTestSQLStageTrainQueryHandler, + ) sql_stage_graph = SQLStageGraph( - start_node=stage1, - end_node=stage2, - edges=[(stage1, stage2)] + start_node=stage1, end_node=stage2, edges=[(stage1, stage2)] + ) + test_setup = create_test_setup( + sql_stage_graph=sql_stage_graph, + stages=[stage1, stage2], + context=top_level_query_handler_context_mock, + local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path), ) - test_setup = create_test_setup(sql_stage_graph=sql_stage_graph, - stages=[stage1, stage2], - context=top_level_query_handler_context_mock, - local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path)) return test_setup def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: @@ -290,14 +342,19 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: test_setup = arrange() result = act(test_setup) - assert isinstance(result, Finish) \ - and isinstance(result.result, SQLStageInputOutput) \ - and result.result.dataset == test_setup.stage_input_output.dataset \ - and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 0 + assert ( + isinstance(result, Finish) + and isinstance(result.result, SQLStageInputOutput) + and result.result.dataset == test_setup.stage_input_output.dataset + and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) + == 0 + ) test_setup.child_query_handler_context.release() - assert len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 0 + assert ( + len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 0 + ) @contextmanager @@ -305,11 +362,12 @@ def not_raises(exception): try: yield except exception: - raise pytest.fail("DID RAISE {0}".format(exception)) + raise pytest.fail(f"DID RAISE {exception}") def test_start_with_single_stage_with_start_only_create_new_output_train_query_handler( - top_level_query_handler_context_mock, tmp_path): + top_level_query_handler_context_mock, tmp_path +): """ This test runs an integration test for the start method of a SQLStageGraphExecutionQueryHandler on a SQLStageGraph with a single stage which return a StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler. @@ -326,16 +384,15 @@ def test_start_with_single_stage_with_start_only_create_new_output_train_query_h def arrange() -> TestSetup: stage1 = TestSQLStage( index=1, - train_query_handler_factory=StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler) - sql_stage_graph = SQLStageGraph( - start_node=stage1, - end_node=stage1, - edges=[] + train_query_handler_factory=StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler, + ) + sql_stage_graph = SQLStageGraph(start_node=stage1, end_node=stage1, edges=[]) + test_setup = create_test_setup( + sql_stage_graph=sql_stage_graph, + stages=[stage1], + context=top_level_query_handler_context_mock, + local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path), ) - test_setup = create_test_setup(sql_stage_graph=sql_stage_graph, - stages=[stage1], - context=top_level_query_handler_context_mock, - local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path)) return test_setup def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: @@ -346,24 +403,36 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: result = act(test_setup) stage_1_train_query_handler = test_setup.stages[0].sql_stage_train_query_handler - assert isinstance(result, Finish) \ - and isinstance(result.result, SQLStageInputOutput) \ - and result.result.dataset != test_setup.stage_input_output.dataset \ - and isinstance(stage_1_train_query_handler, StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler) \ - and result.result.dataset == stage_1_train_query_handler.stage_input_output.dataset \ - and stage_1_train_query_handler.input_table_like_name is not None \ - and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 0 + assert ( + isinstance(result, Finish) + and isinstance(result.result, SQLStageInputOutput) + and result.result.dataset != test_setup.stage_input_output.dataset + and isinstance( + stage_1_train_query_handler, + StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler, + ) + and result.result.dataset + == stage_1_train_query_handler.stage_input_output.dataset + and stage_1_train_query_handler.input_table_like_name is not None + and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) + == 0 + ) if isinstance(result, Finish) and isinstance(result.result, SQLStageInputOutput): with not_raises(Exception): - name = result.result.dataset.data_partitions[TestDatasetPartitionName.TRAIN].table_like.name + name = result.result.dataset.data_partitions[ + TestDatasetPartitionName.TRAIN + ].table_like.name test_setup.child_query_handler_context.release() - assert len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 1 + assert ( + len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 1 + ) def test_start_with_two_stages_with_start_only_create_new_output_train_query_handler( - top_level_query_handler_context_mock, tmp_path): + top_level_query_handler_context_mock, tmp_path +): """ This test runs an integration test for the start method of a SQLStageGraphExecutionQueryHandler on a SQLStageGraph with two stages which return a StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler. @@ -380,19 +449,21 @@ def test_start_with_two_stages_with_start_only_create_new_output_train_query_han def arrange() -> TestSetup: stage1 = TestSQLStage( index=1, - train_query_handler_factory=StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler) + train_query_handler_factory=StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler, + ) stage2 = TestSQLStage( index=2, - train_query_handler_factory=StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler) + train_query_handler_factory=StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler, + ) sql_stage_graph = SQLStageGraph( - start_node=stage1, - end_node=stage2, - edges=[(stage1, stage2)] + start_node=stage1, end_node=stage2, edges=[(stage1, stage2)] + ) + test_setup = create_test_setup( + sql_stage_graph=sql_stage_graph, + stages=[stage1, stage2], + context=top_level_query_handler_context_mock, + local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path), ) - test_setup = create_test_setup(sql_stage_graph=sql_stage_graph, - stages=[stage1, stage2], - context=top_level_query_handler_context_mock, - local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path)) return test_setup def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: @@ -404,26 +475,41 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: stage_1_train_query_handler = test_setup.stages[0].sql_stage_train_query_handler stage_2_train_query_handler = test_setup.stages[1].sql_stage_train_query_handler - assert isinstance(result, Finish) \ - and isinstance(result.result, SQLStageInputOutput) \ - and result.result.dataset != test_setup.stage_input_output.dataset \ - and isinstance(stage_1_train_query_handler, StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler) \ - and isinstance(stage_2_train_query_handler, StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler) \ - and result.result.dataset == stage_2_train_query_handler.stage_input_output.dataset \ - and stage_1_train_query_handler.input_table_like_name is not None \ - and stage_2_train_query_handler.input_table_like_name is not None \ - and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 1 + assert ( + isinstance(result, Finish) + and isinstance(result.result, SQLStageInputOutput) + and result.result.dataset != test_setup.stage_input_output.dataset + and isinstance( + stage_1_train_query_handler, + StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler, + ) + and isinstance( + stage_2_train_query_handler, + StartOnlyCreateNewOutputTestSQLStageTrainQueryHandler, + ) + and result.result.dataset + == stage_2_train_query_handler.stage_input_output.dataset + and stage_1_train_query_handler.input_table_like_name is not None + and stage_2_train_query_handler.input_table_like_name is not None + and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) + == 1 + ) if isinstance(result, Finish) and isinstance(result.result, SQLStageInputOutput): with not_raises(Exception): - name = result.result.dataset.data_partitions[TestDatasetPartitionName.TRAIN].table_like.name + name = result.result.dataset.data_partitions[ + TestDatasetPartitionName.TRAIN + ].table_like.name test_setup.child_query_handler_context.release() - assert len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 1 + assert ( + len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 1 + ) def test_start_with_single_stage_with_handle_query_result_create_new_output_train_query_handler_part1( - top_level_query_handler_context_mock, tmp_path): + top_level_query_handler_context_mock, tmp_path +): """ This test runs an integration test for the start method of a SQLStageGraphExecutionQueryHandler on a SQLStageGraph with a single stage which return a HandleQueryResultCreateNewOutputTestSQLStageTrainQueryHandler. @@ -437,16 +523,15 @@ def test_start_with_single_stage_with_handle_query_result_create_new_output_trai def arrange() -> TestSetup: stage1 = TestSQLStage( index=1, - train_query_handler_factory=HandleQueryResultCreateNewOutputTestSQLStageTrainQueryHandler) - sql_stage_graph = SQLStageGraph( - start_node=stage1, - end_node=stage1, - edges=[] + train_query_handler_factory=HandleQueryResultCreateNewOutputTestSQLStageTrainQueryHandler, + ) + sql_stage_graph = SQLStageGraph(start_node=stage1, end_node=stage1, edges=[]) + test_setup = create_test_setup( + sql_stage_graph=sql_stage_graph, + stages=[stage1], + context=top_level_query_handler_context_mock, + local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path), ) - test_setup = create_test_setup(sql_stage_graph=sql_stage_graph, - stages=[stage1], - context=top_level_query_handler_context_mock, - local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path)) return test_setup def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: @@ -457,15 +542,22 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: result = act(test_setup) stage_1_train_query_handler = test_setup.stages[0].sql_stage_train_query_handler - assert isinstance(stage_1_train_query_handler, HandleQueryResultCreateNewOutputTestSQLStageTrainQueryHandler) and \ - result == stage_1_train_query_handler.continue_result \ - and stage_1_train_query_handler.stage_input_output is None \ - and stage_1_train_query_handler.query_result is None \ - and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 0 + assert ( + isinstance( + stage_1_train_query_handler, + HandleQueryResultCreateNewOutputTestSQLStageTrainQueryHandler, + ) + and result == stage_1_train_query_handler.continue_result + and stage_1_train_query_handler.stage_input_output is None + and stage_1_train_query_handler.query_result is None + and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) + == 0 + ) def test_handle_query_result_with_single_stage_with_handle_query_result_create_new_output_train_query_handler_part2( - top_level_query_handler_context_mock, tmp_path): + top_level_query_handler_context_mock, tmp_path +): """ This test uses test_start_with_single_stage_with_handle_query_result_create_new_output_train_query_handler_part1 as setup and runs handle_query_result on the SQLStageGraphExecutionQueryHandler. @@ -482,21 +574,22 @@ def test_handle_query_result_with_single_stage_with_handle_query_result_create_n def arrange() -> Tuple[TestSetup, QueryResult]: stage1 = TestSQLStage( index=1, - train_query_handler_factory=HandleQueryResultCreateNewOutputTestSQLStageTrainQueryHandler) - sql_stage_graph = SQLStageGraph( - start_node=stage1, - end_node=stage1, - edges=[] + train_query_handler_factory=HandleQueryResultCreateNewOutputTestSQLStageTrainQueryHandler, + ) + sql_stage_graph = SQLStageGraph(start_node=stage1, end_node=stage1, edges=[]) + test_setup = create_test_setup( + sql_stage_graph=sql_stage_graph, + stages=[stage1], + context=top_level_query_handler_context_mock, + local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path), ) - test_setup = create_test_setup(sql_stage_graph=sql_stage_graph, - stages=[stage1], - context=top_level_query_handler_context_mock, - local_fs_mock_bucket_fs_tmp_path=PurePosixPath(tmp_path)) test_setup.query_handler.start() query_result: QueryResult = Mock() return test_setup, query_result - def act(test_setup: TestSetup, query_result: QueryResult) -> Union[Continue, Finish[SQLStageInputOutput]]: + def act( + test_setup: TestSetup, query_result: QueryResult + ) -> Union[Continue, Finish[SQLStageInputOutput]]: result = test_setup.query_handler.handle_query_result(query_result) return result @@ -504,16 +597,25 @@ def act(test_setup: TestSetup, query_result: QueryResult) -> Union[Continue, Fin result = act(test_setup, query_result) stage_1_train_query_handler = test_setup.stages[0].sql_stage_train_query_handler - assert isinstance(result, Finish) \ - and isinstance(stage_1_train_query_handler, - HandleQueryResultCreateNewOutputTestSQLStageTrainQueryHandler) and \ - result.result == stage_1_train_query_handler.stage_input_output \ - and query_result == stage_1_train_query_handler.query_result \ - and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 0 + assert ( + isinstance(result, Finish) + and isinstance( + stage_1_train_query_handler, + HandleQueryResultCreateNewOutputTestSQLStageTrainQueryHandler, + ) + and result.result == stage_1_train_query_handler.stage_input_output + and query_result == stage_1_train_query_handler.query_result + and len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) + == 0 + ) if isinstance(result, Finish) and isinstance(result.result, SQLStageInputOutput): with not_raises(Exception): - name = result.result.dataset.data_partitions[TestDatasetPartitionName.TRAIN].table_like.name + name = result.result.dataset.data_partitions[ + TestDatasetPartitionName.TRAIN + ].table_like.name test_setup.child_query_handler_context.release() - assert len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 1 + assert ( + len(top_level_query_handler_context_mock.cleanup_released_object_proxies()) == 1 + ) diff --git a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_query_handler_using_mock_state.py b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_query_handler_using_mock_state.py index c0f187d6..92cf5597 100644 --- a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_query_handler_using_mock_state.py +++ b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_query_handler_using_mock_state.py @@ -1,27 +1,42 @@ import dataclasses -from typing import Union, List, Tuple -from unittest.mock import MagicMock, create_autospec, Mock, call +from typing import List, Tuple, Union +from unittest.mock import MagicMock, Mock, call, create_autospec from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.query_handler import QueryHandler -from exasol.analytics.query_handler.result import Finish, Continue +from exasol.analytics.query_handler.graph.stage.sql.execution.input import ( + SQLStageGraphExecutionInput, +) +from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler import ( + SQLStageGraphExecutionQueryHandler, + SQLStageGraphExecutionQueryHandlerStateFactory, +) +from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import ( + ResultHandlerReturnValue, + SQLStageGraphExecutionQueryHandlerState, +) +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) +from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import ( + SQLStageQueryHandler, +) from exasol.analytics.query_handler.query.result.interface import QueryResult +from exasol.analytics.query_handler.query_handler import QueryHandler +from exasol.analytics.query_handler.result import Continue, Finish +from tests.utils.mock_cast import mock_cast -from exasol.analytics.query_handler.graph.stage.sql.execution.input import SQLStageGraphExecutionInput -from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler import SQLStageGraphExecutionQueryHandler, SQLStageGraphExecutionQueryHandlerStateFactory -from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import SQLStageGraphExecutionQueryHandlerState, ResultHandlerReturnValue -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput -from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import SQLStageQueryHandler -from tests.mock_cast import mock_cast - -MockSQLStageGraphExecutionQueryHandlerState = Union[SQLStageGraphExecutionQueryHandlerState, MagicMock] +MockSQLStageGraphExecutionQueryHandlerState = Union[ + SQLStageGraphExecutionQueryHandlerState, MagicMock +] MockScopeQueryHandlerContext = Union[ScopeQueryHandlerContext, MagicMock] MockSQLStageTrainQueryHandler = Union[SQLStageQueryHandler, MagicMock] MockQueryHandlerResult = Union[Continue, Finish, MagicMock] MockSQLStageGraphExecutionInput = Union[SQLStageGraphExecutionInput, MagicMock] MockSQLStageInputOutput = Union[SQLStageInputOutput, MagicMock] MockQueryResult = Union[QueryResult, MagicMock] -MockSQLStageGraphExecutionQueryHandlerStateFactory = Union[SQLStageGraphExecutionQueryHandlerStateFactory, Mock] +MockSQLStageGraphExecutionQueryHandlerStateFactory = Union[ + SQLStageGraphExecutionQueryHandlerStateFactory, Mock +] @dataclasses.dataclass @@ -40,10 +55,13 @@ class TrainQueryHandlerSetupDefinition: result_prototypes: List[Union[Continue, Finish]] def create_mock_setup(self) -> TrainQueryHandlerMockSetup: - results: List[MockQueryHandlerResult] = [create_autospec(result_prototype) - for result_prototype in - self.result_prototypes] - train_query_handler: MockSQLStageTrainQueryHandler = create_autospec(QueryHandler) + results: List[MockQueryHandlerResult] = [ + create_autospec(result_prototype) + for result_prototype in self.result_prototypes + ] + train_query_handler: MockSQLStageTrainQueryHandler = create_autospec( + QueryHandler + ) train_query_handler.start.side_effect = [results[0]] train_query_handler.handle_query_result.side_effect = results[1:] return TrainQueryHandlerMockSetup(results, train_query_handler) @@ -65,25 +83,34 @@ class StateSetupDefinition: train_query_handler_setup_definitions: List[TrainQueryHandlerSetupDefinition] def create_mock_setup(self) -> StateMockSetup: - train_query_handler_mock_setups = \ - [train_query_handler_setup_definition.create_mock_setup() - for train_query_handler_setup_definition - in self.train_query_handler_setup_definitions] - state: MockSQLStageGraphExecutionQueryHandlerState = create_autospec(SQLStageGraphExecutionQueryHandlerState) - train_query_handlers = [train_query_handler_mock_setup.train_query_handler - for train_query_handler_mock_setup in train_query_handler_mock_setups - for _ in train_query_handler_mock_setup.results] + train_query_handler_mock_setups = [ + train_query_handler_setup_definition.create_mock_setup() + for train_query_handler_setup_definition in self.train_query_handler_setup_definitions + ] + state: MockSQLStageGraphExecutionQueryHandlerState = create_autospec( + SQLStageGraphExecutionQueryHandlerState + ) + train_query_handlers = [ + train_query_handler_mock_setup.train_query_handler + for train_query_handler_mock_setup in train_query_handler_mock_setups + for _ in train_query_handler_mock_setup.results + ] state.get_current_query_handler.side_effect = train_query_handlers - result_handler_return_values = [self._create_result_handler_return_value(result) - for train_query_handler_mock_setup in train_query_handler_mock_setups - for result in train_query_handler_mock_setup.results] + result_handler_return_values = [ + self._create_result_handler_return_value(result) + for train_query_handler_mock_setup in train_query_handler_mock_setups + for result in train_query_handler_mock_setup.results + ] result_handler_return_values[-1] = ResultHandlerReturnValue.RETURN_RESULT state.handle_result.side_effect = result_handler_return_values - return StateMockSetup(train_query_handler_mock_setups=train_query_handler_mock_setups, - state=state) + return StateMockSetup( + train_query_handler_mock_setups=train_query_handler_mock_setups, state=state + ) @staticmethod - def _create_result_handler_return_value(result: MockQueryHandlerResult) -> ResultHandlerReturnValue: + def _create_result_handler_return_value( + result: MockQueryHandlerResult, + ) -> ResultHandlerReturnValue: if isinstance(result, Continue): return ResultHandlerReturnValue.RETURN_RESULT elif isinstance(result, Finish): @@ -110,36 +137,47 @@ def reset_mock(self): def create_test_setup(state_setup_definition: StateSetupDefinition) -> TestSetup: state_mock_setup = state_setup_definition.create_mock_setup() - mock_scope_query_handler_context: MockScopeQueryHandlerContext = create_autospec(ScopeQueryHandlerContext) - mock_execution_input: MockSQLStageGraphExecutionInput = create_autospec(SQLStageGraphExecutionInput) + mock_scope_query_handler_context: MockScopeQueryHandlerContext = create_autospec( + ScopeQueryHandlerContext + ) + mock_execution_input: MockSQLStageGraphExecutionInput = create_autospec( + SQLStageGraphExecutionInput + ) mock_state_factory: MockSQLStageGraphExecutionQueryHandlerStateFactory = Mock() mock_state_factory.return_value = state_mock_setup.state - execution_query_handler = SQLStageGraphExecutionQueryHandler(parameter=mock_execution_input, - query_handler_context=mock_scope_query_handler_context, - query_handler_state_factory=mock_state_factory) + execution_query_handler = SQLStageGraphExecutionQueryHandler( + parameter=mock_execution_input, + query_handler_context=mock_scope_query_handler_context, + query_handler_state_factory=mock_state_factory, + ) return TestSetup( execution_query_handler=execution_query_handler, mock_state_factory=mock_state_factory, mock_execution_input=mock_execution_input, mock_scope_query_handler_context=mock_scope_query_handler_context, - state_mock_setup=state_mock_setup + state_mock_setup=state_mock_setup, ) -def create_test_setup_with_two_train_query_handler_returning_continue_finish() -> TestSetup: +def create_test_setup_with_two_train_query_handler_returning_continue_finish() -> ( + TestSetup +): state_setup_definition = StateSetupDefinition( train_query_handler_setup_definitions=[ TrainQueryHandlerSetupDefinition( result_prototypes=[ Continue(query_list=None, input_query=None), - Finish(result=None) - ]), + Finish(result=None), + ] + ), TrainQueryHandlerSetupDefinition( result_prototypes=[ Continue(query_list=None, input_query=None), - Finish(result=None) - ]) - ]) + Finish(result=None), + ] + ), + ] + ) test_setup = create_test_setup(state_setup_definition) test_setup.reset_mock() return test_setup @@ -156,8 +194,10 @@ def arrange() -> StateSetupDefinition: state_setup_definition = StateSetupDefinition( train_query_handler_setup_definitions=[ TrainQueryHandlerSetupDefinition( - result_prototypes=[Finish(result=None)]) - ]) + result_prototypes=[Finish(result=None)] + ) + ] + ) return state_setup_definition def act(state_setup_definition: StateSetupDefinition) -> TestSetup: @@ -168,10 +208,12 @@ def act(state_setup_definition: StateSetupDefinition) -> TestSetup: test_setup = act(state_setup_definiton) test_setup.mock_state_factory.assert_called_once_with( - test_setup.mock_execution_input, - test_setup.mock_scope_query_handler_context) + test_setup.mock_execution_input, test_setup.mock_scope_query_handler_context + ) test_setup.state_mock_setup.state.assert_not_called() - test_setup.state_mock_setup.train_query_handler_mock_setups[0].train_query_handler.assert_not_called() + test_setup.state_mock_setup.train_query_handler_mock_setups[ + 0 + ].train_query_handler.assert_not_called() def test_start_single_train_query_handler_returning_finish(): @@ -189,8 +231,10 @@ def arrange() -> TestSetup: state_setup_definition = StateSetupDefinition( train_query_handler_setup_definitions=[ TrainQueryHandlerSetupDefinition( - result_prototypes=[Finish(result=None)]) - ]) + result_prototypes=[Finish(result=None)] + ) + ] + ) test_setup = create_test_setup(state_setup_definition) test_setup.reset_mock() return test_setup @@ -202,13 +246,19 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: test_setup = arrange() result = act(test_setup) - train_query_handler_mock_setup = test_setup.state_mock_setup.train_query_handler_mock_setups[0] - test_setup.state_mock_setup.state.assert_has_calls([ - call.get_current_query_handler(), - call.handle_result(train_query_handler_mock_setup.results[0]) - ]) + train_query_handler_mock_setup = ( + test_setup.state_mock_setup.train_query_handler_mock_setups[0] + ) + test_setup.state_mock_setup.state.assert_has_calls( + [ + call.get_current_query_handler(), + call.handle_result(train_query_handler_mock_setup.results[0]), + ] + ) train_query_handler_mock_setup.train_query_handler.assert_has_calls([call.start()]) - mock_cast(train_query_handler_mock_setup.train_query_handler.handle_query_result).assert_not_called() + mock_cast( + train_query_handler_mock_setup.train_query_handler.handle_query_result + ).assert_not_called() assert result == train_query_handler_mock_setup.results[0] @@ -227,8 +277,12 @@ def arrange() -> TestSetup: state_setup_definition = StateSetupDefinition( train_query_handler_setup_definitions=[ TrainQueryHandlerSetupDefinition( - result_prototypes=[Continue(query_list=None, input_query=None), ]) - ]) + result_prototypes=[ + Continue(query_list=None, input_query=None), + ] + ) + ] + ) test_setup = create_test_setup(state_setup_definition) test_setup.reset_mock() return test_setup @@ -240,14 +294,21 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: test_setup = arrange() result = act(test_setup) - train_query_handler_mock_setup = test_setup.state_mock_setup.train_query_handler_mock_setups[0] - test_setup.state_mock_setup.state.assert_has_calls([ - call.get_current_query_handler(), - call.handle_result(train_query_handler_mock_setup.results[0]) - ]) + train_query_handler_mock_setup = ( + test_setup.state_mock_setup.train_query_handler_mock_setups[0] + ) + test_setup.state_mock_setup.state.assert_has_calls( + [ + call.get_current_query_handler(), + call.handle_result(train_query_handler_mock_setup.results[0]), + ] + ) + mock_cast( + train_query_handler_mock_setup.train_query_handler.start + ).assert_called_once() mock_cast( - train_query_handler_mock_setup.train_query_handler.start).assert_called_once() - mock_cast(train_query_handler_mock_setup.train_query_handler.handle_query_result).assert_not_called() + train_query_handler_mock_setup.train_query_handler.handle_query_result + ).assert_not_called() assert result == train_query_handler_mock_setup.results[0] @@ -269,31 +330,41 @@ def arrange() -> Tuple[TestSetup, MockQueryResult]: TrainQueryHandlerSetupDefinition( result_prototypes=[ Continue(query_list=None, input_query=None), - Finish(result=None) - ]) - ]) + Finish(result=None), + ] + ) + ] + ) test_setup = create_test_setup(state_setup_definition) test_setup.execution_query_handler.start() test_setup.reset_mock() query_result: MockQueryResult = create_autospec(QueryResult) return test_setup, query_result - def act(test_setup: TestSetup, query_Result: MockQueryResult) \ - -> Union[Continue, Finish[SQLStageInputOutput]]: + def act( + test_setup: TestSetup, query_Result: MockQueryResult + ) -> Union[Continue, Finish[SQLStageInputOutput]]: result = test_setup.execution_query_handler.handle_query_result(query_result) return result test_setup, query_result = arrange() result = act(test_setup, query_result) - train_query_handler_mock_setup = test_setup.state_mock_setup.train_query_handler_mock_setups[0] - test_setup.state_mock_setup.state.assert_has_calls([ - call.get_current_query_handler(), - call.handle_result(train_query_handler_mock_setup.results[1]) - ]) - mock_cast(train_query_handler_mock_setup.train_query_handler.handle_query_result) \ - .assert_called_once_with(query_result) - mock_cast(train_query_handler_mock_setup.train_query_handler.start).assert_not_called() + train_query_handler_mock_setup = ( + test_setup.state_mock_setup.train_query_handler_mock_setups[0] + ) + test_setup.state_mock_setup.state.assert_has_calls( + [ + call.get_current_query_handler(), + call.handle_result(train_query_handler_mock_setup.results[1]), + ] + ) + mock_cast( + train_query_handler_mock_setup.train_query_handler.handle_query_result + ).assert_called_once_with(query_result) + mock_cast( + train_query_handler_mock_setup.train_query_handler.start + ).assert_not_called() assert result == train_query_handler_mock_setup.results[1] @@ -314,9 +385,14 @@ def test_start_two_train_query_handler_returning_finish(): def arrange() -> TestSetup: state_setup_definition = StateSetupDefinition( train_query_handler_setup_definitions=[ - TrainQueryHandlerSetupDefinition(result_prototypes=[Finish(result=None)]), - TrainQueryHandlerSetupDefinition(result_prototypes=[Finish(result=None)]) - ]) + TrainQueryHandlerSetupDefinition( + result_prototypes=[Finish(result=None)] + ), + TrainQueryHandlerSetupDefinition( + result_prototypes=[Finish(result=None)] + ), + ] + ) test_setup = create_test_setup(state_setup_definition) test_setup.reset_mock() return test_setup @@ -328,15 +404,23 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: test_setup = arrange() result = act(test_setup) - train_query_handler_mock_setups = test_setup.state_mock_setup.train_query_handler_mock_setups - test_setup.state_mock_setup.state.assert_has_calls([ - call.get_current_query_handler(), - call.handle_result(train_query_handler_mock_setups[0].results[0]), - call.get_current_query_handler(), - call.handle_result(train_query_handler_mock_setups[1].results[0]) - ]) - mock_cast(train_query_handler_mock_setups[0].train_query_handler.start).assert_called_once() - mock_cast(train_query_handler_mock_setups[1].train_query_handler.start).assert_called_once() + train_query_handler_mock_setups = ( + test_setup.state_mock_setup.train_query_handler_mock_setups + ) + test_setup.state_mock_setup.state.assert_has_calls( + [ + call.get_current_query_handler(), + call.handle_result(train_query_handler_mock_setups[0].results[0]), + call.get_current_query_handler(), + call.handle_result(train_query_handler_mock_setups[1].results[0]), + ] + ) + mock_cast( + train_query_handler_mock_setups[0].train_query_handler.start + ).assert_called_once() + mock_cast( + train_query_handler_mock_setups[1].train_query_handler.start + ).assert_called_once() assert result == train_query_handler_mock_setups[1].results[0] @@ -353,7 +437,9 @@ def test_start_two_train_query_handler_returning_continue_finish_part1(): """ def arrange() -> TestSetup: - test_setup = create_test_setup_with_two_train_query_handler_returning_continue_finish() + test_setup = ( + create_test_setup_with_two_train_query_handler_returning_continue_finish() + ) test_setup.reset_mock() return test_setup @@ -364,14 +450,24 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]: test_setup = arrange() result = act(test_setup) - train_query_handler_mock_setups = test_setup.state_mock_setup.train_query_handler_mock_setups - test_setup.state_mock_setup.state.assert_has_calls([ - call.get_current_query_handler(), - call.handle_result(train_query_handler_mock_setups[0].results[0]), - ]) - mock_cast(train_query_handler_mock_setups[0].train_query_handler.start).assert_called_once() - mock_cast(train_query_handler_mock_setups[1].train_query_handler.start).assert_not_called() - mock_cast(train_query_handler_mock_setups[1].train_query_handler.handle_query_result).assert_not_called() + train_query_handler_mock_setups = ( + test_setup.state_mock_setup.train_query_handler_mock_setups + ) + test_setup.state_mock_setup.state.assert_has_calls( + [ + call.get_current_query_handler(), + call.handle_result(train_query_handler_mock_setups[0].results[0]), + ] + ) + mock_cast( + train_query_handler_mock_setups[0].train_query_handler.start + ).assert_called_once() + mock_cast( + train_query_handler_mock_setups[1].train_query_handler.start + ).assert_not_called() + mock_cast( + train_query_handler_mock_setups[1].train_query_handler.handle_query_result + ).assert_not_called() assert result == train_query_handler_mock_setups[0].results[0] @@ -390,31 +486,46 @@ def test_handle_query_result_two_train_query_handler_returning_continue_finish_p """ def arrange() -> Tuple[TestSetup, QueryResult]: - test_setup = create_test_setup_with_two_train_query_handler_returning_continue_finish() + test_setup = ( + create_test_setup_with_two_train_query_handler_returning_continue_finish() + ) test_setup.execution_query_handler.start() test_setup.reset_mock() query_result: MockQueryResult = create_autospec(QueryResult) return test_setup, query_result - def act(test_setup: TestSetup, query_result: QueryResult) \ - -> Union[Continue, Finish[SQLStageInputOutput]]: + def act( + test_setup: TestSetup, query_result: QueryResult + ) -> Union[Continue, Finish[SQLStageInputOutput]]: result = test_setup.execution_query_handler.handle_query_result(query_result) return result test_setup, query_result = arrange() result = act(test_setup, query_result) - train_query_handler_mock_setups = test_setup.state_mock_setup.train_query_handler_mock_setups - test_setup.state_mock_setup.state.assert_has_calls([ - call.get_current_query_handler(), - call.handle_result(train_query_handler_mock_setups[0].results[1]), - call.get_current_query_handler(), - call.handle_result(train_query_handler_mock_setups[1].results[0]) - ]) - mock_cast(train_query_handler_mock_setups[0].train_query_handler.start).assert_not_called() - mock_cast(train_query_handler_mock_setups[0].train_query_handler.handle_query_result).assert_called_once() - mock_cast(train_query_handler_mock_setups[1].train_query_handler.start).assert_called_once() - mock_cast(train_query_handler_mock_setups[1].train_query_handler.handle_query_result).assert_not_called() + train_query_handler_mock_setups = ( + test_setup.state_mock_setup.train_query_handler_mock_setups + ) + test_setup.state_mock_setup.state.assert_has_calls( + [ + call.get_current_query_handler(), + call.handle_result(train_query_handler_mock_setups[0].results[1]), + call.get_current_query_handler(), + call.handle_result(train_query_handler_mock_setups[1].results[0]), + ] + ) + mock_cast( + train_query_handler_mock_setups[0].train_query_handler.start + ).assert_not_called() + mock_cast( + train_query_handler_mock_setups[0].train_query_handler.handle_query_result + ).assert_called_once() + mock_cast( + train_query_handler_mock_setups[1].train_query_handler.start + ).assert_called_once() + mock_cast( + train_query_handler_mock_setups[1].train_query_handler.handle_query_result + ).assert_not_called() assert result == train_query_handler_mock_setups[1].results[0] @@ -430,7 +541,9 @@ def test_handle_query_result_two_train_query_handler_returning_continue_finish_p """ def arrange() -> Tuple[TestSetup, QueryResult]: - test_setup = create_test_setup_with_two_train_query_handler_returning_continue_finish() + test_setup = ( + create_test_setup_with_two_train_query_handler_returning_continue_finish() + ) query_result1: MockQueryResult = create_autospec(QueryResult) test_setup.execution_query_handler.start() test_setup.execution_query_handler.handle_query_result(query_result1) @@ -438,22 +551,34 @@ def arrange() -> Tuple[TestSetup, QueryResult]: query_result2: MockQueryResult = create_autospec(QueryResult) return test_setup, query_result2 - def act(test_setup: TestSetup, query_result: QueryResult) \ - -> Union[Continue, Finish[SQLStageInputOutput]]: + def act( + test_setup: TestSetup, query_result: QueryResult + ) -> Union[Continue, Finish[SQLStageInputOutput]]: result = test_setup.execution_query_handler.handle_query_result(query_result) return result test_setup, query_result = arrange() result = act(test_setup, query_result) - train_query_handler_mock_setups = test_setup.state_mock_setup.train_query_handler_mock_setups - test_setup.state_mock_setup.state.assert_has_calls([ - call.get_current_query_handler(), - call.handle_result(train_query_handler_mock_setups[1].results[1]) - ]) - mock_cast(train_query_handler_mock_setups[0].train_query_handler.start).assert_not_called() - mock_cast(train_query_handler_mock_setups[0].train_query_handler.handle_query_result).assert_not_called() - mock_cast(train_query_handler_mock_setups[1].train_query_handler.start).assert_not_called() - mock_cast(train_query_handler_mock_setups[1].train_query_handler.handle_query_result) \ - .assert_called_once_with(query_result) + train_query_handler_mock_setups = ( + test_setup.state_mock_setup.train_query_handler_mock_setups + ) + test_setup.state_mock_setup.state.assert_has_calls( + [ + call.get_current_query_handler(), + call.handle_result(train_query_handler_mock_setups[1].results[1]), + ] + ) + mock_cast( + train_query_handler_mock_setups[0].train_query_handler.start + ).assert_not_called() + mock_cast( + train_query_handler_mock_setups[0].train_query_handler.handle_query_result + ).assert_not_called() + mock_cast( + train_query_handler_mock_setups[1].train_query_handler.start + ).assert_not_called() + mock_cast( + train_query_handler_mock_setups[1].train_query_handler.handle_query_result + ).assert_called_once_with(query_result) assert result == train_query_handler_mock_setups[1].results[1] diff --git a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_diamond.py b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_diamond.py index 7de767b8..8f6ccee6 100644 --- a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_diamond.py +++ b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_diamond.py @@ -1,26 +1,38 @@ -from typing import Union, List +from typing import List, Union from unittest.mock import MagicMock, Mock import pytest -from exasol.analytics.query_handler.query_handler import QueryHandler -from exasol.analytics.query_handler.result import Finish, Continue +from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import ( + ResultHandlerReturnValue, +) +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) from exasol.analytics.query_handler.graph.stage.sql.sql_stage_graph import SQLStageGraph -from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import ResultHandlerReturnValue -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput -from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.assert_helper import \ - assert_reference_counting_bag_creation, assert_stage_train_query_handler_created, assert_stage_not_called, \ - assert_reference_counting_bag_not_called, assert_release_on_query_handler_context_for_stage, \ - assert_parent_query_handler_context_not_called -from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.state_test_setup import TestSetup, \ - create_mocks_for_stage, create_execution_query_handler_state_setup +from exasol.analytics.query_handler.query_handler import QueryHandler +from exasol.analytics.query_handler.result import Continue, Finish +from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.assert_helper import ( + assert_parent_query_handler_context_not_called, + assert_reference_counting_bag_creation, + assert_reference_counting_bag_not_called, + assert_release_on_query_handler_context_for_stage, + assert_stage_not_called, + assert_stage_train_query_handler_created, +) +from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.state_test_setup import ( + TestSetup, + create_execution_query_handler_state_setup, + create_mocks_for_stage, +) def create_diamond_setup( - stage1_result_prototypes: List[Union[Continue, Finish, MagicMock]], - stage2_result_prototypes: List[Union[Continue, Finish, MagicMock]], - stage3_result_prototypes: List[Union[Continue, Finish, MagicMock]], - stage4_result_prototypes: List[Union[Continue, Finish, MagicMock]]) -> TestSetup: + stage1_result_prototypes: List[Union[Continue, Finish, MagicMock]], + stage2_result_prototypes: List[Union[Continue, Finish, MagicMock]], + stage3_result_prototypes: List[Union[Continue, Finish, MagicMock]], + stage4_result_prototypes: List[Union[Continue, Finish, MagicMock]], +) -> TestSetup: stage1_setup = create_mocks_for_stage(stage1_result_prototypes, stage_index=0) stage2_setup = create_mocks_for_stage(stage2_result_prototypes, stage_index=1) stage3_setup = create_mocks_for_stage(stage3_result_prototypes, stage_index=2) @@ -32,22 +44,22 @@ def create_diamond_setup( (stage1_setup.stage, stage2_setup.stage), (stage1_setup.stage, stage3_setup.stage), (stage2_setup.stage, stage4_setup.stage), - (stage3_setup.stage, stage4_setup.stage) - } + (stage3_setup.stage, stage4_setup.stage), + }, ) mock_compute_dependency_order = Mock() - mock_compute_dependency_order.return_value = [stage1_setup.stage, - stage2_setup.stage, - stage3_setup.stage, - stage4_setup.stage] + mock_compute_dependency_order.return_value = [ + stage1_setup.stage, + stage2_setup.stage, + stage3_setup.stage, + stage4_setup.stage, + ] sql_stage_graph.compute_dependency_order = mock_compute_dependency_order stage_setups = [stage1_setup, stage2_setup, stage3_setup, stage4_setup] state_setup = create_execution_query_handler_state_setup( - sql_stage_graph, stage_setups) - return TestSetup( - stage_setups=stage_setups, - state_setup=state_setup + sql_stage_graph, stage_setups ) + return TestSetup(stage_setups=stage_setups, state_setup=state_setup) def create_diamond_setup_with_finish() -> TestSetup: @@ -55,7 +67,7 @@ def create_diamond_setup_with_finish() -> TestSetup: stage1_result_prototypes=[Finish(result=None)], stage2_result_prototypes=[Finish(result=None)], stage3_result_prototypes=[Finish(result=None)], - stage4_result_prototypes=[Finish(result=None)] + stage4_result_prototypes=[Finish(result=None)], ) return test_setup @@ -70,8 +82,12 @@ def arrange() -> TestSetup: test_setup = create_diamond_setup_with_finish() return test_setup - def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: - result = test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + def act( + test_setup: TestSetup, + ) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: + result = ( + test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + ) return result test_setup = arrange() @@ -79,8 +95,10 @@ def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLSta assert_reference_counting_bag_creation(test_setup) assert_stage_train_query_handler_created( - test_setup, stage_index=0, - stage_inputs=[test_setup.state_setup.sql_stage_input_output]) + test_setup, + stage_index=0, + stage_inputs=[test_setup.state_setup.sql_stage_input_output], + ) assert_stage_not_called(test_setup, stage_index=1) assert_stage_not_called(test_setup, stage_index=2) assert_stage_not_called(test_setup, stage_index=3) @@ -99,14 +117,17 @@ def test_handle_result_diamond_return_finish_part2(): def arrange() -> TestSetup: test_setup = create_diamond_setup_with_finish() - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() test_setup.reset_mock() return test_setup def act(test_setup: TestSetup) -> ResultHandlerReturnValue: result = test_setup.state_setup.execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[0]) + test_setup.stage_setups[0].results[0] + ) return result test_setup = arrange() @@ -115,8 +136,10 @@ def act(test_setup: TestSetup) -> ResultHandlerReturnValue: assert_reference_counting_bag_not_called(test_setup) assert_release_on_query_handler_context_for_stage(test_setup, stage_index=0) assert_stage_train_query_handler_created( - test_setup, stage_index=1, - stage_inputs=[test_setup.stage_setups[0].results[0].result]) + test_setup, + stage_index=1, + stage_inputs=[test_setup.stage_setups[0].results[0].result], + ) assert_stage_not_called(test_setup, stage_index=2) assert_stage_not_called(test_setup, stage_index=3) assert result == ResultHandlerReturnValue.CONTINUE_PROCESSING @@ -131,15 +154,22 @@ def test_get_current_query_handler_diamond_return_finish_part3(): def arrange() -> TestSetup: test_setup = create_diamond_setup_with_finish() - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[0]) + test_setup.stage_setups[0].results[0] + ) test_setup.reset_mock() return test_setup - def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: - result = test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + def act( + test_setup: TestSetup, + ) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: + result = ( + test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + ) return result test_setup = arrange() @@ -166,17 +196,21 @@ def test_handle_result_diamond_return_finish_part4(): def arrange() -> TestSetup: test_setup = create_diamond_setup_with_finish() - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[0]) + test_setup.stage_setups[0].results[0] + ) execution_query_handler_state.get_current_query_handler() test_setup.reset_mock() return test_setup def act(test_setup: TestSetup) -> ResultHandlerReturnValue: result = test_setup.state_setup.execution_query_handler_state.handle_result( - test_setup.stage_setups[1].results[0]) + test_setup.stage_setups[1].results[0] + ) return result test_setup = arrange() @@ -186,8 +220,10 @@ def act(test_setup: TestSetup) -> ResultHandlerReturnValue: assert_stage_not_called(test_setup, stage_index=0) assert_release_on_query_handler_context_for_stage(test_setup, stage_index=1) assert_stage_train_query_handler_created( - test_setup, stage_index=2, - stage_inputs=[test_setup.stage_setups[0].results[0].result]) + test_setup, + stage_index=2, + stage_inputs=[test_setup.stage_setups[0].results[0].result], + ) assert_stage_not_called(test_setup, stage_index=3) assert result == ResultHandlerReturnValue.CONTINUE_PROCESSING @@ -201,18 +237,26 @@ def test_get_current_query_handler_diamond_return_finish_part5(): def arrange() -> TestSetup: test_setup = create_diamond_setup_with_finish() - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[0]) + test_setup.stage_setups[0].results[0] + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[1].results[0]) + test_setup.stage_setups[1].results[0] + ) test_setup.reset_mock() return test_setup - def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: - result = test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + def act( + test_setup: TestSetup, + ) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: + result = ( + test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + ) return result test_setup = arrange() @@ -238,20 +282,25 @@ def test_handle_result_diamond_return_finish_part6(): def arrange() -> TestSetup: test_setup = create_diamond_setup_with_finish() - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[0]) + test_setup.stage_setups[0].results[0] + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[1].results[0]) + test_setup.stage_setups[1].results[0] + ) execution_query_handler_state.get_current_query_handler() test_setup.reset_mock() return test_setup def act(test_setup: TestSetup) -> ResultHandlerReturnValue: result = test_setup.state_setup.execution_query_handler_state.handle_result( - test_setup.stage_setups[2].results[0]) + test_setup.stage_setups[2].results[0] + ) return result test_setup = arrange() @@ -262,9 +311,13 @@ def act(test_setup: TestSetup) -> ResultHandlerReturnValue: assert_stage_not_called(test_setup, stage_index=1) assert_release_on_query_handler_context_for_stage(test_setup, stage_index=2) assert_stage_train_query_handler_created( - test_setup, stage_index=3, - stage_inputs=[test_setup.stage_setups[1].results[0].result, - test_setup.stage_setups[2].results[0].result]) + test_setup, + stage_index=3, + stage_inputs=[ + test_setup.stage_setups[1].results[0].result, + test_setup.stage_setups[2].results[0].result, + ], + ) assert result == ResultHandlerReturnValue.CONTINUE_PROCESSING @@ -277,21 +330,30 @@ def test_get_current_query_handler_diamond_return_finish_part7(): def arrange() -> TestSetup: test_setup = create_diamond_setup_with_finish() - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[0]) + test_setup.stage_setups[0].results[0] + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[1].results[0]) + test_setup.stage_setups[1].results[0] + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[2].results[0]) + test_setup.stage_setups[2].results[0] + ) test_setup.reset_mock() return test_setup - def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: - result = test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + def act( + test_setup: TestSetup, + ) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: + result = ( + test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + ) return result test_setup = arrange() @@ -314,23 +376,29 @@ def test_handle_result_diamond_return_finish_part8(): def arrange() -> TestSetup: test_setup = create_diamond_setup_with_finish() - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[0]) + test_setup.stage_setups[0].results[0] + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[1].results[0]) + test_setup.stage_setups[1].results[0] + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[2].results[0]) + test_setup.stage_setups[2].results[0] + ) execution_query_handler_state.get_current_query_handler() test_setup.reset_mock() return test_setup def act(test_setup: TestSetup) -> ResultHandlerReturnValue: result = test_setup.state_setup.execution_query_handler_state.handle_result( - test_setup.stage_setups[3].results[0]) + test_setup.stage_setups[3].results[0] + ) return result test_setup = arrange() @@ -352,19 +420,25 @@ def test_get_current_query_handler_diamond_return_finish_part9(): def arrange() -> TestSetup: test_setup = create_diamond_setup_with_finish() - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[0]) + test_setup.stage_setups[0].results[0] + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[1].results[0]) + test_setup.stage_setups[1].results[0] + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[2].results[0]) + test_setup.stage_setups[2].results[0] + ) execution_query_handler_state.get_current_query_handler() execution_query_handler_state.handle_result( - test_setup.stage_setups[3].results[0]) + test_setup.stage_setups[3].results[0] + ) test_setup.reset_mock() return test_setup diff --git a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_reference_counting.py b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_reference_counting.py index 1013be95..06cc9804 100644 --- a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_reference_counting.py +++ b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_reference_counting.py @@ -1,22 +1,27 @@ import dataclasses -from typing import Union, List, Dict -from unittest.mock import MagicMock, Mock, create_autospec, call +from typing import Dict, List, Union +from unittest.mock import MagicMock, Mock, call, create_autospec from exasol.analytics.query_handler.context.proxy.object_proxy import ObjectProxy -from exasol.analytics.query_handler.result import Finish, Continue - +from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import ( + ResultHandlerReturnValue, +) from exasol.analytics.query_handler.graph.stage.sql.sql_stage_graph import SQLStageGraph -from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import ResultHandlerReturnValue -from tests.mock_cast import mock_cast -from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.state_test_setup import TestSetup, \ - create_mocks_for_stage, create_execution_query_handler_state_setup +from exasol.analytics.query_handler.result import Continue, Finish +from tests.utils.mock_cast import mock_cast +from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.state_test_setup import ( + TestSetup, + create_execution_query_handler_state_setup, + create_mocks_for_stage, +) def create_diamond_setup( - stage1_result_prototypes: List[Union[Continue, Finish, MagicMock]], - stage2_result_prototypes: List[Union[Continue, Finish, MagicMock]], - stage3_result_prototypes: List[Union[Continue, Finish, MagicMock]], - stage4_result_prototypes: List[Union[Continue, Finish, MagicMock]]) -> TestSetup: + stage1_result_prototypes: List[Union[Continue, Finish, MagicMock]], + stage2_result_prototypes: List[Union[Continue, Finish, MagicMock]], + stage3_result_prototypes: List[Union[Continue, Finish, MagicMock]], + stage4_result_prototypes: List[Union[Continue, Finish, MagicMock]], +) -> TestSetup: stage1_setup = create_mocks_for_stage(stage1_result_prototypes, stage_index=1) stage2_setup = create_mocks_for_stage(stage2_result_prototypes, stage_index=2) stage3_setup = create_mocks_for_stage(stage3_result_prototypes, stage_index=3) @@ -28,24 +33,24 @@ def create_diamond_setup( (stage1_setup.stage, stage2_setup.stage), (stage1_setup.stage, stage3_setup.stage), (stage2_setup.stage, stage4_setup.stage), - (stage3_setup.stage, stage4_setup.stage) - } + (stage3_setup.stage, stage4_setup.stage), + }, ) mock_compute_dependency_order = Mock() - mock_compute_dependency_order.return_value = [stage1_setup.stage, - stage2_setup.stage, - stage3_setup.stage, - stage4_setup.stage] + mock_compute_dependency_order.return_value = [ + stage1_setup.stage, + stage2_setup.stage, + stage3_setup.stage, + stage4_setup.stage, + ] sql_stage_graph.compute_dependency_order = mock_compute_dependency_order stage_setups = [stage1_setup, stage2_setup, stage3_setup, stage4_setup] state_setup = create_execution_query_handler_state_setup( - sql_stage_graph, stage_setups) - - return TestSetup( - stage_setups=stage_setups, - state_setup=state_setup + sql_stage_graph, stage_setups ) + return TestSetup(stage_setups=stage_setups, state_setup=state_setup) + @dataclasses.dataclass class ReferenceCountingSetup: @@ -53,12 +58,14 @@ class ReferenceCountingSetup: object_proxy_dict: Dict[ObjectProxy, int] -def create_diamond_setup_with_finish_with_last_stage_returning_new_result() -> ReferenceCountingSetup: +def create_diamond_setup_with_finish_with_last_stage_returning_new_result() -> ( + ReferenceCountingSetup +): test_setup = create_diamond_setup( stage1_result_prototypes=[Finish(result=None)], stage2_result_prototypes=[Finish(result=None)], stage3_result_prototypes=[Finish(result=None)], - stage4_result_prototypes=[Finish(result=None)] + stage4_result_prototypes=[Finish(result=None)], ) stage1_object_proxy = create_autospec(ObjectProxy) stage2_object_proxy = create_autospec(ObjectProxy) @@ -71,12 +78,14 @@ def create_diamond_setup_with_finish_with_last_stage_returning_new_result() -> R return ReferenceCountingSetup(test_setup, object_proxy_dict) -def create_diamond_setup_with_finish_with_last_stage_returning_existing_result() -> ReferenceCountingSetup: +def create_diamond_setup_with_finish_with_last_stage_returning_existing_result() -> ( + ReferenceCountingSetup +): test_setup = create_diamond_setup( stage1_result_prototypes=[Finish(result=None)], stage2_result_prototypes=[Finish(result=None)], stage3_result_prototypes=[Finish(result=None)], - stage4_result_prototypes=[Finish(result=None)] + stage4_result_prototypes=[Finish(result=None)], ) stage1_object_proxy = create_autospec(ObjectProxy) test_setup.stage_setups[0].results[0].result = stage1_object_proxy @@ -111,32 +120,41 @@ def side_effect_transfer_back(object_proxy): del object_proxy_dict[object_proxy] mock_cast( - test_setup.state_setup.reference_counting_bag_mock_setup.bag.__contains__).side_effect = side_effect_contains + test_setup.state_setup.reference_counting_bag_mock_setup.bag.__contains__ + ).side_effect = side_effect_contains mock_cast( - test_setup.state_setup.reference_counting_bag_mock_setup.bag.add).side_effect = side_effect_add + test_setup.state_setup.reference_counting_bag_mock_setup.bag.add + ).side_effect = side_effect_add mock_cast( - test_setup.state_setup.reference_counting_bag_mock_setup.bag.remove).side_effect = side_effect_remove + test_setup.state_setup.reference_counting_bag_mock_setup.bag.remove + ).side_effect = side_effect_remove mock_cast( - test_setup.state_setup.reference_counting_bag_mock_setup.bag.transfer_back_to_parent_query_handler_context) \ - .side_effect = side_effect_transfer_back + test_setup.state_setup.reference_counting_bag_mock_setup.bag.transfer_back_to_parent_query_handler_context + ).side_effect = side_effect_transfer_back return object_proxy_dict def assert_transfer_from_child_to_parent_query_handler_context( - ref_count_setup: ReferenceCountingSetup, stage_index: int): + ref_count_setup: ReferenceCountingSetup, stage_index: int +): mock_cast( - ref_count_setup.test_setup.stage_setups[stage_index] - .child_query_handler_context.transfer_object_to).assert_called_once_with( + ref_count_setup.test_setup.stage_setups[ + stage_index + ].child_query_handler_context.transfer_object_to + ).assert_called_once_with( ref_count_setup.test_setup.stage_setups[stage_index].results[0].result, - ref_count_setup.test_setup.state_setup.parent_query_handler_context + ref_count_setup.test_setup.state_setup.parent_query_handler_context, ) def assert_no_transfer_from_child_to_parent_query_handler_context( - ref_count_setup: ReferenceCountingSetup, stage_index: int): + ref_count_setup: ReferenceCountingSetup, stage_index: int +): mock_cast( - ref_count_setup.test_setup.stage_setups[stage_index] - .child_query_handler_context.transfer_object_to).assert_not_called() + ref_count_setup.test_setup.stage_setups[ + stage_index + ].child_query_handler_context.transfer_object_to + ).assert_not_called() def test_handle_result_diamond_return_finish_new_result_part1(): @@ -144,25 +162,33 @@ def test_handle_result_diamond_return_finish_new_result_part1(): This test calls handle_result with the result for the first stage on a diamond stage graph where the last stage returns a new result. """ + def arrange() -> ReferenceCountingSetup: - ref_count_setup = create_diamond_setup_with_finish_with_last_stage_returning_new_result() + ref_count_setup = ( + create_diamond_setup_with_finish_with_last_stage_returning_new_result() + ) ref_count_setup.test_setup.reset_mock() return ref_count_setup def act(ref_count_setup: ReferenceCountingSetup) -> ResultHandlerReturnValue: result = ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[0].results[0]) + ref_count_setup.test_setup.stage_setups[0].results[0] + ) return result ref_count_setup = arrange() result = act(ref_count_setup) - assert ref_count_setup.test_setup.state_setup.reference_counting_bag_mock_setup.bag.mock_calls == \ - [ - call.__contains__(ref_count_setup.test_setup.stage_setups[0].results[0].result), - call.add(ref_count_setup.test_setup.stage_setups[0].results[0].result), - call.add(ref_count_setup.test_setup.stage_setups[0].results[0].result) - ] + assert ( + ref_count_setup.test_setup.state_setup.reference_counting_bag_mock_setup.bag.mock_calls + == [ + call.__contains__( + ref_count_setup.test_setup.stage_setups[0].results[0].result + ), + call.add(ref_count_setup.test_setup.stage_setups[0].results[0].result), + call.add(ref_count_setup.test_setup.stage_setups[0].results[0].result), + ] + ) assert_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 0) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 1) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 2) @@ -174,28 +200,39 @@ def test_handle_result_diamond_return_finish_new_result_part2(): This test use test_handle_result_diamond_return_finish_new_result_part1 as setup and calls handle_result with the result of the second stage. """ + def arrange() -> ReferenceCountingSetup: - ref_count_setup = create_diamond_setup_with_finish_with_last_stage_returning_new_result() + ref_count_setup = ( + create_diamond_setup_with_finish_with_last_stage_returning_new_result() + ) ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[0].results[0]) + ref_count_setup.test_setup.stage_setups[0].results[0] + ) ref_count_setup.test_setup.reset_mock() return ref_count_setup def act(ref_count_setup: ReferenceCountingSetup) -> ResultHandlerReturnValue: result = ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[1].results[0]) + ref_count_setup.test_setup.stage_setups[1].results[0] + ) return result ref_count_setup = arrange() result = act(ref_count_setup) - assert ref_count_setup.test_setup.state_setup.reference_counting_bag_mock_setup.bag.mock_calls == \ - [ - call.__contains__(ref_count_setup.test_setup.stage_setups[1].results[0].result), - call.add(ref_count_setup.test_setup.stage_setups[1].results[0].result), - call.__contains__(ref_count_setup.test_setup.stage_setups[0].results[0].result), - call.remove(ref_count_setup.test_setup.stage_setups[0].results[0].result), - ] + assert ( + ref_count_setup.test_setup.state_setup.reference_counting_bag_mock_setup.bag.mock_calls + == [ + call.__contains__( + ref_count_setup.test_setup.stage_setups[1].results[0].result + ), + call.add(ref_count_setup.test_setup.stage_setups[1].results[0].result), + call.__contains__( + ref_count_setup.test_setup.stage_setups[0].results[0].result + ), + call.remove(ref_count_setup.test_setup.stage_setups[0].results[0].result), + ] + ) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 0) assert_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 1) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 2) @@ -207,30 +244,42 @@ def test_handle_result_diamond_return_finish_new_result_part3(): This test use test_handle_result_diamond_return_finish_new_result_part2 as setup and calls handle_result with the result of the third stage. """ + def arrange() -> ReferenceCountingSetup: - ref_count_setup = create_diamond_setup_with_finish_with_last_stage_returning_new_result() + ref_count_setup = ( + create_diamond_setup_with_finish_with_last_stage_returning_new_result() + ) ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[0].results[0]) + ref_count_setup.test_setup.stage_setups[0].results[0] + ) ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[1].results[0]) + ref_count_setup.test_setup.stage_setups[1].results[0] + ) ref_count_setup.test_setup.reset_mock() return ref_count_setup def act(ref_count_setup: ReferenceCountingSetup) -> ResultHandlerReturnValue: result = ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[2].results[0]) + ref_count_setup.test_setup.stage_setups[2].results[0] + ) return result ref_count_setup = arrange() result = act(ref_count_setup) - assert ref_count_setup.test_setup.state_setup.reference_counting_bag_mock_setup.bag.mock_calls == \ - [ - call.__contains__(ref_count_setup.test_setup.stage_setups[0].results[0].result), - call.add(ref_count_setup.test_setup.stage_setups[0].results[0].result), - call.__contains__(ref_count_setup.test_setup.stage_setups[0].results[0].result), - call.remove(ref_count_setup.test_setup.stage_setups[0].results[0].result), - ] + assert ( + ref_count_setup.test_setup.state_setup.reference_counting_bag_mock_setup.bag.mock_calls + == [ + call.__contains__( + ref_count_setup.test_setup.stage_setups[0].results[0].result + ), + call.add(ref_count_setup.test_setup.stage_setups[0].results[0].result), + call.__contains__( + ref_count_setup.test_setup.stage_setups[0].results[0].result + ), + call.remove(ref_count_setup.test_setup.stage_setups[0].results[0].result), + ] + ) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 0) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 1) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 2) @@ -242,31 +291,44 @@ def test_handle_result_diamond_return_finish_new_result_part4(): This test use test_handle_result_diamond_return_finish_new_result_part3 as setup and calls handle_result with the result of the forth stage. """ + def arrange() -> ReferenceCountingSetup: - ref_count_setup = create_diamond_setup_with_finish_with_last_stage_returning_new_result() + ref_count_setup = ( + create_diamond_setup_with_finish_with_last_stage_returning_new_result() + ) ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[0].results[0]) + ref_count_setup.test_setup.stage_setups[0].results[0] + ) ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[1].results[0]) + ref_count_setup.test_setup.stage_setups[1].results[0] + ) ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[2].results[0]) + ref_count_setup.test_setup.stage_setups[2].results[0] + ) ref_count_setup.test_setup.reset_mock() return ref_count_setup def act(ref_count_setup: ReferenceCountingSetup) -> ResultHandlerReturnValue: result = ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[3].results[0]) + ref_count_setup.test_setup.stage_setups[3].results[0] + ) return result ref_count_setup = arrange() result = act(ref_count_setup) print(ref_count_setup.object_proxy_dict) - ref_count_setup.test_setup.state_setup.reference_counting_bag_mock_setup.bag.assert_has_calls([ - call.__contains__(ref_count_setup.test_setup.stage_setups[1].results[0].result), - call.remove(ref_count_setup.test_setup.stage_setups[1].results[0].result), - call.__contains__(ref_count_setup.test_setup.stage_setups[0].results[0].result), - call.remove(ref_count_setup.test_setup.stage_setups[0].results[0].result), - ]) + ref_count_setup.test_setup.state_setup.reference_counting_bag_mock_setup.bag.assert_has_calls( + [ + call.__contains__( + ref_count_setup.test_setup.stage_setups[1].results[0].result + ), + call.remove(ref_count_setup.test_setup.stage_setups[1].results[0].result), + call.__contains__( + ref_count_setup.test_setup.stage_setups[0].results[0].result + ), + call.remove(ref_count_setup.test_setup.stage_setups[0].results[0].result), + ] + ) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 0) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 1) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 2) @@ -282,20 +344,27 @@ def test_handle_result_diamond_return_finish_existing_result(): result. """ + def arrange() -> ReferenceCountingSetup: - ref_count_setup = create_diamond_setup_with_finish_with_last_stage_returning_existing_result() + ref_count_setup = ( + create_diamond_setup_with_finish_with_last_stage_returning_existing_result() + ) ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[0].results[0]) + ref_count_setup.test_setup.stage_setups[0].results[0] + ) ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[1].results[0]) + ref_count_setup.test_setup.stage_setups[1].results[0] + ) ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[2].results[0]) + ref_count_setup.test_setup.stage_setups[2].results[0] + ) ref_count_setup.test_setup.reset_mock() return ref_count_setup def act(ref_count_setup: ReferenceCountingSetup) -> ResultHandlerReturnValue: result = ref_count_setup.test_setup.state_setup.execution_query_handler_state.handle_result( - ref_count_setup.test_setup.stage_setups[3].results[0]) + ref_count_setup.test_setup.stage_setups[3].results[0] + ) return result ref_count_setup = arrange() @@ -303,8 +372,10 @@ def act(ref_count_setup: ReferenceCountingSetup) -> ResultHandlerReturnValue: state_setup = ref_count_setup.test_setup.state_setup mock_cast( - state_setup.reference_counting_bag_mock_setup.bag.transfer_back_to_parent_query_handler_context) \ - .assert_called_once_with(ref_count_setup.test_setup.stage_setups[0].results[0].result) + state_setup.reference_counting_bag_mock_setup.bag.transfer_back_to_parent_query_handler_context + ).assert_called_once_with( + ref_count_setup.test_setup.stage_setups[0].results[0].result + ) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 0) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 1) assert_no_transfer_from_child_to_parent_query_handler_context(ref_count_setup, 2) diff --git a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_single_stage.py b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_single_stage.py index fbf297b8..0d13cfd6 100644 --- a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_single_stage.py +++ b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_single_stage.py @@ -1,36 +1,44 @@ -from typing import Union, List +from typing import List, Union from unittest.mock import MagicMock import pytest -from exasol.analytics.query_handler.query_handler import QueryHandler -from exasol.analytics.query_handler.result import Finish, Continue +from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import ( + ResultHandlerReturnValue, +) +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) from exasol.analytics.query_handler.graph.stage.sql.sql_stage_graph import SQLStageGraph -from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import ResultHandlerReturnValue -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput -from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.assert_helper import assert_stage_not_called, \ - assert_reference_counting_bag_not_called, assert_reference_counting_bag_creation, \ - assert_stage_train_query_handler_created, \ - assert_release_on_query_handler_context_for_stage, assert_parent_query_handler_context_not_called -from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.state_test_setup import TestSetup, \ - create_mocks_for_stage, create_execution_query_handler_state_setup +from exasol.analytics.query_handler.query_handler import QueryHandler +from exasol.analytics.query_handler.result import Continue, Finish +from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.assert_helper import ( + assert_parent_query_handler_context_not_called, + assert_reference_counting_bag_creation, + assert_reference_counting_bag_not_called, + assert_release_on_query_handler_context_for_stage, + assert_stage_not_called, + assert_stage_train_query_handler_created, +) +from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.state_test_setup import ( + TestSetup, + create_execution_query_handler_state_setup, + create_mocks_for_stage, +) def create_single_stage_setup( - result_prototypes: List[Union[Continue, Finish, MagicMock]]) -> TestSetup: + result_prototypes: List[Union[Continue, Finish, MagicMock]] +) -> TestSetup: stage_setup = create_mocks_for_stage(result_prototypes, stage_index=0) sql_stage_graph = SQLStageGraph( - start_node=stage_setup.stage, - end_node=stage_setup.stage, - edges=set() + start_node=stage_setup.stage, end_node=stage_setup.stage, edges=set() ) stage_setups = [stage_setup] state_setup = create_execution_query_handler_state_setup( - sql_stage_graph, stage_setups) - return TestSetup( - stage_setups=stage_setups, - state_setup=state_setup + sql_stage_graph, stage_setups ) + return TestSetup(stage_setups=stage_setups, state_setup=state_setup) def test_get_current_query_handler_single_stage_after_init(): @@ -43,8 +51,12 @@ def arrange() -> TestSetup: test_setup = create_single_stage_setup(result_prototypes=[Finish(result=None)]) return test_setup - def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: - current_query_handler = test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + def act( + test_setup: TestSetup, + ) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: + current_query_handler = ( + test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + ) return current_query_handler test_setup = arrange() @@ -52,8 +64,10 @@ def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLSta assert_reference_counting_bag_creation(test_setup) assert_stage_train_query_handler_created( - test_setup, stage_index=0, - stage_inputs=[test_setup.state_setup.sql_stage_input_output]) + test_setup, + stage_index=0, + stage_inputs=[test_setup.state_setup.sql_stage_input_output], + ) assert test_setup.stage_setups[0].train_query_handler == result @@ -65,14 +79,17 @@ def test_handle_result_single_stage_return_finish(): def arrange() -> TestSetup: test_setup = create_single_stage_setup(result_prototypes=[Finish(result=None)]) - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() test_setup.reset_mock() return test_setup def act(test_setup: TestSetup) -> ResultHandlerReturnValue: result = test_setup.state_setup.execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[0]) + test_setup.stage_setups[0].results[0] + ) return result test_setup = arrange() @@ -111,7 +128,9 @@ def test_handle_result_single_stage_return_finish_after_finish(): test_setup.reset_mock() with pytest.raises(RuntimeError, match="No current query handler set."): - execution_query_handler_state.handle_result(test_setup.stage_setups[0].results[0]) + execution_query_handler_state.handle_result( + test_setup.stage_setups[0].results[0] + ) def test_get_current_query_handler_single_stage_return_continue_finish(): @@ -124,12 +143,17 @@ def arrange() -> TestSetup: test_setup = create_single_stage_setup( result_prototypes=[ Continue(query_list=None, input_query=None), - Finish(result=None) - ]) + Finish(result=None), + ] + ) return test_setup - def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: - current_query_handler = test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + def act( + test_setup: TestSetup, + ) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: + current_query_handler = ( + test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + ) return current_query_handler test_setup = arrange() @@ -137,8 +161,10 @@ def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLSta assert_reference_counting_bag_creation(test_setup) assert_stage_train_query_handler_created( - test_setup, stage_index=0, - stage_inputs=[test_setup.state_setup.sql_stage_input_output]) + test_setup, + stage_index=0, + stage_inputs=[test_setup.state_setup.sql_stage_input_output], + ) assert test_setup.stage_setups[0].train_query_handler == result @@ -153,16 +179,20 @@ def arrange() -> TestSetup: test_setup = create_single_stage_setup( result_prototypes=[ Continue(query_list=None, input_query=None), - Finish(result=None) - ]) - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + Finish(result=None), + ] + ) + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() test_setup.reset_mock() return test_setup def act(test_setup: TestSetup) -> ResultHandlerReturnValue: result = test_setup.state_setup.execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[0]) + test_setup.stage_setups[0].results[0] + ) return result test_setup = arrange() @@ -185,16 +215,25 @@ def arrange() -> TestSetup: test_setup = create_single_stage_setup( result_prototypes=[ Continue(query_list=None, input_query=None), - Finish(result=None) - ]) - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + Finish(result=None), + ] + ) + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() - execution_query_handler_state.handle_result(test_setup.stage_setups[0].results[0]) + execution_query_handler_state.handle_result( + test_setup.stage_setups[0].results[0] + ) test_setup.reset_mock() return test_setup - def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: - current_query_handler = test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + def act( + test_setup: TestSetup, + ) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: + current_query_handler = ( + test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + ) return current_query_handler test_setup = arrange() @@ -217,18 +256,24 @@ def arrange() -> TestSetup: test_setup = create_single_stage_setup( result_prototypes=[ Continue(query_list=None, input_query=None), - Finish(result=None) - ]) - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state + Finish(result=None), + ] + ) + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state + ) execution_query_handler_state.get_current_query_handler() - execution_query_handler_state.handle_result(test_setup.stage_setups[0].results[0]) + execution_query_handler_state.handle_result( + test_setup.stage_setups[0].results[0] + ) execution_query_handler_state.get_current_query_handler() test_setup.reset_mock() return test_setup def act(test_setup: TestSetup) -> ResultHandlerReturnValue: result = test_setup.state_setup.execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[1]) + test_setup.stage_setups[0].results[1] + ) return result test_setup = arrange() diff --git a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_two_stage.py b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_two_stage.py index 0dacc930..26e3bce9 100644 --- a/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_two_stage.py +++ b/tests/unit_tests/sql_stage_graph/stage_graph_execution_query_handler/test_state_two_stage.py @@ -1,39 +1,48 @@ -from typing import Union, List +from typing import List, Union from unittest.mock import MagicMock import pytest -from exasol.analytics.query_handler.query_handler import QueryHandler -from exasol.analytics.query_handler.result import Finish, Continue +from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import ( + ResultHandlerReturnValue, +) +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) from exasol.analytics.query_handler.graph.stage.sql.sql_stage_graph import SQLStageGraph -from exasol.analytics.query_handler.graph.stage.sql.execution.query_handler_state import ResultHandlerReturnValue -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput -from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.assert_helper import \ - assert_reference_counting_bag_creation, \ - assert_stage_train_query_handler_created, assert_stage_not_called, \ - assert_release_on_query_handler_context_for_stage, assert_reference_counting_bag_not_called, \ - assert_parent_query_handler_context_not_called -from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.state_test_setup import \ - create_mocks_for_stage, create_execution_query_handler_state_setup, TestSetup +from exasol.analytics.query_handler.query_handler import QueryHandler +from exasol.analytics.query_handler.result import Continue, Finish +from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.assert_helper import ( + assert_parent_query_handler_context_not_called, + assert_reference_counting_bag_creation, + assert_reference_counting_bag_not_called, + assert_release_on_query_handler_context_for_stage, + assert_stage_not_called, + assert_stage_train_query_handler_created, +) +from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.state_test_setup import ( + TestSetup, + create_execution_query_handler_state_setup, + create_mocks_for_stage, +) def create_two_stage_setup( - stage1_result_prototypes: List[Union[Continue, Finish, MagicMock]], - stage2_result_prototypes: List[Union[Continue, Finish, MagicMock]]) -> TestSetup: + stage1_result_prototypes: List[Union[Continue, Finish, MagicMock]], + stage2_result_prototypes: List[Union[Continue, Finish, MagicMock]], +) -> TestSetup: stage1_setup = create_mocks_for_stage(stage1_result_prototypes, stage_index=0) stage2_setup = create_mocks_for_stage(stage2_result_prototypes, stage_index=1) sql_stage_graph = SQLStageGraph( start_node=stage1_setup.stage, end_node=stage2_setup.stage, - edges={(stage1_setup.stage, stage2_setup.stage)} + edges={(stage1_setup.stage, stage2_setup.stage)}, ) stage_setups = [stage1_setup, stage2_setup] state_setup = create_execution_query_handler_state_setup( - sql_stage_graph, stage_setups) - return TestSetup( - stage_setups=stage_setups, - state_setup=state_setup + sql_stage_graph, stage_setups ) + return TestSetup(stage_setups=stage_setups, state_setup=state_setup) def test_get_current_query_handler_two_stage_return_finish_part1(): @@ -45,12 +54,16 @@ def test_get_current_query_handler_two_stage_return_finish_part1(): def arrange() -> TestSetup: test_setup = create_two_stage_setup( stage1_result_prototypes=[Finish(result=None)], - stage2_result_prototypes=[Finish(result=None)] + stage2_result_prototypes=[Finish(result=None)], ) return test_setup - def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: - current_query_handler = test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + def act( + test_setup: TestSetup, + ) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: + current_query_handler = ( + test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + ) return current_query_handler test_setup = arrange() @@ -58,8 +71,10 @@ def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLSta assert_reference_counting_bag_creation(test_setup) assert_stage_train_query_handler_created( - test_setup, stage_index=0, - stage_inputs=[test_setup.state_setup.sql_stage_input_output]) + test_setup, + stage_index=0, + stage_inputs=[test_setup.state_setup.sql_stage_input_output], + ) assert_stage_not_called(test_setup, stage_index=1) assert result == test_setup.stage_setups[0].train_query_handler @@ -77,16 +92,19 @@ def test_handle_result_two_stage_return_finish_part2(): def arrange() -> TestSetup: test_setup = create_two_stage_setup( stage1_result_prototypes=[Finish(result=None)], - stage2_result_prototypes=[Finish(result=None)] + stage2_result_prototypes=[Finish(result=None)], + ) + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state ) - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state execution_query_handler_state.get_current_query_handler() test_setup.reset_mock() return test_setup def act(test_setup: TestSetup) -> ResultHandlerReturnValue: result = test_setup.state_setup.execution_query_handler_state.handle_result( - test_setup.stage_setups[0].results[0]) + test_setup.stage_setups[0].results[0] + ) return result test_setup = arrange() @@ -95,8 +113,10 @@ def act(test_setup: TestSetup) -> ResultHandlerReturnValue: assert_reference_counting_bag_not_called(test_setup) assert_release_on_query_handler_context_for_stage(test_setup, stage_index=0) assert_stage_train_query_handler_created( - test_setup, stage_index=1, - stage_inputs=[test_setup.stage_setups[0].results[0].result]) + test_setup, + stage_index=1, + stage_inputs=[test_setup.stage_setups[0].results[0].result], + ) assert result == ResultHandlerReturnValue.CONTINUE_PROCESSING @@ -110,16 +130,24 @@ def test_get_current_query_handler_two_stage_return_finish_part3(): def arrange() -> TestSetup: test_setup = create_two_stage_setup( stage1_result_prototypes=[Finish(result=None)], - stage2_result_prototypes=[Finish(result=None)] + stage2_result_prototypes=[Finish(result=None)], + ) + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state ) - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state execution_query_handler_state.get_current_query_handler() - execution_query_handler_state.handle_result(test_setup.stage_setups[0].results[0]) + execution_query_handler_state.handle_result( + test_setup.stage_setups[0].results[0] + ) test_setup.reset_mock() return test_setup - def act(test_setup: TestSetup) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: - current_query_handler = test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + def act( + test_setup: TestSetup, + ) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]: + current_query_handler = ( + test_setup.state_setup.execution_query_handler_state.get_current_query_handler() + ) return current_query_handler test_setup = arrange() @@ -142,18 +170,23 @@ def test_handle_result_two_stage_return_finish_part4(): def arrange() -> TestSetup: test_setup = create_two_stage_setup( stage1_result_prototypes=[Finish(result=None)], - stage2_result_prototypes=[Finish(result=None)] + stage2_result_prototypes=[Finish(result=None)], + ) + execution_query_handler_state = ( + test_setup.state_setup.execution_query_handler_state ) - execution_query_handler_state = test_setup.state_setup.execution_query_handler_state execution_query_handler_state.get_current_query_handler() - execution_query_handler_state.handle_result(test_setup.stage_setups[0].results[0]) + execution_query_handler_state.handle_result( + test_setup.stage_setups[0].results[0] + ) execution_query_handler_state.get_current_query_handler() test_setup.reset_mock() return test_setup def act(test_setup: TestSetup) -> ResultHandlerReturnValue: result = test_setup.state_setup.execution_query_handler_state.handle_result( - test_setup.stage_setups[1].results[0]) + test_setup.stage_setups[1].results[0] + ) return result test_setup = arrange() @@ -175,7 +208,7 @@ def test_get_current_query_handler_two_stage_return_finish_part5(): test_setup = create_two_stage_setup( stage1_result_prototypes=[Finish(result=None)], - stage2_result_prototypes=[Finish(result=None)] + stage2_result_prototypes=[Finish(result=None)], ) execution_query_handler_state = test_setup.state_setup.execution_query_handler_state execution_query_handler_state.get_current_query_handler() diff --git a/tests/unit_tests/sql_stage_graph/test_data_partition.py b/tests/unit_tests/sql_stage_graph/test_data_partition.py index 1b81174d..421f530a 100644 --- a/tests/unit_tests/sql_stage_graph/test_data_partition.py +++ b/tests/unit_tests/sql_stage_graph/test_data_partition.py @@ -1,20 +1,20 @@ from enum import Enum, auto import pytest + +from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition +from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependency from exasol.analytics.schema import ( - SchemaName, + Column, ColumnName, + ColumnType, + SchemaName, Table, - Column, TableNameBuilder, - ViewNameBuilder, View, - ColumnType, + ViewNameBuilder, ) -from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition -from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependency - class TestEnum(Enum): __test__ = False @@ -25,9 +25,9 @@ class TestEnum(Enum): @pytest.fixture def table(): table = Table( - TableNameBuilder.create( - "table", SchemaName("TEST_SCHEMA")), - columns=[Column(ColumnName("x1"), ColumnType("INTEGER"))]) + TableNameBuilder.create("table", SchemaName("TEST_SCHEMA")), + columns=[Column(ColumnName("x1"), ColumnType("INTEGER"))], + ) return table @@ -38,9 +38,9 @@ def test_with_table(table): @pytest.fixture() def view(): view = View( - ViewNameBuilder.create( - "view", SchemaName("TEST_SCHEMA")), - columns=[Column(ColumnName("x1"), ColumnType("INTEGER"))]) + ViewNameBuilder.create("view", SchemaName("TEST_SCHEMA")), + columns=[Column(ColumnName("x1"), ColumnType("INTEGER"))], + ) return view @@ -50,9 +50,10 @@ def test_with_view(view): def test_dependencies(table, view): view = View( - ViewNameBuilder.create( - "view", SchemaName("TEST_SCHEMA")), - columns=[Column(ColumnName("x1"), ColumnType("INTEGER"))]) - DataPartition(table_like=view, - dependencies={TestEnum.K1: Dependency( - DataPartition(table_like=table))}) + ViewNameBuilder.create("view", SchemaName("TEST_SCHEMA")), + columns=[Column(ColumnName("x1"), ColumnType("INTEGER"))], + ) + DataPartition( + table_like=view, + dependencies={TestEnum.K1: Dependency(DataPartition(table_like=table))}, + ) diff --git a/tests/unit_tests/sql_stage_graph/test_dataset.py b/tests/unit_tests/sql_stage_graph/test_dataset.py index acf26ba5..632d887a 100644 --- a/tests/unit_tests/sql_stage_graph/test_dataset.py +++ b/tests/unit_tests/sql_stage_graph/test_dataset.py @@ -2,19 +2,19 @@ from typing import List import pytest + +from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition +from exasol.analytics.query_handler.graph.stage.sql.dataset import Dataset +from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependencies from exasol.analytics.schema import ( - SchemaName, + Column, ColumnName, + ColumnType, + SchemaName, Table, - Column, TableNameBuilder, - ColumnType, ) -from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition -from exasol.analytics.query_handler.graph.stage.sql.dataset import Dataset -from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependencies - class TestEnum(Enum): __test__ = False @@ -41,96 +41,108 @@ def target(): def create_table_data_partition( - name: str, - columns: List[Column], - dependencies: Dependencies = None): + name: str, columns: List[Column], dependencies: Dependencies = None +): if dependencies is None: dependencies = {} return DataPartition( table_like=Table( - TableNameBuilder.create( - name, SchemaName("TEST_SCHEMA")), - columns=columns), - dependencies=dependencies + TableNameBuilder.create(name, SchemaName("TEST_SCHEMA")), columns=columns + ), + dependencies=dependencies, ) def test_dataset_partitions_with_same_table_like_name(identifier, sample, target): extra_column = Column(ColumnName("extra"), ColumnType("INTEGER")) - partition1 = create_table_data_partition(name="TRAIN", - columns=[identifier, sample, target, extra_column]) - partition2 = create_table_data_partition(name="TRAIN", - columns=[identifier, sample, target, extra_column]) - with pytest.raises(ValueError, match="The names of table likes of the data partitions should be different."): - Dataset(data_partitions={TestEnum.K1: partition1, - TestEnum.K2: partition2}, - identifier_columns=[identifier], - sample_columns=[sample], - target_columns=[target]) + partition1 = create_table_data_partition( + name="TRAIN", columns=[identifier, sample, target, extra_column] + ) + partition2 = create_table_data_partition( + name="TRAIN", columns=[identifier, sample, target, extra_column] + ) + with pytest.raises( + ValueError, + match="The names of table likes of the data partitions should be different.", + ): + Dataset( + data_partitions={TestEnum.K1: partition1, TestEnum.K2: partition2}, + identifier_columns=[identifier], + sample_columns=[sample], + target_columns=[target], + ) def test_dataset_extra_column_valid(identifier, sample, target): extra_column = Column(ColumnName("extra"), ColumnType("INTEGER")) - partition1 = create_table_data_partition(name="TRAIN", - columns=[identifier, sample, target, extra_column]) - partition2 = create_table_data_partition(name="TEST", - columns=[identifier, sample, target, extra_column]) - Dataset(data_partitions={TestEnum.K1: partition1, - TestEnum.K2: partition2}, - identifier_columns=[identifier], - sample_columns=[sample], - target_columns=[target]) + partition1 = create_table_data_partition( + name="TRAIN", columns=[identifier, sample, target, extra_column] + ) + partition2 = create_table_data_partition( + name="TEST", columns=[identifier, sample, target, extra_column] + ) + Dataset( + data_partitions={TestEnum.K1: partition1, TestEnum.K2: partition2}, + identifier_columns=[identifier], + sample_columns=[sample], + target_columns=[target], + ) def test_dataset_partitions_different_columns_throws_exception( - identifier, sample, target): + identifier, sample, target +): extra_column = Column(ColumnName("extra"), ColumnType("INTEGER")) - partition1 = create_table_data_partition(name="TRAIN", - columns=[identifier, sample, target, extra_column]) - partition2 = create_table_data_partition(name="TEST", - columns=[identifier, sample, target]) - with pytest.raises(ValueError, match="Not all data partitions have the same columns."): - Dataset(data_partitions={TestEnum.K1: partition1, - TestEnum.K2: partition2}, - identifier_columns=[identifier], - sample_columns=[sample], - target_columns=[target]) + partition1 = create_table_data_partition( + name="TRAIN", columns=[identifier, sample, target, extra_column] + ) + partition2 = create_table_data_partition( + name="TEST", columns=[identifier, sample, target] + ) + with pytest.raises( + ValueError, match="Not all data partitions have the same columns." + ): + Dataset( + data_partitions={TestEnum.K1: partition1, TestEnum.K2: partition2}, + identifier_columns=[identifier], + sample_columns=[sample], + target_columns=[target], + ) def test_dataset_not_contains_sample_throws_exception(identifier, sample, target): - partition1 = create_table_data_partition(name="TRAIN", - columns=[identifier, target]) - partition2 = create_table_data_partition(name="TEST", - columns=[identifier, target]) + partition1 = create_table_data_partition(name="TRAIN", columns=[identifier, target]) + partition2 = create_table_data_partition(name="TEST", columns=[identifier, target]) with pytest.raises(ValueError, match="Not all sample columns in data partitions."): - Dataset(data_partitions={TestEnum.K1: partition1, - TestEnum.K2: partition2}, - identifier_columns=[identifier], - sample_columns=[sample], - target_columns=[target]) + Dataset( + data_partitions={TestEnum.K1: partition1, TestEnum.K2: partition2}, + identifier_columns=[identifier], + sample_columns=[sample], + target_columns=[target], + ) def test_dataset_not_contains_target_throws_exception(identifier, sample, target): - partition1 = create_table_data_partition(name="TRAIN", - columns=[identifier, sample]) - partition2 = create_table_data_partition(name="TEST", - columns=[identifier, sample]) + partition1 = create_table_data_partition(name="TRAIN", columns=[identifier, sample]) + partition2 = create_table_data_partition(name="TEST", columns=[identifier, sample]) with pytest.raises(ValueError, match="Not all target columns in data partitions."): - Dataset(data_partitions={TestEnum.K1: partition1, - TestEnum.K2: partition2}, - identifier_columns=[identifier], - sample_columns=[sample], - target_columns=[target]) + Dataset( + data_partitions={TestEnum.K1: partition1, TestEnum.K2: partition2}, + identifier_columns=[identifier], + sample_columns=[sample], + target_columns=[target], + ) def test_dataset_not_contains_identifier_throws_exception(identifier, sample, target): - partition1 = create_table_data_partition(name="TRAIN", - columns=[target, sample]) - partition2 = create_table_data_partition(name="TEST", - columns=[target, sample]) - with pytest.raises(ValueError, match="Not all identifier columns in data partitions."): - Dataset(data_partitions={TestEnum.K1: partition1, - TestEnum.K2: partition2}, - identifier_columns=[identifier], - sample_columns=[sample], - target_columns=[target]) + partition1 = create_table_data_partition(name="TRAIN", columns=[target, sample]) + partition2 = create_table_data_partition(name="TEST", columns=[target, sample]) + with pytest.raises( + ValueError, match="Not all identifier columns in data partitions." + ): + Dataset( + data_partitions={TestEnum.K1: partition1, TestEnum.K2: partition2}, + identifier_columns=[identifier], + sample_columns=[sample], + target_columns=[target], + ) diff --git a/tests/unit_tests/sql_stage_graph/test_find_object_proxies.py b/tests/unit_tests/sql_stage_graph/test_find_object_proxies.py index 1f82291b..32a62793 100644 --- a/tests/unit_tests/sql_stage_graph/test_find_object_proxies.py +++ b/tests/unit_tests/sql_stage_graph/test_find_object_proxies.py @@ -2,17 +2,19 @@ import pytest +from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependency +from exasol.analytics.query_handler.graph.stage.sql.execution.find_object_proxies import ( + find_object_proxies, +) from exasol.analytics.schema import ( + ColumnBuilder, + ColumnNameBuilder, + ColumnType, TableBuilder, - ViewName, TableName, - ColumnBuilder, View, - ColumnType, - ColumnNameBuilder, + ViewName, ) -from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependency -from exasol.analytics.query_handler.graph.stage.sql.execution.find_object_proxies import find_object_proxies BUCKETFS_LOCATION = "BUCKETFS_LOCATION" @@ -71,8 +73,9 @@ def test_object_proxy_in_dependency_object(object_proxy): def test_object_proxy_in_sub_dependency(object_proxy): - dependency = Dependency(object="test", - dependencies={TestEnum.K1: Dependency(object=object_proxy)}) + dependency = Dependency( + object="test", dependencies={TestEnum.K1: Dependency(object=object_proxy)} + ) result = find_object_proxies(dependency) assert result == [object_proxy] @@ -80,12 +83,13 @@ def test_object_proxy_in_sub_dependency(object_proxy): def test_object_proxy_in_table(object_proxy): if not isinstance(object_proxy, TableName): pytest.skip() - column = ColumnBuilder() \ - .with_name(ColumnNameBuilder.create("test")) \ - .with_type(ColumnType("INTEGER")).build() - table = TableBuilder() \ - .with_name(object_proxy) \ - .with_columns([column]).build() + column = ( + ColumnBuilder() + .with_name(ColumnNameBuilder.create("test")) + .with_type(ColumnType("INTEGER")) + .build() + ) + table = TableBuilder().with_name(object_proxy).with_columns([column]).build() result = find_object_proxies(table) assert result == [object_proxy] @@ -93,9 +97,12 @@ def test_object_proxy_in_table(object_proxy): def test_object_proxy_in_view(object_proxy): if not isinstance(object_proxy, ViewName): pytest.skip() - column = ColumnBuilder() \ - .with_name(ColumnNameBuilder.create("test")) \ - .with_type(ColumnType("INTEGER")).build() + column = ( + ColumnBuilder() + .with_name(ColumnNameBuilder.create("test")) + .with_type(ColumnType("INTEGER")) + .build() + ) view = View(name=object_proxy, columns=[column]) result = find_object_proxies(view) assert result == [object_proxy] @@ -113,8 +120,11 @@ def test_object_proxy_in_column(object_proxy): if not isinstance(object_proxy, TableName): pytest.skip() column_name = ColumnNameBuilder.create("test", table_like_name=object_proxy) - column = ColumnBuilder().with_name(column_name).with_type(ColumnType("INTEGER")).build() + column = ( + ColumnBuilder().with_name(column_name).with_type(ColumnType("INTEGER")).build() + ) result = find_object_proxies(column) assert result == [object_proxy] + # TODO DataPartition, Dataset, SQLStageInputOutput, arbitrary object diff --git a/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counter.py b/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counter.py index 21be428a..325f7bd6 100644 --- a/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counter.py +++ b/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counter.py @@ -1,13 +1,16 @@ import dataclasses from typing import Union -from unittest.mock import MagicMock, create_autospec, call +from unittest.mock import MagicMock, call, create_autospec import pytest + from exasol.analytics.query_handler.context.proxy.object_proxy import ObjectProxy from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext - -from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counter import ObjectProxyReferenceCounter, ReferenceCounterStatus -from tests.mock_cast import mock_cast +from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counter import ( + ObjectProxyReferenceCounter, + ReferenceCounterStatus, +) +from tests.utils.mock_cast import mock_cast MockScopeQueryHandlerContext = Union[ScopeQueryHandlerContext, MagicMock] MockObjectProxy = Union[ObjectProxy, MagicMock] @@ -27,15 +30,21 @@ def reset_mock(self): def create_test_setup() -> TestMockSetup: - parent_query_context_handler: MockScopeQueryHandlerContext = \ - create_autospec(ScopeQueryHandlerContext) - child_query_context_handler: MockScopeQueryHandlerContext = \ - create_autospec(ScopeQueryHandlerContext) - mock_cast(parent_query_context_handler.get_child_query_handler_context).side_effect = [child_query_context_handler] + parent_query_context_handler: MockScopeQueryHandlerContext = create_autospec( + ScopeQueryHandlerContext + ) + child_query_context_handler: MockScopeQueryHandlerContext = create_autospec( + ScopeQueryHandlerContext + ) + mock_cast( + parent_query_context_handler.get_child_query_handler_context + ).side_effect = [child_query_context_handler] object_proxy: MockObjectProxy = create_autospec(ObjectProxy) - return TestMockSetup(mock_parent_query_context_handler=parent_query_context_handler, - mock_child_query_context_handler=child_query_context_handler, - mock_object_proxy=object_proxy) + return TestMockSetup( + mock_parent_query_context_handler=parent_query_context_handler, + mock_child_query_context_handler=child_query_context_handler, + mock_object_proxy=object_proxy, + ) def test_init(): @@ -46,11 +55,17 @@ def test_init(): child_query_context_handler. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) - mock_cast(test_setup.mock_parent_query_context_handler.get_child_query_handler_context).assert_called_once() - mock_cast(test_setup.mock_parent_query_context_handler.transfer_object_to).assert_called_once_with( - test_setup.mock_object_proxy, test_setup.mock_child_query_context_handler) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) + mock_cast( + test_setup.mock_parent_query_context_handler.get_child_query_handler_context + ).assert_called_once() + mock_cast( + test_setup.mock_parent_query_context_handler.transfer_object_to + ).assert_called_once_with( + test_setup.mock_object_proxy, test_setup.mock_child_query_context_handler + ) def test_single_add(): @@ -59,12 +74,15 @@ def test_single_add(): It expects no calls to the parent_query_context_handler and child_query_handler_context. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) test_setup.reset_mock() counter.add() - assert test_setup.mock_parent_query_context_handler.mock_calls == [] and \ - test_setup.mock_child_query_context_handler.mock_calls == [] + assert ( + test_setup.mock_parent_query_context_handler.mock_calls == [] + and test_setup.mock_child_query_context_handler.mock_calls == [] + ) def test_single_add_and_single_remove(): @@ -74,14 +92,17 @@ def test_single_add_and_single_remove(): and that the remove returns NOT_RELEASED. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) counter.add() test_setup.reset_mock() reference_counter_status = counter.remove() - assert reference_counter_status == ReferenceCounterStatus.NOT_RELEASED and \ - test_setup.mock_parent_query_context_handler.mock_calls == [] and \ - test_setup.mock_child_query_context_handler.mock_calls == [] + assert ( + reference_counter_status == ReferenceCounterStatus.NOT_RELEASED + and test_setup.mock_parent_query_context_handler.mock_calls == [] + and test_setup.mock_child_query_context_handler.mock_calls == [] + ) def test_single_add_and_two_removes(): @@ -91,15 +112,18 @@ def test_single_add_and_two_removes(): and that the remove returns RELEASED. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) counter.add() counter.remove() test_setup.reset_mock() reference_counter_status = counter.remove() - assert reference_counter_status == ReferenceCounterStatus.RELEASED and \ - test_setup.mock_child_query_context_handler.mock_calls == [call.release()] and \ - test_setup.mock_parent_query_context_handler.mock_calls == [] + assert ( + reference_counter_status == ReferenceCounterStatus.RELEASED + and test_setup.mock_child_query_context_handler.mock_calls == [call.release()] + and test_setup.mock_parent_query_context_handler.mock_calls == [] + ) def test_single_remove(): @@ -109,13 +133,16 @@ def test_single_remove(): and that mock_child_query_context_handler.release is called """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) test_setup.reset_mock() reference_counter_status = counter.remove() - assert reference_counter_status == ReferenceCounterStatus.RELEASED and \ - test_setup.mock_child_query_context_handler.mock_calls == [call.release()] and \ - test_setup.mock_parent_query_context_handler.mock_calls == [] + assert ( + reference_counter_status == ReferenceCounterStatus.RELEASED + and test_setup.mock_child_query_context_handler.mock_calls == [call.release()] + and test_setup.mock_parent_query_context_handler.mock_calls == [] + ) def test_add_after_release(): @@ -123,13 +150,17 @@ def test_add_after_release(): This test checks that we fail, when we call add after we already released the counter. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) test_setup.reset_mock() reference_counter_status = counter.remove() assert reference_counter_status == ReferenceCounterStatus.RELEASED - with pytest.raises(RuntimeError, match="ReferenceCounter not valid anymore. " - "ObjectProxy got already garbage collected or transfered back."): + with pytest.raises( + RuntimeError, + match="ReferenceCounter not valid anymore. " + "ObjectProxy got already garbage collected or transfered back.", + ): counter.add() @@ -138,13 +169,17 @@ def test_remove_after_release(): This test checks that we fail, when we call remove after we already released the counter. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) test_setup.reset_mock() reference_counter_status = counter.remove() assert reference_counter_status == ReferenceCounterStatus.RELEASED - with pytest.raises(RuntimeError, match="ReferenceCounter not valid anymore. " - "ObjectProxy got already garbage collected or transfered back."): + with pytest.raises( + RuntimeError, + match="ReferenceCounter not valid anymore. " + "ObjectProxy got already garbage collected or transfered back.", + ): counter.remove() @@ -158,17 +193,21 @@ def test_multiple_adds_and_removes_after_each_other(count: int): - that release gets called on the child query handler context """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) test_setup.reset_mock() for i in range(count): counter.add() reference_counter_status_of_first_removes = [counter.remove() for i in range(count)] last_reference_counter_status = counter.remove() - assert last_reference_counter_status == ReferenceCounterStatus.RELEASED and \ - test_setup.mock_child_query_context_handler.mock_calls == [call.release()] and \ - test_setup.mock_parent_query_context_handler.mock_calls == [] and \ - reference_counter_status_of_first_removes == [ReferenceCounterStatus.NOT_RELEASED] * count + assert ( + last_reference_counter_status == ReferenceCounterStatus.RELEASED + and test_setup.mock_child_query_context_handler.mock_calls == [call.release()] + and test_setup.mock_parent_query_context_handler.mock_calls == [] + and reference_counter_status_of_first_removes + == [ReferenceCounterStatus.NOT_RELEASED] * count + ) @pytest.mark.parametrize("count", list(range(2, 10))) @@ -181,8 +220,9 @@ def test_multiple_adds_and_removes_after_alternating(count: int): - that release gets called on the child query handler context """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) test_setup.reset_mock() reference_counter_status_of_first_removes = [] @@ -191,10 +231,13 @@ def test_multiple_adds_and_removes_after_alternating(count: int): reference_counter_status_of_first_removes.append(counter.remove()) last_reference_counter_status = counter.remove() - assert last_reference_counter_status == ReferenceCounterStatus.RELEASED and \ - test_setup.mock_child_query_context_handler.mock_calls == [call.release()] and \ - test_setup.mock_parent_query_context_handler.mock_calls == [] and \ - reference_counter_status_of_first_removes == [ReferenceCounterStatus.NOT_RELEASED] * count + assert ( + last_reference_counter_status == ReferenceCounterStatus.RELEASED + and test_setup.mock_child_query_context_handler.mock_calls == [call.release()] + and test_setup.mock_parent_query_context_handler.mock_calls == [] + and reference_counter_status_of_first_removes + == [ReferenceCounterStatus.NOT_RELEASED] * count + ) def test_transfer_back_to_parent_query_handler_context_after_init(): @@ -205,15 +248,21 @@ def test_transfer_back_to_parent_query_handler_context_after_init(): and the parent query context handler, followed by a release call. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) test_setup.reset_mock() counter.transfer_back_to_parent_query_handler_context() assert test_setup.mock_parent_query_context_handler.mock_calls == [] - test_setup.mock_child_query_context_handler.assert_has_calls([ - call.transfer_object_to(test_setup.mock_object_proxy, test_setup.mock_parent_query_context_handler), - call.release() - ]) + test_setup.mock_child_query_context_handler.assert_has_calls( + [ + call.transfer_object_to( + test_setup.mock_object_proxy, + test_setup.mock_parent_query_context_handler, + ), + call.release(), + ] + ) def test_transfer_back_to_parent_query_handler_context_after_add(): @@ -223,16 +272,22 @@ def test_transfer_back_to_parent_query_handler_context_after_add(): ObjectProxyReferenceCounter. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) counter.add() test_setup.reset_mock() counter.transfer_back_to_parent_query_handler_context() assert test_setup.mock_parent_query_context_handler.mock_calls == [] - test_setup.mock_child_query_context_handler.assert_has_calls([ - call.transfer_object_to(test_setup.mock_object_proxy, test_setup.mock_parent_query_context_handler), - call.release() - ]) + test_setup.mock_child_query_context_handler.assert_has_calls( + [ + call.transfer_object_to( + test_setup.mock_object_proxy, + test_setup.mock_parent_query_context_handler, + ), + call.release(), + ] + ) def test_transfer_back_to_parent_query_handler_context_after_release(): @@ -242,12 +297,16 @@ def test_transfer_back_to_parent_query_handler_context_after_release(): which lead to the release of the ObjectProxy. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) counter.remove() test_setup.reset_mock() - with pytest.raises(RuntimeError, match="ReferenceCounter not valid anymore. " - "ObjectProxy got already garbage collected or transfered back."): + with pytest.raises( + RuntimeError, + match="ReferenceCounter not valid anymore. " + "ObjectProxy got already garbage collected or transfered back.", + ): counter.transfer_back_to_parent_query_handler_context() @@ -257,12 +316,16 @@ def test_two_transfer_back_to_parent_query_handler_context(): fails after a first successful transfer_back_to_parent_query_handler_context. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) counter.transfer_back_to_parent_query_handler_context() test_setup.reset_mock() - with pytest.raises(RuntimeError, match="ReferenceCounter not valid anymore. " - "ObjectProxy got already garbage collected or transfered back."): + with pytest.raises( + RuntimeError, + match="ReferenceCounter not valid anymore. " + "ObjectProxy got already garbage collected or transfered back.", + ): counter.transfer_back_to_parent_query_handler_context() @@ -271,12 +334,16 @@ def test_remove_after_transfer_back_to_parent_query_handler_context(): This tests if a remove after a call to transfer_back_to_parent_query_handler_context fails. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) counter.transfer_back_to_parent_query_handler_context() test_setup.reset_mock() - with pytest.raises(RuntimeError, match="ReferenceCounter not valid anymore. " - "ObjectProxy got already garbage collected or transfered back."): + with pytest.raises( + RuntimeError, + match="ReferenceCounter not valid anymore. " + "ObjectProxy got already garbage collected or transfered back.", + ): counter.remove() @@ -285,10 +352,14 @@ def test_add_after_transfer_back_to_parent_query_handler_context(): This tests if a add after a call to transfer_back_to_parent_query_handler_context fails. """ test_setup = create_test_setup() - counter = ObjectProxyReferenceCounter(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy) + counter = ObjectProxyReferenceCounter( + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxy + ) counter.transfer_back_to_parent_query_handler_context() test_setup.reset_mock() - with pytest.raises(RuntimeError, match="ReferenceCounter not valid anymore. " - "ObjectProxy got already garbage collected or transfered back."): + with pytest.raises( + RuntimeError, + match="ReferenceCounter not valid anymore. " + "ObjectProxy got already garbage collected or transfered back.", + ): counter.add() diff --git a/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counting_bag_with_mocks.py b/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counting_bag_with_mocks.py index 13dc1aa4..ee957c77 100644 --- a/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counting_bag_with_mocks.py +++ b/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counting_bag_with_mocks.py @@ -1,14 +1,20 @@ import dataclasses -from typing import Union, List -from unittest.mock import create_autospec, MagicMock, call, Mock +from typing import List, Union +from unittest.mock import MagicMock, Mock, call, create_autospec import pytest + from exasol.analytics.query_handler.context.proxy.object_proxy import ObjectProxy from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext - -from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counter import ObjectProxyReferenceCounter, ReferenceCounterStatus -from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counting_bag import ObjectProxyReferenceCountingBag, ObjectProxyReferenceCounterFactory -from tests.mock_cast import mock_cast +from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counter import ( + ObjectProxyReferenceCounter, + ReferenceCounterStatus, +) +from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counting_bag import ( + ObjectProxyReferenceCounterFactory, + ObjectProxyReferenceCountingBag, +) +from tests.utils.mock_cast import mock_cast MockObjectProxyReferenceCounter = Union[ObjectProxyReferenceCounter, MagicMock] MockObjectProxyReferenceCounterFactory = Union[ObjectProxyReferenceCounterFactory, Mock] @@ -34,27 +40,40 @@ def reset_mock(self): def create_test_setup(*, proxy_count: int) -> TestSetup: - parent_query_context_handler: MockScopeQueryHandlerContext = \ - create_autospec(ScopeQueryHandlerContext) - object_proxies: List[MockObjectProxy] = \ - [create_autospec(ObjectProxy) for i in range(proxy_count)] - object_proxy_reference_counter_factory: MockObjectProxyReferenceCounterFactory = Mock() - object_proxy_reference_counters = create_test_setup_with_reference_counters(object_proxies) + parent_query_context_handler: MockScopeQueryHandlerContext = create_autospec( + ScopeQueryHandlerContext + ) + object_proxies: List[MockObjectProxy] = [ + create_autospec(ObjectProxy) for i in range(proxy_count) + ] + object_proxy_reference_counter_factory: MockObjectProxyReferenceCounterFactory = ( + Mock() + ) + object_proxy_reference_counters = create_test_setup_with_reference_counters( + object_proxies + ) object_proxy_reference_counter_factory.side_effect = object_proxy_reference_counters - return TestSetup(parent_query_context_handler, - object_proxy_reference_counter_factory, - object_proxies, - object_proxy_reference_counters) - - -def create_test_setup_with_reference_counters(mock_object_proxies: List[MockObjectProxy]): - object_proxy_reference_counters = [create_mock_reference_counter() - for _ in mock_object_proxies] + return TestSetup( + parent_query_context_handler, + object_proxy_reference_counter_factory, + object_proxies, + object_proxy_reference_counters, + ) + + +def create_test_setup_with_reference_counters( + mock_object_proxies: List[MockObjectProxy], +): + object_proxy_reference_counters = [ + create_mock_reference_counter() for _ in mock_object_proxies + ] return object_proxy_reference_counters def create_mock_reference_counter() -> MockObjectProxyReferenceCounter: - object_proxy_reference_counter: MockObjectProxyReferenceCounter = create_autospec(ObjectProxyReferenceCounter) + object_proxy_reference_counter: MockObjectProxyReferenceCounter = create_autospec( + ObjectProxyReferenceCounter + ) @dataclasses.dataclass class Counter: @@ -82,7 +101,8 @@ def test_init(): test_setup = create_test_setup(proxy_count=1) bag = ObjectProxyReferenceCountingBag( test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + test_setup.mock_object_proxy_reference_counter_factory, + ) test_setup.mock_parent_query_context_handler.assert_not_called() test_setup.mock_object_proxy_reference_counter_factory.assert_not_called() assert test_setup.mock_object_proxy_reference_counters[0].mock_calls == [] @@ -96,12 +116,16 @@ def test_single_object_proxy_add(): test_setup = create_test_setup(proxy_count=1) bag = ObjectProxyReferenceCountingBag( test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) test_setup.mock_object_proxy_reference_counter_factory.assert_called_once_with( - test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxies[0]) + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxies[0] + ) assert test_setup.mock_parent_query_context_handler.mock_calls == [] - mock_cast(test_setup.mock_object_proxy_reference_counters[0].remove).assert_not_called() + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].remove + ).assert_not_called() def test_single_object_proxy_add_contains(): @@ -112,7 +136,8 @@ def test_single_object_proxy_add_contains(): test_setup = create_test_setup(proxy_count=1) bag = ObjectProxyReferenceCountingBag( test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) assert test_setup.mock_object_proxies[0] in bag @@ -124,7 +149,8 @@ def test_single_object_proxy_not_added_contains(): test_setup = create_test_setup(proxy_count=1) bag = ObjectProxyReferenceCountingBag( test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + test_setup.mock_object_proxy_reference_counter_factory, + ) assert test_setup.mock_object_proxies[0] not in bag @@ -137,7 +163,8 @@ def test_single_object_proxy_add_remove_contains(): test_setup = create_test_setup(proxy_count=1) bag = ObjectProxyReferenceCountingBag( test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) bag.remove(test_setup.mock_object_proxies[0]) assert test_setup.mock_object_proxies[0] not in bag @@ -149,11 +176,16 @@ def test_multiple_object_proxy_add_contains(): expects that __contains__ returns true for these object proxy mocks """ test_setup = create_test_setup(proxy_count=2) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) bag.add(test_setup.mock_object_proxies[1]) - assert test_setup.mock_object_proxies[0] in bag and test_setup.mock_object_proxies[1] in bag + assert ( + test_setup.mock_object_proxies[0] in bag + and test_setup.mock_object_proxies[1] in bag + ) def test_single_object_proxy_add_remove(): @@ -162,15 +194,21 @@ def test_single_object_proxy_add_remove(): expects that the remove method of the ObjectProxyReferenceCounter is called """ test_setup = create_test_setup(proxy_count=1) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) test_setup.reset_mock() bag.remove(test_setup.mock_object_proxies[0]) assert test_setup.mock_parent_query_context_handler.mock_calls == [] test_setup.mock_object_proxy_reference_counter_factory.assert_not_called() - mock_cast(test_setup.mock_object_proxy_reference_counters[0].remove).assert_called_once() - mock_cast(test_setup.mock_object_proxy_reference_counters[0].add).assert_not_called() + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].remove + ).assert_called_once() + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].add + ).assert_not_called() def test_single_object_proxy_add_add(): @@ -179,15 +217,21 @@ def test_single_object_proxy_add_add(): expects besides the behavior for the first add, no further interactions with the mocks. """ test_setup = create_test_setup(proxy_count=1) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) test_setup.reset_mock() bag.add(test_setup.mock_object_proxies[0]) assert test_setup.mock_parent_query_context_handler.mock_calls == [] test_setup.mock_object_proxy_reference_counter_factory.assert_not_called() - mock_cast(test_setup.mock_object_proxy_reference_counters[0].add).assert_called_once() - mock_cast(test_setup.mock_object_proxy_reference_counters[0].remove).assert_not_called() + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].add + ).assert_called_once() + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].remove + ).assert_not_called() def test_single_object_proxy_add_add_remove(): @@ -196,16 +240,22 @@ def test_single_object_proxy_add_add_remove(): It expects the behavior for the first add and a call to remove of the ObjectProxyRefereneCounter """ test_setup = create_test_setup(proxy_count=1) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) bag.add(test_setup.mock_object_proxies[0]) test_setup.reset_mock() bag.remove(test_setup.mock_object_proxies[0]) assert test_setup.mock_parent_query_context_handler.mock_calls == [] test_setup.mock_object_proxy_reference_counter_factory.assert_not_called() - mock_cast(test_setup.mock_object_proxy_reference_counters[0].remove).assert_called_once() - mock_cast(test_setup.mock_object_proxy_reference_counters[0].add).assert_not_called() + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].remove + ).assert_called_once() + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].add + ).assert_not_called() def test_single_object_proxy_add_add_remove_remove(): @@ -214,8 +264,10 @@ def test_single_object_proxy_add_add_remove_remove(): Besides behavior of the adds and the first remove, we expect a call to the remove of ObjectProxyReferenceCounter """ test_setup = create_test_setup(proxy_count=1) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) bag.add(test_setup.mock_object_proxies[0]) bag.remove(test_setup.mock_object_proxies[0]) @@ -223,8 +275,12 @@ def test_single_object_proxy_add_add_remove_remove(): bag.remove(test_setup.mock_object_proxies[0]) assert test_setup.mock_parent_query_context_handler.mock_calls == [] test_setup.mock_object_proxy_reference_counter_factory.assert_not_called() - mock_cast(test_setup.mock_object_proxy_reference_counters[0].remove).assert_called_once() - mock_cast(test_setup.mock_object_proxy_reference_counters[0].add).assert_not_called() + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].remove + ).assert_called_once() + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].add + ).assert_not_called() def test_multiple_object_proxies_add(): @@ -233,18 +289,30 @@ def test_multiple_object_proxies_add(): It expects the create or two ObjectProxyReferenceCounter with the factory. """ test_setup = create_test_setup(proxy_count=2) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) bag.add(test_setup.mock_object_proxies[1]) - test_setup.mock_object_proxy_reference_counter_factory.assert_has_calls([ - call(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxies[0]), - call(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxies[1], ) - ]) - mock_cast(test_setup.mock_object_proxy_reference_counters[0].remove).assert_not_called() - mock_cast(test_setup.mock_object_proxy_reference_counters[1].remove).assert_not_called() + test_setup.mock_object_proxy_reference_counter_factory.assert_has_calls( + [ + call( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxies[0], + ), + call( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxies[1], + ), + ] + ) + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].remove + ).assert_not_called() + mock_cast( + test_setup.mock_object_proxy_reference_counters[1].remove + ).assert_not_called() assert test_setup.mock_parent_query_context_handler.mock_calls == [] @@ -254,8 +322,10 @@ def test_multiple_object_proxies_add_remove(): It expects besides the behavior of the adds, calls to remove on the two ObjectProxyReferenceCounter """ test_setup = create_test_setup(proxy_count=2) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) bag.add(test_setup.mock_object_proxies[1]) test_setup.reset_mock() @@ -263,8 +333,12 @@ def test_multiple_object_proxies_add_remove(): bag.remove(test_setup.mock_object_proxies[1]) assert test_setup.mock_parent_query_context_handler.mock_calls == [] test_setup.mock_object_proxy_reference_counter_factory.assert_not_called() - mock_cast(test_setup.mock_object_proxy_reference_counters[0].remove).assert_called_once() - mock_cast(test_setup.mock_object_proxy_reference_counters[1].remove).assert_called_once() + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].remove + ).assert_called_once() + mock_cast( + test_setup.mock_object_proxy_reference_counters[1].remove + ).assert_called_once() def test_transfer_back_to_parent_query_handler_context_for_not_added_element(): @@ -273,10 +347,14 @@ def test_transfer_back_to_parent_query_handler_context_for_not_added_element(): and expects that it fails. """ test_setup = create_test_setup(proxy_count=1) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) with pytest.raises(KeyError): - bag.transfer_back_to_parent_query_handler_context(test_setup.mock_object_proxies[0]) + bag.transfer_back_to_parent_query_handler_context( + test_setup.mock_object_proxies[0] + ) def test_transfer_back_to_parent_query_handler_context_for_added_element(): @@ -285,15 +363,22 @@ def test_transfer_back_to_parent_query_handler_context_for_added_element(): It expects that transfer_back_to_parent_query_handler_context on the corresponding reference counter is called. """ test_setup = create_test_setup(proxy_count=1) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) test_setup.reset_mock() bag.transfer_back_to_parent_query_handler_context(test_setup.mock_object_proxies[0]) - mock_cast(test_setup.mock_object_proxy_reference_counters[0].transfer_back_to_parent_query_handler_context) \ - .assert_called_once() - assert test_setup.mock_parent_query_context_handler.mock_calls == [] and \ - test_setup.mock_object_proxy_reference_counter_factory.mock_calls == [] + mock_cast( + test_setup.mock_object_proxy_reference_counters[ + 0 + ].transfer_back_to_parent_query_handler_context + ).assert_called_once() + assert ( + test_setup.mock_parent_query_context_handler.mock_calls == [] + and test_setup.mock_object_proxy_reference_counter_factory.mock_calls == [] + ) def test_transfer_back_to_parent_query_handler_context_after_remove(): @@ -302,13 +387,17 @@ def test_transfer_back_to_parent_query_handler_context_after_remove(): It expects that the call fails. """ test_setup = create_test_setup(proxy_count=1) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) bag.remove(test_setup.mock_object_proxies[0]) test_setup.reset_mock() with pytest.raises(KeyError): - bag.transfer_back_to_parent_query_handler_context(test_setup.mock_object_proxies[0]) + bag.transfer_back_to_parent_query_handler_context( + test_setup.mock_object_proxies[0] + ) def test_transfer_back_to_parent_query_handler_context_after_multiple_adds(): @@ -317,16 +406,23 @@ def test_transfer_back_to_parent_query_handler_context_after_multiple_adds(): It expects the same behavior as after the first add. """ test_setup = create_test_setup(proxy_count=1) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) bag.add(test_setup.mock_object_proxies[0]) test_setup.reset_mock() bag.transfer_back_to_parent_query_handler_context(test_setup.mock_object_proxies[0]) - mock_cast(test_setup.mock_object_proxy_reference_counters[0].transfer_back_to_parent_query_handler_context) \ - .assert_called_once() - assert test_setup.mock_parent_query_context_handler.mock_calls == [] and \ - test_setup.mock_object_proxy_reference_counter_factory.mock_calls == [] + mock_cast( + test_setup.mock_object_proxy_reference_counters[ + 0 + ].transfer_back_to_parent_query_handler_context + ).assert_called_once() + assert ( + test_setup.mock_parent_query_context_handler.mock_calls == [] + and test_setup.mock_object_proxy_reference_counter_factory.mock_calls == [] + ) def test_remove_after_transfer_back_to_parent_query_handler_context(): @@ -335,8 +431,10 @@ def test_remove_after_transfer_back_to_parent_query_handler_context(): It expects the remove to fail. """ test_setup = create_test_setup(proxy_count=1) - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) bag.add(test_setup.mock_object_proxies[0]) test_setup.reset_mock() @@ -354,18 +452,24 @@ def test_add_after_transfer_back_to_parent_query_handler_context(): """ test_setup = create_test_setup(proxy_count=1) # For this test we need allow the creation of a second reference counter for the same proxy count - test_setup.mock_object_proxy_reference_counter_factory.side_effect = \ + test_setup.mock_object_proxy_reference_counter_factory.side_effect = ( test_setup.mock_object_proxy_reference_counters * 2 - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + ) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) bag.transfer_back_to_parent_query_handler_context(test_setup.mock_object_proxies[0]) test_setup.reset_mock() bag.add(test_setup.mock_object_proxies[0]) test_setup.mock_object_proxy_reference_counter_factory.assert_called_once_with( - test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxies[0]) + test_setup.mock_parent_query_context_handler, test_setup.mock_object_proxies[0] + ) assert test_setup.mock_parent_query_context_handler.mock_calls == [] - mock_cast(test_setup.mock_object_proxy_reference_counters[0].remove).assert_not_called() + mock_cast( + test_setup.mock_object_proxy_reference_counters[0].remove + ).assert_not_called() def test_contains_after_transfer_back_to_parent_query_handler_context(): @@ -375,10 +479,13 @@ def test_contains_after_transfer_back_to_parent_query_handler_context(): """ test_setup = create_test_setup(proxy_count=1) # For this test we need allow the creation of a second reference counter for the same proxy count - test_setup.mock_object_proxy_reference_counter_factory.side_effect = \ + test_setup.mock_object_proxy_reference_counter_factory.side_effect = ( test_setup.mock_object_proxy_reference_counters * 2 - bag = ObjectProxyReferenceCountingBag(test_setup.mock_parent_query_context_handler, - test_setup.mock_object_proxy_reference_counter_factory) + ) + bag = ObjectProxyReferenceCountingBag( + test_setup.mock_parent_query_context_handler, + test_setup.mock_object_proxy_reference_counter_factory, + ) bag.add(test_setup.mock_object_proxies[0]) bag.transfer_back_to_parent_query_handler_context(test_setup.mock_object_proxies[0]) test_setup.reset_mock() diff --git a/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counting_bag_with_query_handler_context_impl.py b/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counting_bag_with_query_handler_context_impl.py index 0edb2620..d0842012 100644 --- a/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counting_bag_with_query_handler_context_impl.py +++ b/tests/unit_tests/sql_stage_graph/test_object_proxy_reference_counting_bag_with_query_handler_context_impl.py @@ -1,10 +1,16 @@ from pathlib import PurePosixPath import pytest -from exasol.analytics.query_handler.context.top_level_query_handler_context import TopLevelQueryHandlerContext -from exasol_bucketfs_utils_python.localfs_mock_bucketfs_location import LocalFSMockBucketFSLocation +from exasol_bucketfs_utils_python.localfs_mock_bucketfs_location import ( + LocalFSMockBucketFSLocation, +) -from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counting_bag import ObjectProxyReferenceCountingBag +from exasol.analytics.query_handler.context.top_level_query_handler_context import ( + TopLevelQueryHandlerContext, +) +from exasol.analytics.query_handler.graph.stage.sql.execution.object_proxy_reference_counting_bag import ( + ObjectProxyReferenceCountingBag, +) @pytest.fixture @@ -34,8 +40,7 @@ def test_single_add_remove(context): assert len(context.cleanup_released_object_proxies()) == 1 -def test_single_add_remove_only_the_added_object_proxy_get_removed( - context): +def test_single_add_remove_only_the_added_object_proxy_get_removed(context): """ This tests adds and removes a object_proxy to a ObjectProxyReferenceCountingBag. Further, it creates an additional object proxy which it doesn't add and check if only one proxy was released. @@ -71,12 +76,13 @@ def test_single_add_remove_add(context): bag = ObjectProxyReferenceCountingBag(context) bag.add(table_name) bag.remove(table_name) - with pytest.raises(Exception, match="Object not owned by this ScopeQueryHandlerContext."): + with pytest.raises( + Exception, match="Object not owned by this ScopeQueryHandlerContext." + ): bag.add(table_name) -def test_transfer_back_to_parent_query_handler_context_after_add( - context): +def test_transfer_back_to_parent_query_handler_context_after_add(context): table_name = context.get_temporary_table_name() bag = ObjectProxyReferenceCountingBag(context) bag.add(table_name) @@ -86,8 +92,7 @@ def test_transfer_back_to_parent_query_handler_context_after_add( assert len(context.cleanup_released_object_proxies()) == 1 -def test_add_after_transfer_back_to_parent_query_handler_context( - context): +def test_add_after_transfer_back_to_parent_query_handler_context(context): table_name = context.get_temporary_table_name() bag = ObjectProxyReferenceCountingBag(context) bag.add(table_name) diff --git a/tests/unit_tests/sql_stage_graph/test_sql_stage_input_output.py b/tests/unit_tests/sql_stage_graph/test_sql_stage_input_output.py index 31d13547..e6113fac 100644 --- a/tests/unit_tests/sql_stage_graph/test_sql_stage_input_output.py +++ b/tests/unit_tests/sql_stage_graph/test_sql_stage_input_output.py @@ -4,18 +4,23 @@ import pytest from typeguard import TypeCheckError +from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition +from exasol.analytics.query_handler.graph.stage.sql.dataset import Dataset +from exasol.analytics.query_handler.graph.stage.sql.dependency import ( + Dependencies, + Dependency, +) +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) from exasol.analytics.schema import ( - SchemaName, + Column, ColumnName, + ColumnType, + SchemaName, Table, - Column, TableNameBuilder, - ColumnType, ) -from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition -from exasol.analytics.query_handler.graph.stage.sql.dataset import Dataset -from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependencies, Dependency -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput class TestEnum(Enum): @@ -43,27 +48,27 @@ def target(): def create_table_data_partition( - name: str, - columns: List[Column], - dependencies: Dependencies): + name: str, columns: List[Column], dependencies: Dependencies +): return DataPartition( table_like=Table( - TableNameBuilder.create( - name, SchemaName("TEST_SCHEMA")), - columns=columns), - dependencies=dependencies + TableNameBuilder.create(name, SchemaName("TEST_SCHEMA")), columns=columns + ), + dependencies=dependencies, ) @pytest.fixture() def dataset(identifier, sample, target): - partition1 = create_table_data_partition(name="TRAIN", - columns=[identifier, sample, target], - dependencies={}) - dataset = Dataset(data_partitions={TestEnum.K1: partition1}, - identifier_columns=[identifier], - sample_columns=[sample], - target_columns=[target]) + partition1 = create_table_data_partition( + name="TRAIN", columns=[identifier, sample, target], dependencies={} + ) + dataset = Dataset( + data_partitions={TestEnum.K1: partition1}, + identifier_columns=[identifier], + sample_columns=[sample], + target_columns=[target], + ) return dataset @@ -82,4 +87,6 @@ def test_dataset(dataset): def test_dependencies(dataset): - SQLStageInputOutput(dataset=dataset, dependencies={TestEnum.K2: Dependency(object="mystr")}) + SQLStageInputOutput( + dataset=dataset, dependencies={TestEnum.K2: Dependency(object="mystr")} + ) diff --git a/tests/unit_tests/sql_stage_graph/test_sql_stage_train_query_handler_input.py b/tests/unit_tests/sql_stage_graph/test_sql_stage_train_query_handler_input.py index 24ac86cb..ee074d57 100644 --- a/tests/unit_tests/sql_stage_graph/test_sql_stage_train_query_handler_input.py +++ b/tests/unit_tests/sql_stage_graph/test_sql_stage_train_query_handler_input.py @@ -1,72 +1,94 @@ from typing import Union -from unittest.mock import create_autospec, MagicMock +from unittest.mock import MagicMock, create_autospec import pytest -from exasol_bucketfs_utils_python.abstract_bucketfs_location import AbstractBucketFSLocation +from exasol_bucketfs_utils_python.abstract_bucketfs_location import ( + AbstractBucketFSLocation, +) -from exasol.analytics.query_handler.graph.stage.sql.input_output import SQLStageInputOutput -from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import SQLStageTrainQueryHandlerInput +from exasol.analytics.query_handler.graph.stage.sql.input_output import ( + SQLStageInputOutput, +) +from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import ( + SQLStageTrainQueryHandlerInput, +) def test_empty_stage_inputs(): - bucketfs_location: Union[AbstractBucketFSLocation, MagicMock] = create_autospec(AbstractBucketFSLocation) + bucketfs_location: Union[AbstractBucketFSLocation, MagicMock] = create_autospec( + AbstractBucketFSLocation + ) with pytest.raises(AssertionError, match="Empty sql_stage_inputs not allowed."): SQLStageTrainQueryHandlerInput( - sql_stage_inputs=[], - result_bucketfs_location=bucketfs_location + sql_stage_inputs=[], result_bucketfs_location=bucketfs_location ) def test_non_empty_stage_inputs(): - bucketfs_location: Union[AbstractBucketFSLocation, MagicMock] = create_autospec(AbstractBucketFSLocation) - sql_stage_input: Union[SQLStageInputOutput, MagicMock] = create_autospec(SQLStageInputOutput) + bucketfs_location: Union[AbstractBucketFSLocation, MagicMock] = create_autospec( + AbstractBucketFSLocation + ) + sql_stage_input: Union[SQLStageInputOutput, MagicMock] = create_autospec( + SQLStageInputOutput + ) obj = SQLStageTrainQueryHandlerInput( - sql_stage_inputs=[sql_stage_input], - result_bucketfs_location=bucketfs_location + sql_stage_inputs=[sql_stage_input], result_bucketfs_location=bucketfs_location ) assert ( - obj.sql_stage_inputs == [sql_stage_input] - and obj.result_bucketfs_location == bucketfs_location + obj.sql_stage_inputs == [sql_stage_input] + and obj.result_bucketfs_location == bucketfs_location ) def test_equality(): - bucketfs_location: Union[AbstractBucketFSLocation, MagicMock] = create_autospec(AbstractBucketFSLocation) - sql_stage_input: Union[SQLStageInputOutput, MagicMock] = create_autospec(SQLStageInputOutput) + bucketfs_location: Union[AbstractBucketFSLocation, MagicMock] = create_autospec( + AbstractBucketFSLocation + ) + sql_stage_input: Union[SQLStageInputOutput, MagicMock] = create_autospec( + SQLStageInputOutput + ) obj1 = SQLStageTrainQueryHandlerInput( - sql_stage_inputs=[sql_stage_input], - result_bucketfs_location=bucketfs_location + sql_stage_inputs=[sql_stage_input], result_bucketfs_location=bucketfs_location ) obj2 = SQLStageTrainQueryHandlerInput( - sql_stage_inputs=[sql_stage_input], - result_bucketfs_location=bucketfs_location + sql_stage_inputs=[sql_stage_input], result_bucketfs_location=bucketfs_location ) assert obj1 == obj2 + def test_inequality_sql_stage_input(): - bucketfs_location: Union[AbstractBucketFSLocation, MagicMock] = create_autospec(AbstractBucketFSLocation) - sql_stage_input1: Union[SQLStageInputOutput, MagicMock] = create_autospec(SQLStageInputOutput) - sql_stage_input2: Union[SQLStageInputOutput, MagicMock] = create_autospec(SQLStageInputOutput) + bucketfs_location: Union[AbstractBucketFSLocation, MagicMock] = create_autospec( + AbstractBucketFSLocation + ) + sql_stage_input1: Union[SQLStageInputOutput, MagicMock] = create_autospec( + SQLStageInputOutput + ) + sql_stage_input2: Union[SQLStageInputOutput, MagicMock] = create_autospec( + SQLStageInputOutput + ) obj1 = SQLStageTrainQueryHandlerInput( - sql_stage_inputs=[sql_stage_input1], - result_bucketfs_location=bucketfs_location + sql_stage_inputs=[sql_stage_input1], result_bucketfs_location=bucketfs_location ) obj2 = SQLStageTrainQueryHandlerInput( - sql_stage_inputs=[sql_stage_input2], - result_bucketfs_location=bucketfs_location + sql_stage_inputs=[sql_stage_input2], result_bucketfs_location=bucketfs_location ) assert obj1 != obj2 + def test_inequality_bucketfs_location(): - bucketfs_location1: Union[AbstractBucketFSLocation, MagicMock] = create_autospec(AbstractBucketFSLocation) - bucketfs_location2: Union[AbstractBucketFSLocation, MagicMock] = create_autospec(AbstractBucketFSLocation) - sql_stage_input: Union[SQLStageInputOutput, MagicMock] = create_autospec(SQLStageInputOutput) + bucketfs_location1: Union[AbstractBucketFSLocation, MagicMock] = create_autospec( + AbstractBucketFSLocation + ) + bucketfs_location2: Union[AbstractBucketFSLocation, MagicMock] = create_autospec( + AbstractBucketFSLocation + ) + sql_stage_input: Union[SQLStageInputOutput, MagicMock] = create_autospec( + SQLStageInputOutput + ) obj1 = SQLStageTrainQueryHandlerInput( - sql_stage_inputs=[sql_stage_input], - result_bucketfs_location=bucketfs_location1 + sql_stage_inputs=[sql_stage_input], result_bucketfs_location=bucketfs_location1 ) obj2 = SQLStageTrainQueryHandlerInput( - sql_stage_inputs=[sql_stage_input], - result_bucketfs_location=bucketfs_location2 + sql_stage_inputs=[sql_stage_input], result_bucketfs_location=bucketfs_location2 ) - assert obj1 != obj2 \ No newline at end of file + assert obj1 != obj2 diff --git a/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_closer.py b/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_closer.py index 9ba265cc..9ac580f5 100644 --- a/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_closer.py +++ b/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_closer.py @@ -1,20 +1,24 @@ import dataclasses from typing import Union -from unittest.mock import MagicMock, create_autospec, call +from unittest.mock import MagicMock, call, create_autospec from exasol.analytics.udf.communication import messages from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.messages import Message from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSender -from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.close_connection_sender import \ - CloseConnectionSender -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer import ConnectionCloser -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_is_closed_sender import ConnectionIsClosedSender +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSender, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.close_connection_sender import ( + CloseConnectionSender, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer import ( + ConnectionCloser, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_is_closed_sender import ( + ConnectionIsClosedSender, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender @@ -43,19 +47,25 @@ def create_test_setup() -> TestSetup: name="t1", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=11), - group_identifier="g" - )) + group_identifier="g", + ) + ) my_connection_info = ConnectionInfo( name="t0", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=10), - group_identifier="g" + group_identifier="g", ) sender_mock: Union[MagicMock, Sender] = create_autospec(Sender) - abort_timeout_sender_mock: Union[MagicMock, AbortTimeoutSender] = create_autospec(AbortTimeoutSender) - connection_is_closed_sender: Union[MagicMock, ConnectionIsClosedSender] = create_autospec(ConnectionIsClosedSender) - close_connection_sender_mock: Union[MagicMock, CloseConnectionSender] = \ + abort_timeout_sender_mock: Union[MagicMock, AbortTimeoutSender] = create_autospec( + AbortTimeoutSender + ) + connection_is_closed_sender: Union[MagicMock, ConnectionIsClosedSender] = ( + create_autospec(ConnectionIsClosedSender) + ) + close_connection_sender_mock: Union[MagicMock, CloseConnectionSender] = ( create_autospec(CloseConnectionSender) + ) connection_closer = ConnectionCloser( my_connection_info=my_connection_info, peer=peer, @@ -78,10 +88,10 @@ def create_test_setup() -> TestSetup: def test_init(): test_setup = create_test_setup() assert ( - test_setup.close_connection_sender_mock.mock_calls == [] - and test_setup.connection_is_closed_sender_mock.mock_calls == [] - and test_setup.abort_timeout_sender_mock.mock_calls == [] - and test_setup.sender_mock.mock_calls == [] + test_setup.close_connection_sender_mock.mock_calls == [] + and test_setup.connection_is_closed_sender_mock.mock_calls == [] + and test_setup.abort_timeout_sender_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] ) @@ -90,10 +100,10 @@ def test_try_send(): test_setup.reset_mock() test_setup.connection_closer.try_send() assert ( - test_setup.close_connection_sender_mock.mock_calls == [call.try_send()] - and test_setup.connection_is_closed_sender_mock.mock_calls == [call.try_send()] - and test_setup.abort_timeout_sender_mock.mock_calls == [] - and test_setup.sender_mock.mock_calls == [] + test_setup.close_connection_sender_mock.mock_calls == [call.try_send()] + and test_setup.connection_is_closed_sender_mock.mock_calls == [call.try_send()] + and test_setup.abort_timeout_sender_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] ) @@ -102,13 +112,21 @@ def test_received_close_connection(): test_setup.reset_mock() test_setup.connection_closer.received_close_connection() assert ( - test_setup.close_connection_sender_mock.mock_calls == [] - and test_setup.connection_is_closed_sender_mock.mock_calls == [call.received_close_connection()] - and test_setup.abort_timeout_sender_mock.mock_calls == [] - and test_setup.sender_mock.mock_calls == [ - call.send(Message(__root__=messages.AcknowledgeCloseConnection( - source=test_setup.my_connection_info, destination=test_setup.peer - )))] + test_setup.close_connection_sender_mock.mock_calls == [] + and test_setup.connection_is_closed_sender_mock.mock_calls + == [call.received_close_connection()] + and test_setup.abort_timeout_sender_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls + == [ + call.send( + Message( + __root__=messages.AcknowledgeCloseConnection( + source=test_setup.my_connection_info, + destination=test_setup.peer, + ) + ) + ) + ] ) @@ -117,8 +135,9 @@ def test_received_acknowledge_close_connection(): test_setup.reset_mock() test_setup.connection_closer.received_acknowledge_close_connection() assert ( - test_setup.close_connection_sender_mock.mock_calls == [call.stop()] - and test_setup.connection_is_closed_sender_mock.mock_calls == [call.received_acknowledge_close_connection()] - and test_setup.abort_timeout_sender_mock.mock_calls == [] - and test_setup.sender_mock.mock_calls == [] + test_setup.close_connection_sender_mock.mock_calls == [call.stop()] + and test_setup.connection_is_closed_sender_mock.mock_calls + == [call.received_acknowledge_close_connection()] + and test_setup.abort_timeout_sender_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] ) diff --git a/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_closer_builder.py b/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_closer_builder.py index b808e096..ac00817a 100644 --- a/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_closer_builder.py +++ b/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_closer_builder.py @@ -1,28 +1,33 @@ import dataclasses -from typing import Union, List -from unittest.mock import MagicMock, Mock, create_autospec, call +from typing import List, Union +from unittest.mock import MagicMock, Mock, call, create_autospec from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSenderFactory -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.close_connection_sender import \ - CloseConnectionSenderFactory -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer_builder import ConnectionCloserBuilder -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer_factory import ConnectionCloserFactory -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer_timeout_config import ConnectionCloserTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_is_closed_sender import ConnectionIsClosedSenderFactory +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSenderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.close_connection_sender import ( + CloseConnectionSenderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer_builder import ( + ConnectionCloserBuilder, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer_factory import ( + ConnectionCloserFactory, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer_timeout_config import ( + ConnectionCloserTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_is_closed_sender import ( + ConnectionIsClosedSenderFactory, +) from exasol.analytics.udf.communication.peer_communicator.clock import Clock from exasol.analytics.udf.communication.peer_communicator.sender import Sender from exasol.analytics.udf.communication.peer_communicator.timer import TimerFactory from exasol.analytics.udf.communication.socket_factory.abstract import Socket -from tests.mock_cast import mock_cast +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass() @@ -34,7 +39,9 @@ class TestSetup: clock_mock: Union[MagicMock, Clock] timeout_config: ConnectionCloserTimeoutConfig abort_timeout_sender_factory_mock: Union[MagicMock, AbortTimeoutSenderFactory] - connection_is_closed_sender_factory_mock: Union[MagicMock, ConnectionIsClosedSenderFactory] + connection_is_closed_sender_factory_mock: Union[ + MagicMock, ConnectionIsClosedSenderFactory + ] close_connection_sender_factory_mock: Union[MagicMock, CloseConnectionSenderFactory] timer_factory_mock: Union[MagicMock, TimerFactory] timer_mocks: List[Mock] @@ -59,20 +66,24 @@ def create_test_setup() -> TestSetup: name="t1", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=11), - group_identifier="g" - )) + group_identifier="g", + ) + ) my_connection_info = ConnectionInfo( name="t0", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=10), - group_identifier="g" + group_identifier="g", + ) + abort_timeout_sender_factory_mock: Union[MagicMock, AbortTimeoutSenderFactory] = ( + create_autospec(AbortTimeoutSenderFactory) ) - abort_timeout_sender_factory_mock: Union[MagicMock, AbortTimeoutSenderFactory] = create_autospec( - AbortTimeoutSenderFactory) - conncection_is_ready_sender_factory_mock: Union[MagicMock, ConnectionIsClosedSenderFactory] = create_autospec( - ConnectionIsClosedSenderFactory) - close_connection_sender_factory_mock: Union[MagicMock, CloseConnectionSenderFactory] = create_autospec( - CloseConnectionSenderFactory) + conncection_is_ready_sender_factory_mock: Union[ + MagicMock, ConnectionIsClosedSenderFactory + ] = create_autospec(ConnectionIsClosedSenderFactory) + close_connection_sender_factory_mock: Union[ + MagicMock, CloseConnectionSenderFactory + ] = create_autospec(CloseConnectionSenderFactory) timer_factory_mock: Union[MagicMock, TimerFactory] = create_autospec(TimerFactory) timer_mocks = [Mock(), Mock(), Mock(), Mock(), Mock()] mock_cast(timer_factory_mock.create).side_effect = timer_mocks @@ -84,14 +95,15 @@ def create_test_setup() -> TestSetup: close_retry_timeout_in_ms=2, connection_is_closed_wait_time_in_ms=3, ) - connection_closer_factory_mock: Union[MagicMock, ConnectionCloserFactory] = \ + connection_closer_factory_mock: Union[MagicMock, ConnectionCloserFactory] = ( create_autospec(ConnectionCloserFactory) + ) connection_closer_builder = ConnectionCloserBuilder( abort_timeout_sender_factory=abort_timeout_sender_factory_mock, connection_is_closed_sender_factory=conncection_is_ready_sender_factory_mock, close_connection_sender_factory=close_connection_sender_factory_mock, timer_factory=timer_factory_mock, - connection_closer_factory=connection_closer_factory_mock + connection_closer_factory=connection_closer_factory_mock, ) return TestSetup( connection_closer_builder=connection_closer_builder, @@ -106,7 +118,7 @@ def create_test_setup() -> TestSetup: timer_factory_mock=timer_factory_mock, timer_mocks=timer_mocks, sender_mock=sender_mock, - timeout_config=timeout_config + timeout_config=timeout_config, ) @@ -114,8 +126,12 @@ def test_init(): test_setup = create_test_setup() mock_cast(test_setup.timer_factory_mock.create).assert_not_called() mock_cast(test_setup.abort_timeout_sender_factory_mock.create).assert_not_called() - mock_cast(test_setup.close_connection_sender_factory_mock.create).assert_not_called() - mock_cast(test_setup.connection_is_closed_sender_factory_mock.create).assert_not_called() + mock_cast( + test_setup.close_connection_sender_factory_mock.create + ).assert_not_called() + mock_cast( + test_setup.connection_is_closed_sender_factory_mock.create + ).assert_not_called() mock_cast(test_setup.connection_closer_factory_mock.create).assert_not_called() @@ -128,7 +144,7 @@ def test_create(): clock=test_setup.clock_mock, out_control_socket=test_setup.out_control_socket_mock, peer=test_setup.peer, - timeout_config=test_setup.timeout_config + timeout_config=test_setup.timeout_config, ) assert_timer_factory(test_setup) test_setup.sender_mock.assert_not_called() @@ -141,7 +157,9 @@ def test_create(): def assert_connection_is_closed_sender_factory_mock(test_setup): - mock_cast(test_setup.connection_is_closed_sender_factory_mock.create).assert_called_once_with( + mock_cast( + test_setup.connection_is_closed_sender_factory_mock.create + ).assert_called_once_with( my_connection_info=test_setup.my_connection_info, peer=test_setup.peer, out_control_socket=test_setup.out_control_socket_mock, @@ -150,21 +168,25 @@ def assert_connection_is_closed_sender_factory_mock(test_setup): def assert_abort_timeout_sender_factory_mock(test_setup): - mock_cast(test_setup.abort_timeout_sender_factory_mock.create).assert_called_once_with( + mock_cast( + test_setup.abort_timeout_sender_factory_mock.create + ).assert_called_once_with( my_connection_info=test_setup.my_connection_info, peer=test_setup.peer, out_control_socket=test_setup.out_control_socket_mock, timer=test_setup.timer_mocks[1], - reason='Timeout occurred during establishing connection.' + reason="Timeout occurred during establishing connection.", ) def assert_close_connection_sender_factory_mock(test_setup): - mock_cast(test_setup.close_connection_sender_factory_mock.create).assert_called_once_with( + mock_cast( + test_setup.close_connection_sender_factory_mock.create + ).assert_called_once_with( my_connection_info=test_setup.my_connection_info, peer=test_setup.peer, sender=test_setup.sender_mock, - timer=test_setup.timer_mocks[0] + timer=test_setup.timer_mocks[0], ) @@ -174,14 +196,19 @@ def assert_timer_mocks(test_setup): def assert_timer_factory(test_setup): - test_setup.timer_factory_mock.assert_has_calls([ - call.create( - clock=test_setup.clock_mock, - timeout_in_ms=test_setup.timeout_config.close_retry_timeout_in_ms), - call.create( - clock=test_setup.clock_mock, - timeout_in_ms=test_setup.timeout_config.abort_timeout_in_ms), - call.create( - clock=test_setup.clock_mock, - timeout_in_ms=test_setup.timeout_config.connection_is_closed_wait_time_in_ms), - ]) + test_setup.timer_factory_mock.assert_has_calls( + [ + call.create( + clock=test_setup.clock_mock, + timeout_in_ms=test_setup.timeout_config.close_retry_timeout_in_ms, + ), + call.create( + clock=test_setup.clock_mock, + timeout_in_ms=test_setup.timeout_config.abort_timeout_in_ms, + ), + call.create( + clock=test_setup.clock_mock, + timeout_in_ms=test_setup.timeout_config.connection_is_closed_wait_time_in_ms, + ), + ] + ) diff --git a/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_is_closed_sender.py b/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_is_closed_sender.py index e6010f81..0c57e40b 100644 --- a/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_is_closed_sender.py +++ b/tests/unit_tests/udf_communication/peer_communication/background_thread/connection_closer/test_connection_is_closed_sender.py @@ -1,6 +1,6 @@ import dataclasses from typing import Union -from unittest.mock import MagicMock, create_autospec, call +from unittest.mock import MagicMock, call, create_autospec import pytest @@ -8,12 +8,13 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.background_thread. \ - connection_closer.connection_is_closed_sender import ConnectionIsClosedSender +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_is_closed_sender import ( + ConnectionIsClosedSender, +) from exasol.analytics.udf.communication.peer_communicator.timer import Timer from exasol.analytics.udf.communication.serialization import serialize_message from exasol.analytics.udf.communication.socket_factory.abstract import Socket -from tests.mock_cast import mock_cast +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass() @@ -35,13 +36,14 @@ def create_test_setup(): name="t2", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=12), - group_identifier="g" - )) + group_identifier="g", + ) + ) my_connection_info = ConnectionInfo( name="t1", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=11), - group_identifier="g" + group_identifier="g", ) timer_mock = create_autospec(Timer) out_control_socket_mock = create_autospec(Socket) @@ -55,15 +57,15 @@ def create_test_setup(): peer=peer, timer_mock=timer_mock, out_control_socket_mock=out_control_socket_mock, - peer_is_closed_sender=connection_is_ready_sender + peer_is_closed_sender=connection_is_ready_sender, ) def test_init(): test_setup = create_test_setup() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [] ) @@ -76,21 +78,19 @@ def test_try_send_after_init(is_time: bool): test_setup.peer_is_closed_sender.try_send() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [ - call.is_time() - ] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) -@pytest.mark.parametrize("is_time,send_expected", - [ - (True, True), - (False, False), - ]) -def test_try_send_after_synchronize_connection( - is_time: bool, - send_expected: bool): +@pytest.mark.parametrize( + "is_time,send_expected", + [ + (True, True), + (False, False), + ], +) +def test_try_send_after_synchronize_connection(is_time: bool, send_expected: bool): test_setup = create_test_setup() test_setup.peer_is_closed_sender.received_close_connection() mock_cast(test_setup.timer_mock.is_time).return_value = is_time @@ -99,17 +99,15 @@ def test_try_send_after_synchronize_connection( test_setup.peer_is_closed_sender.try_send() if send_expected: - assert ( - test_setup.out_control_socket_mock.mock_calls == - [ - call.send(serialize_message(messages.ConnectionIsClosed(peer=test_setup.peer))) - ] - and test_setup.timer_mock.mock_calls == [call.is_time()] - ) + assert test_setup.out_control_socket_mock.mock_calls == [ + call.send( + serialize_message(messages.ConnectionIsClosed(peer=test_setup.peer)) + ) + ] and test_setup.timer_mock.mock_calls == [call.is_time()] else: assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) @@ -123,13 +121,15 @@ def test_try_send_after_acknowledge_connection(is_time: bool): test_setup.peer_is_closed_sender.try_send() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) @pytest.mark.parametrize("is_time", [True, False]) -def test_try_send_after_synchronize_connection_and_acknowledge_connection(is_time: bool): +def test_try_send_after_synchronize_connection_and_acknowledge_connection( + is_time: bool, +): test_setup = create_test_setup() test_setup.peer_is_closed_sender.received_close_connection() test_setup.peer_is_closed_sender.received_acknowledge_close_connection() @@ -138,13 +138,9 @@ def test_try_send_after_synchronize_connection_and_acknowledge_connection(is_tim test_setup.peer_is_closed_sender.try_send() - assert ( - test_setup.out_control_socket_mock.mock_calls == - [ - call.send(serialize_message(messages.ConnectionIsClosed(peer=test_setup.peer))) - ] - and test_setup.timer_mock.mock_calls == [call.is_time()] - ) + assert test_setup.out_control_socket_mock.mock_calls == [ + call.send(serialize_message(messages.ConnectionIsClosed(peer=test_setup.peer))) + ] and test_setup.timer_mock.mock_calls == [call.is_time()] @pytest.mark.parametrize("is_time", [True, False]) @@ -158,8 +154,8 @@ def test_try_send_twice_after_synchronize_connection(is_time: bool): test_setup.peer_is_closed_sender.try_send() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) @@ -174,6 +170,6 @@ def test_try_send_twice_after_acknowledge_connection(is_time: bool): test_setup.peer_is_closed_sender.try_send() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_abort_timeout_sender.py b/tests/unit_tests/udf_communication/peer_communication/test_abort_timeout_sender.py index a08b30e7..dcc59d19 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_abort_timeout_sender.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_abort_timeout_sender.py @@ -1,6 +1,6 @@ import dataclasses -from typing import Union, cast, Any -from unittest.mock import MagicMock, Mock, create_autospec, call +from typing import Any, Union, cast +from unittest.mock import MagicMock, Mock, call, create_autospec import pytest @@ -8,8 +8,9 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSender +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSender, +) from exasol.analytics.udf.communication.peer_communicator.timer import Timer from exasol.analytics.udf.communication.serialization import serialize_message from exasol.analytics.udf.communication.socket_factory.abstract import Socket @@ -38,13 +39,14 @@ def create_test_setup(): name="t2", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=12), - group_identifier="g" - )) + group_identifier="g", + ) + ) my_connection_info = ConnectionInfo( name="t1", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=11), - group_identifier="g" + group_identifier="g", ) timer_mock = create_autospec(Timer) out_control_socket_mock = create_autospec(Socket) @@ -54,21 +56,21 @@ def create_test_setup(): my_connection_info=my_connection_info, out_control_socket=out_control_socket_mock, timer=timer_mock, - reason=reason + reason=reason, ) return TestSetup( reason=reason, timer_mock=timer_mock, out_control_socket_mock=out_control_socket_mock, - abort_timeout_sender=abort_timeout_sender + abort_timeout_sender=abort_timeout_sender, ) def test_init(): test_setup = create_test_setup() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [] ) @@ -80,10 +82,8 @@ def test_try_send_after_init_and_is_time_false(): test_setup.abort_timeout_sender.try_send() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [ - call.is_time() - ] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) @@ -94,15 +94,9 @@ def test_try_send_after_init_and_is_time_true(): test_setup.abort_timeout_sender.try_send() - assert ( - test_setup.out_control_socket_mock.mock_calls == - [ - call.send(serialize_message(messages.Timeout(reason=test_setup.reason))) - ] - and test_setup.timer_mock.mock_calls == [ - call.is_time() - ] - ) + assert test_setup.out_control_socket_mock.mock_calls == [ + call.send(serialize_message(messages.Timeout(reason=test_setup.reason))) + ] and test_setup.timer_mock.mock_calls == [call.is_time()] def test_try_send_twice_and_is_time_false(): @@ -114,10 +108,8 @@ def test_try_send_twice_and_is_time_false(): test_setup.abort_timeout_sender.try_send() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [ - call.is_time() - ] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) @@ -129,8 +121,8 @@ def test_try_send_after_stop(is_time: bool): test_setup.abort_timeout_sender.try_send() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) @@ -139,6 +131,6 @@ def test_reset_timer(): print(test_setup.timer_mock.mock_calls) test_setup.abort_timeout_sender.reset_timer() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.reset_timer()] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.reset_timer()] ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_acknowledge_register_peer_sender.py b/tests/unit_tests/udf_communication/peer_communication/test_acknowledge_register_peer_sender.py index 29b20e09..f2bc04bf 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_acknowledge_register_peer_sender.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_acknowledge_register_peer_sender.py @@ -1,18 +1,20 @@ import dataclasses from typing import Union -from unittest.mock import MagicMock, create_autospec, call +from unittest.mock import MagicMock, call, create_autospec import pytest from exasol.analytics.udf.communication.connection_info import ConnectionInfo -from exasol.analytics.udf.communication.ip_address import Port, IPAddress +from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.acknowledge_register_peer_sender import \ - AcknowledgeRegisterPeerSender -from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import \ - RegisterPeerConnection +from exasol.analytics.udf.communication.peer_communicator.acknowledge_register_peer_sender import ( + AcknowledgeRegisterPeerSender, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import ( + RegisterPeerConnection, +) from exasol.analytics.udf.communication.peer_communicator.timer import Timer -from tests.mock_cast import mock_cast +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass @@ -30,16 +32,23 @@ def reset_mock(self): def create_test_setup(needs_to_send_for_peer: bool) -> TestSetup: - register_peer_connection_mock: Union[RegisterPeerConnection, MagicMock] = create_autospec(RegisterPeerConnection) - my_connection_info = ConnectionInfo(name="t0", - port=Port(port=1), - ipaddress=IPAddress(ip_address="127.0.0.1"), - group_identifier="g") - peer = Peer(connection_info= - ConnectionInfo(name="t1", - port=Port(port=2), - ipaddress=IPAddress(ip_address="127.0.0.1"), - group_identifier="g")) + register_peer_connection_mock: Union[RegisterPeerConnection, MagicMock] = ( + create_autospec(RegisterPeerConnection) + ) + my_connection_info = ConnectionInfo( + name="t0", + port=Port(port=1), + ipaddress=IPAddress(ip_address="127.0.0.1"), + group_identifier="g", + ) + peer = Peer( + connection_info=ConnectionInfo( + name="t1", + port=Port(port=2), + ipaddress=IPAddress(ip_address="127.0.0.1"), + group_identifier="g", + ) + ) timer_mock: Union[Timer, MagicMock] = create_autospec(Timer) acknowledge_register_peer_sender = AcknowledgeRegisterPeerSender( register_peer_connection=register_peer_connection_mock, @@ -53,7 +62,7 @@ def create_test_setup(needs_to_send_for_peer: bool) -> TestSetup: peer=peer, acknowledge_register_peer_sender=acknowledge_register_peer_sender, timer_mock=timer_mock, - register_peer_connection_mock=register_peer_connection_mock + register_peer_connection_mock=register_peer_connection_mock, ) return test_setup @@ -62,51 +71,50 @@ def create_test_setup(needs_to_send_for_peer: bool) -> TestSetup: def test_init(needs_to_send_for_peer: bool): test_setup = create_test_setup(needs_to_send_for_peer) assert ( - test_setup.register_peer_connection_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [] + test_setup.register_peer_connection_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [] ) -@pytest.mark.parametrize("needs_to_send_for_peer, is_time, send_expected", - [ - (True, True, True), - (True, False, False), - (False, True, False), - (False, False, False), - ]) -def test_try_send_after_init(needs_to_send_for_peer: bool, is_time: bool, send_expected: bool): +@pytest.mark.parametrize( + "needs_to_send_for_peer, is_time, send_expected", + [ + (True, True, True), + (True, False, False), + (False, True, False), + (False, False, False), + ], +) +def test_try_send_after_init( + needs_to_send_for_peer: bool, is_time: bool, send_expected: bool +): test_setup = create_test_setup(needs_to_send_for_peer) mock_cast(test_setup.timer_mock.is_time).return_value = is_time test_setup.reset_mock() test_setup.acknowledge_register_peer_sender.try_send() if send_expected: - assert ( - test_setup.register_peer_connection_mock.mock_calls == [call.ack(test_setup.peer)] - and test_setup.timer_mock.mock_calls == [call.is_time(), call.reset_timer()] - ) + assert test_setup.register_peer_connection_mock.mock_calls == [ + call.ack(test_setup.peer) + ] and test_setup.timer_mock.mock_calls == [call.is_time(), call.reset_timer()] else: assert ( - test_setup.register_peer_connection_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.register_peer_connection_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) @pytest.mark.parametrize( ["needs_to_send_for_peer", "is_time"], - [ - (True, True), - (True, False), - (False, True), - (False, False) - ]) + [(True, True), (True, False), (False, True), (False, False)], +) def test_stop(needs_to_send_for_peer: bool, is_time: bool): test_setup = create_test_setup(needs_to_send_for_peer) mock_cast(test_setup.timer_mock.is_time).return_value = is_time test_setup.reset_mock() test_setup.acknowledge_register_peer_sender.stop() assert ( - test_setup.register_peer_connection_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [] + test_setup.register_peer_connection_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [] ) @@ -117,7 +125,8 @@ def test_stop(needs_to_send_for_peer: bool, is_time: bool): (True, False), (False, True), (False, False), - ]) + ], +) def test_try_send_after_stop(needs_to_send_for_peer: bool, is_time: bool): test_setup = create_test_setup(needs_to_send_for_peer) mock_cast(test_setup.timer_mock.is_time).return_value = is_time @@ -125,8 +134,8 @@ def test_try_send_after_stop(needs_to_send_for_peer: bool, is_time: bool): test_setup.reset_mock() test_setup.acknowledge_register_peer_sender.try_send() assert ( - test_setup.register_peer_connection_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.register_peer_connection_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) @@ -135,9 +144,9 @@ def test_is_ready_to_stop_after_init(needs_to_send_for_peer: bool): test_setup = create_test_setup(needs_to_send_for_peer) result = test_setup.acknowledge_register_peer_sender.is_ready_to_stop() assert ( - test_setup.register_peer_connection_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [] - and result != needs_to_send_for_peer + test_setup.register_peer_connection_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [] + and result != needs_to_send_for_peer ) @@ -148,7 +157,7 @@ def test_is_ready_to_stop_after_stop(needs_to_send_for_peer: bool): test_setup.reset_mock() result = test_setup.acknowledge_register_peer_sender.is_ready_to_stop() assert ( - test_setup.register_peer_connection_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [] - and result == True + test_setup.register_peer_connection_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [] + and result == True ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_background_peer_state.py b/tests/unit_tests/udf_communication/peer_communication/test_background_peer_state.py index d1625214..f03cebc0 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_background_peer_state.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_background_peer_state.py @@ -1,27 +1,32 @@ import dataclasses from typing import Union -from unittest.mock import MagicMock, create_autospec, call +from unittest.mock import MagicMock, call, create_autospec import pytest - from polyfactory.factories.pydantic_factory import ModelFactory from exasol.analytics.udf.communication import messages from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.background_peer_state import \ - BackgroundPeerState -from exasol.analytics.udf.communication.peer_communicator. \ - background_thread.connection_closer.connection_closer import ConnectionCloser -from exasol.analytics.udf.communication.peer_communicator.connection_establisher import \ - ConnectionEstablisher -from exasol.analytics.udf.communication.peer_communicator.payload_handler import PayloadHandler -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder import \ - RegisterPeerForwarder +from exasol.analytics.udf.communication.peer_communicator.background_peer_state import ( + BackgroundPeerState, +) +from exasol.analytics.udf.communication.peer_communicator.background_thread.connection_closer.connection_closer import ( + ConnectionCloser, +) +from exasol.analytics.udf.communication.peer_communicator.connection_establisher import ( + ConnectionEstablisher, +) +from exasol.analytics.udf.communication.peer_communicator.payload_handler import ( + PayloadHandler, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder import ( + RegisterPeerForwarder, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender from exasol.analytics.udf.communication.socket_factory.abstract import Frame -from tests.mock_cast import mock_cast +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass() @@ -50,19 +55,28 @@ def create_test_setup() -> TestSetup: name="t1", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=11), - group_identifier="g" - )) + group_identifier="g", + ) + ) my_connection_info = ConnectionInfo( name="t0", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=10), - group_identifier="g" + group_identifier="g", + ) + payload_handler_mock: Union[MagicMock, PayloadHandler] = create_autospec( + PayloadHandler ) - payload_handler_mock: Union[MagicMock, PayloadHandler] = create_autospec(PayloadHandler) sender_mock: Union[MagicMock, Sender] = create_autospec(Sender) - connection_establisher_mock: Union[MagicMock, ConnectionEstablisher] = create_autospec(ConnectionEstablisher) - connection_closer_mock: Union[MagicMock, ConnectionCloser] = create_autospec(ConnectionCloser) - register_peer_forwarder_mock: Union[MagicMock, RegisterPeerForwarder] = create_autospec(RegisterPeerForwarder) + connection_establisher_mock: Union[MagicMock, ConnectionEstablisher] = ( + create_autospec(ConnectionEstablisher) + ) + connection_closer_mock: Union[MagicMock, ConnectionCloser] = create_autospec( + ConnectionCloser + ) + register_peer_forwarder_mock: Union[MagicMock, RegisterPeerForwarder] = ( + create_autospec(RegisterPeerForwarder) + ) background_peer_state = BackgroundPeerState( my_connection_info=my_connection_info, peer=peer, @@ -70,7 +84,7 @@ def create_test_setup() -> TestSetup: connection_establisher=connection_establisher_mock, register_peer_forwarder=register_peer_forwarder_mock, payload_handler=payload_handler_mock, - connection_closer=connection_closer_mock + connection_closer=connection_closer_mock, ) return TestSetup( peer=peer, @@ -80,17 +94,17 @@ def create_test_setup() -> TestSetup: background_peer_state=background_peer_state, connection_establisher_mock=connection_establisher_mock, register_peer_forwarder_mock=register_peer_forwarder_mock, - connection_closer_mock=connection_closer_mock + connection_closer_mock=connection_closer_mock, ) def test_init(): test_setup = create_test_setup() assert ( - test_setup.sender_mock.mock_calls == [] - and test_setup.connection_establisher_mock.mock_calls == [] - and test_setup.register_peer_forwarder_mock.mock_calls == [] - and test_setup.payload_handler_mock.mock_calls == [] + test_setup.sender_mock.mock_calls == [] + and test_setup.connection_establisher_mock.mock_calls == [] + and test_setup.register_peer_forwarder_mock.mock_calls == [] + and test_setup.payload_handler_mock.mock_calls == [] ) @@ -104,24 +118,35 @@ def test_init(): (False, True, True), (False, True, False), (False, False, True), - (False, False, False) - ] + (False, False, False), + ], ) -def test_try_send_no_prepare_close(connection_establisher_ready: bool, - register_peer_forwarder_ready: bool, - payload_handler_ready: bool): +def test_try_send_no_prepare_close( + connection_establisher_ready: bool, + register_peer_forwarder_ready: bool, + payload_handler_ready: bool, +): test_setup = create_test_setup() test_setup.reset_mocks() - mock_cast(test_setup.connection_establisher_mock.is_ready_to_stop).return_value = connection_establisher_ready - mock_cast(test_setup.register_peer_forwarder_mock.is_ready_to_stop).return_value = register_peer_forwarder_ready - mock_cast(test_setup.payload_handler_mock.is_ready_to_stop).return_value = payload_handler_ready + mock_cast(test_setup.connection_establisher_mock.is_ready_to_stop).return_value = ( + connection_establisher_ready + ) + mock_cast(test_setup.register_peer_forwarder_mock.is_ready_to_stop).return_value = ( + register_peer_forwarder_ready + ) + mock_cast(test_setup.payload_handler_mock.is_ready_to_stop).return_value = ( + payload_handler_ready + ) test_setup.background_peer_state.try_send() assert ( - test_setup.connection_establisher_mock.mock_calls == [call.try_send(), call.is_ready_to_stop()] - and test_setup.register_peer_forwarder_mock.mock_calls == [call.try_send(), call.is_ready_to_stop()] - and test_setup.sender_mock.mock_calls == [] - and test_setup.payload_handler_mock.mock_calls == [call.try_send(), call.is_ready_to_stop()] - and test_setup.connection_closer_mock.mock_calls == [] + test_setup.connection_establisher_mock.mock_calls + == [call.try_send(), call.is_ready_to_stop()] + and test_setup.register_peer_forwarder_mock.mock_calls + == [call.try_send(), call.is_ready_to_stop()] + and test_setup.sender_mock.mock_calls == [] + and test_setup.payload_handler_mock.mock_calls + == [call.try_send(), call.is_ready_to_stop()] + and test_setup.connection_closer_mock.mock_calls == [] ) @@ -134,25 +159,36 @@ def test_try_send_no_prepare_close(connection_establisher_ready: bool, (False, True, True), (False, True, False), (False, False, True), - (False, False, False) - ] + (False, False, False), + ], ) -def test_try_send_prepare_close_not_ready(connection_establisher_ready: bool, - register_peer_forwarder_ready: bool, - payload_handler_ready: bool): +def test_try_send_prepare_close_not_ready( + connection_establisher_ready: bool, + register_peer_forwarder_ready: bool, + payload_handler_ready: bool, +): test_setup = create_test_setup() test_setup.background_peer_state.prepare_to_stop() test_setup.reset_mocks() - mock_cast(test_setup.connection_establisher_mock.is_ready_to_stop).return_value = connection_establisher_ready - mock_cast(test_setup.register_peer_forwarder_mock.is_ready_to_stop).return_value = register_peer_forwarder_ready - mock_cast(test_setup.payload_handler_mock.is_ready_to_stop).return_value = payload_handler_ready + mock_cast(test_setup.connection_establisher_mock.is_ready_to_stop).return_value = ( + connection_establisher_ready + ) + mock_cast(test_setup.register_peer_forwarder_mock.is_ready_to_stop).return_value = ( + register_peer_forwarder_ready + ) + mock_cast(test_setup.payload_handler_mock.is_ready_to_stop).return_value = ( + payload_handler_ready + ) test_setup.background_peer_state.try_send() assert ( - test_setup.connection_establisher_mock.mock_calls == [call.try_send(), call.is_ready_to_stop()] - and test_setup.register_peer_forwarder_mock.mock_calls == [call.try_send(), call.is_ready_to_stop()] - and test_setup.sender_mock.mock_calls == [] - and test_setup.payload_handler_mock.mock_calls == [call.try_send(), call.is_ready_to_stop()] - and test_setup.connection_closer_mock.mock_calls == [] + test_setup.connection_establisher_mock.mock_calls + == [call.try_send(), call.is_ready_to_stop()] + and test_setup.register_peer_forwarder_mock.mock_calls + == [call.try_send(), call.is_ready_to_stop()] + and test_setup.sender_mock.mock_calls == [] + and test_setup.payload_handler_mock.mock_calls + == [call.try_send(), call.is_ready_to_stop()] + and test_setup.connection_closer_mock.mock_calls == [] ) @@ -160,16 +196,23 @@ def test_try_send_prepare_close_ready(): test_setup = create_test_setup() test_setup.background_peer_state.prepare_to_stop() test_setup.reset_mocks() - mock_cast(test_setup.connection_establisher_mock.is_ready_to_stop).return_value = True - mock_cast(test_setup.register_peer_forwarder_mock.is_ready_to_stop).return_value = True + mock_cast(test_setup.connection_establisher_mock.is_ready_to_stop).return_value = ( + True + ) + mock_cast(test_setup.register_peer_forwarder_mock.is_ready_to_stop).return_value = ( + True + ) mock_cast(test_setup.payload_handler_mock.is_ready_to_stop).return_value = True test_setup.background_peer_state.try_send() assert ( - test_setup.connection_establisher_mock.mock_calls == [call.try_send(), call.is_ready_to_stop()] - and test_setup.register_peer_forwarder_mock.mock_calls == [call.try_send(), call.is_ready_to_stop()] - and test_setup.sender_mock.mock_calls == [] - and test_setup.payload_handler_mock.mock_calls == [call.try_send(), call.is_ready_to_stop()] - and test_setup.connection_closer_mock.mock_calls == [call.try_send()] + test_setup.connection_establisher_mock.mock_calls + == [call.try_send(), call.is_ready_to_stop()] + and test_setup.register_peer_forwarder_mock.mock_calls + == [call.try_send(), call.is_ready_to_stop()] + and test_setup.sender_mock.mock_calls == [] + and test_setup.payload_handler_mock.mock_calls + == [call.try_send(), call.is_ready_to_stop()] + and test_setup.connection_closer_mock.mock_calls == [call.try_send()] ) @@ -178,10 +221,11 @@ def test_received_synchronize_connection(): test_setup.reset_mocks() test_setup.background_peer_state.received_synchronize_connection() assert ( - test_setup.connection_establisher_mock.mock_calls == [call.received_synchronize_connection()] - and test_setup.register_peer_forwarder_mock.mock_calls == [] - and test_setup.sender_mock.mock_calls == [] - and test_setup.payload_handler_mock.mock_calls == [] + test_setup.connection_establisher_mock.mock_calls + == [call.received_synchronize_connection()] + and test_setup.register_peer_forwarder_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + and test_setup.payload_handler_mock.mock_calls == [] ) @@ -190,10 +234,11 @@ def test_received_acknowledge_connection(): test_setup.reset_mocks() test_setup.background_peer_state.received_acknowledge_connection() assert ( - test_setup.connection_establisher_mock.mock_calls == [call.received_acknowledge_connection()] - and test_setup.register_peer_forwarder_mock.mock_calls == [] - and test_setup.sender_mock.mock_calls == [] - and test_setup.payload_handler_mock.mock_calls == [] + test_setup.connection_establisher_mock.mock_calls + == [call.received_acknowledge_connection()] + and test_setup.register_peer_forwarder_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + and test_setup.payload_handler_mock.mock_calls == [] ) @@ -202,10 +247,11 @@ def test_received_acknowledge_register_peer(): test_setup.reset_mocks() test_setup.background_peer_state.received_acknowledge_register_peer() assert ( - test_setup.register_peer_forwarder_mock.mock_calls == [call.received_acknowledge_register_peer()] - and test_setup.connection_establisher_mock.mock_calls == [] - and test_setup.sender_mock.mock_calls == [] - and test_setup.payload_handler_mock.mock_calls == [] + test_setup.register_peer_forwarder_mock.mock_calls + == [call.received_acknowledge_register_peer()] + and test_setup.connection_establisher_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + and test_setup.payload_handler_mock.mock_calls == [] ) @@ -214,10 +260,11 @@ def test_received_register_peer_complete(): test_setup.reset_mocks() test_setup.background_peer_state.received_register_peer_complete() assert ( - test_setup.register_peer_forwarder_mock.mock_calls == [call.received_register_peer_complete()] - and test_setup.connection_establisher_mock.mock_calls == [] - and test_setup.sender_mock.mock_calls == [] - and test_setup.payload_handler_mock.mock_calls == [] + test_setup.register_peer_forwarder_mock.mock_calls + == [call.received_register_peer_complete()] + and test_setup.connection_establisher_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + and test_setup.payload_handler_mock.mock_calls == [] ) @@ -227,4 +274,6 @@ def test_send_payload(): frames = [create_autospec(Frame)] payload = ModelFactory.create_factory(model=messages.Payload).build() test_setup.background_peer_state.send_payload(message=payload, frames=frames) - assert mock_cast(test_setup.payload_handler_mock.send_payload).mock_calls == [call(payload, frames)] + assert mock_cast(test_setup.payload_handler_mock.send_payload).mock_calls == [ + call(payload, frames) + ] diff --git a/tests/unit_tests/udf_communication/peer_communication/test_connection_establisher.py b/tests/unit_tests/udf_communication/peer_communication/test_connection_establisher.py index 026c31c4..8fc3b930 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_connection_establisher.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_connection_establisher.py @@ -1,20 +1,24 @@ import dataclasses from typing import Union -from unittest.mock import MagicMock, create_autospec, call +from unittest.mock import MagicMock, call, create_autospec from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.messages import AcknowledgeConnection, Message from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSender -from exasol.analytics.udf.communication.peer_communicator.connection_establisher import \ - ConnectionEstablisher -from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import \ - ConnectionIsReadySender +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSender, +) +from exasol.analytics.udf.communication.peer_communicator.connection_establisher import ( + ConnectionEstablisher, +) +from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import ( + ConnectionIsReadySender, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender -from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import \ - SynchronizeConnectionSender +from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import ( + SynchronizeConnectionSender, +) @dataclasses.dataclass() @@ -42,19 +46,25 @@ def create_test_setup() -> TestSetup: name="t1", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=11), - group_identifier="g" - )) + group_identifier="g", + ) + ) my_connection_info = ConnectionInfo( name="t0", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=10), - group_identifier="g" + group_identifier="g", ) sender_mock: Union[MagicMock, Sender] = create_autospec(Sender) - abort_timeout_sender_mock: Union[MagicMock, AbortTimeoutSender] = create_autospec(AbortTimeoutSender) - connection_is_ready_sender: Union[MagicMock, ConnectionIsReadySender] = create_autospec(ConnectionIsReadySender) - synchronize_connection_sender_mock: Union[MagicMock, SynchronizeConnectionSender] = \ - create_autospec(SynchronizeConnectionSender) + abort_timeout_sender_mock: Union[MagicMock, AbortTimeoutSender] = create_autospec( + AbortTimeoutSender + ) + connection_is_ready_sender: Union[MagicMock, ConnectionIsReadySender] = ( + create_autospec(ConnectionIsReadySender) + ) + synchronize_connection_sender_mock: Union[ + MagicMock, SynchronizeConnectionSender + ] = create_autospec(SynchronizeConnectionSender) connection_establisher = ConnectionEstablisher( my_connection_info=my_connection_info, peer=peer, @@ -77,10 +87,11 @@ def create_test_setup() -> TestSetup: def test_init(): test_setup = create_test_setup() assert ( - test_setup.synchronize_connection_sender_mock.mock_calls == [call.try_send(force=True)] - and test_setup.connection_is_ready_sender_mock.mock_calls == [] - and test_setup.abort_timeout_sender_mock.mock_calls == [] - and test_setup.sender_mock.mock_calls == [] + test_setup.synchronize_connection_sender_mock.mock_calls + == [call.try_send(force=True)] + and test_setup.connection_is_ready_sender_mock.mock_calls == [] + and test_setup.abort_timeout_sender_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] ) @@ -89,10 +100,10 @@ def test_try_send(): test_setup.reset_mock() test_setup.connection_establisher.try_send() assert ( - test_setup.synchronize_connection_sender_mock.mock_calls == [call.try_send()] - and test_setup.connection_is_ready_sender_mock.mock_calls == [call.try_send()] - and test_setup.abort_timeout_sender_mock.mock_calls == [call.try_send()] - and test_setup.sender_mock.mock_calls == [] + test_setup.synchronize_connection_sender_mock.mock_calls == [call.try_send()] + and test_setup.connection_is_ready_sender_mock.mock_calls == [call.try_send()] + and test_setup.abort_timeout_sender_mock.mock_calls == [call.try_send()] + and test_setup.sender_mock.mock_calls == [] ) @@ -101,14 +112,21 @@ def test_received_synchronize_connection(): test_setup.reset_mock() test_setup.connection_establisher.received_synchronize_connection() assert ( - test_setup.synchronize_connection_sender_mock.mock_calls == [] - and test_setup.connection_is_ready_sender_mock.mock_calls == [call.received_synchronize_connection()] - and test_setup.abort_timeout_sender_mock.mock_calls == [call.stop()] - and test_setup.sender_mock.mock_calls == [ - call.send(Message(__root__=AcknowledgeConnection( - source=test_setup.my_connection_info, - destination=test_setup.peer - )))] + test_setup.synchronize_connection_sender_mock.mock_calls == [] + and test_setup.connection_is_ready_sender_mock.mock_calls + == [call.received_synchronize_connection()] + and test_setup.abort_timeout_sender_mock.mock_calls == [call.stop()] + and test_setup.sender_mock.mock_calls + == [ + call.send( + Message( + __root__=AcknowledgeConnection( + source=test_setup.my_connection_info, + destination=test_setup.peer, + ) + ) + ) + ] ) @@ -117,8 +135,9 @@ def test_received_acknowledge_connection(): test_setup.reset_mock() test_setup.connection_establisher.received_acknowledge_connection() assert ( - test_setup.synchronize_connection_sender_mock.mock_calls == [call.stop()] - and test_setup.connection_is_ready_sender_mock.mock_calls == [call.received_acknowledge_connection()] - and test_setup.abort_timeout_sender_mock.mock_calls == [call.stop()] - and test_setup.sender_mock.mock_calls == [] + test_setup.synchronize_connection_sender_mock.mock_calls == [call.stop()] + and test_setup.connection_is_ready_sender_mock.mock_calls + == [call.received_acknowledge_connection()] + and test_setup.abort_timeout_sender_mock.mock_calls == [call.stop()] + and test_setup.sender_mock.mock_calls == [] ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_connection_establisher_builder.py b/tests/unit_tests/udf_communication/peer_communication/test_connection_establisher_builder.py index 85cc5a3b..59c2b4d7 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_connection_establisher_builder.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_connection_establisher_builder.py @@ -1,27 +1,33 @@ import dataclasses -from typing import Union, List -from unittest.mock import MagicMock, Mock, create_autospec, call +from typing import List, Union +from unittest.mock import MagicMock, Mock, call, create_autospec from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import \ - AbortTimeoutSenderFactory +from exasol.analytics.udf.communication.peer_communicator.abort_timeout_sender import ( + AbortTimeoutSenderFactory, +) from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.connection_establisher_builder import \ - ConnectionEstablisherBuilder -from exasol.analytics.udf.communication.peer_communicator.connection_establisher_factory import \ - ConnectionEstablisherFactory -from exasol.analytics.udf.communication.peer_communicator.connection_establisher_timeout_config import \ - ConnectionEstablisherTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import \ - ConnectionIsReadySenderFactory +from exasol.analytics.udf.communication.peer_communicator.connection_establisher_builder import ( + ConnectionEstablisherBuilder, +) +from exasol.analytics.udf.communication.peer_communicator.connection_establisher_factory import ( + ConnectionEstablisherFactory, +) +from exasol.analytics.udf.communication.peer_communicator.connection_establisher_timeout_config import ( + ConnectionEstablisherTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import ( + ConnectionIsReadySenderFactory, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender -from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import \ - SynchronizeConnectionSenderFactory +from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import ( + SynchronizeConnectionSenderFactory, +) from exasol.analytics.udf.communication.peer_communicator.timer import TimerFactory from exasol.analytics.udf.communication.socket_factory.abstract import Socket -from tests.mock_cast import mock_cast +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass() @@ -33,8 +39,12 @@ class TestSetup: clock_mock: Union[MagicMock, Clock] timeout_config: ConnectionEstablisherTimeoutConfig abort_timeout_sender_factory_mock: Union[MagicMock, AbortTimeoutSenderFactory] - connection_is_ready_sender_factory_mock: Union[MagicMock, ConnectionIsReadySenderFactory] - synchronize_connection_sender_factory_mock: Union[MagicMock, SynchronizeConnectionSenderFactory] + connection_is_ready_sender_factory_mock: Union[ + MagicMock, ConnectionIsReadySenderFactory + ] + synchronize_connection_sender_factory_mock: Union[ + MagicMock, SynchronizeConnectionSenderFactory + ] timer_factory_mock: Union[MagicMock, TimerFactory] timer_mocks: List[Mock] sender_mock: Union[MagicMock, Sender] @@ -58,20 +68,24 @@ def create_test_setup() -> TestSetup: name="t1", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=11), - group_identifier="g" - )) + group_identifier="g", + ) + ) my_connection_info = ConnectionInfo( name="t0", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=10), - group_identifier="g" + group_identifier="g", + ) + abort_timeout_sender_factory_mock: Union[MagicMock, AbortTimeoutSenderFactory] = ( + create_autospec(AbortTimeoutSenderFactory) ) - abort_timeout_sender_factory_mock: Union[MagicMock, AbortTimeoutSenderFactory] = create_autospec( - AbortTimeoutSenderFactory) - conncection_is_ready_sender_factory_mock: Union[MagicMock, ConnectionIsReadySenderFactory] = create_autospec( - ConnectionIsReadySenderFactory) - synchronize_connection_sender_factory_mock: Union[MagicMock, SynchronizeConnectionSenderFactory] = create_autospec( - SynchronizeConnectionSenderFactory) + conncection_is_ready_sender_factory_mock: Union[ + MagicMock, ConnectionIsReadySenderFactory + ] = create_autospec(ConnectionIsReadySenderFactory) + synchronize_connection_sender_factory_mock: Union[ + MagicMock, SynchronizeConnectionSenderFactory + ] = create_autospec(SynchronizeConnectionSenderFactory) timer_factory_mock: Union[MagicMock, TimerFactory] = create_autospec(TimerFactory) timer_mocks = [Mock(), Mock(), Mock(), Mock(), Mock()] mock_cast(timer_factory_mock.create).side_effect = timer_mocks @@ -82,16 +96,16 @@ def create_test_setup() -> TestSetup: abort_timeout_in_ms=1, synchronize_retry_timeout_in_ms=2, connection_is_ready_wait_time_in_ms=3, - ) - connection_establisher_factory_mock: Union[MagicMock, ConnectionEstablisherFactory] = \ - create_autospec(ConnectionEstablisherFactory) + connection_establisher_factory_mock: Union[ + MagicMock, ConnectionEstablisherFactory + ] = create_autospec(ConnectionEstablisherFactory) connection_establisher_builder = ConnectionEstablisherBuilder( abort_timeout_sender_factory=abort_timeout_sender_factory_mock, connection_is_ready_sender_factory=conncection_is_ready_sender_factory_mock, synchronize_connection_sender_factory=synchronize_connection_sender_factory_mock, timer_factory=timer_factory_mock, - connection_establisher_factory=connection_establisher_factory_mock + connection_establisher_factory=connection_establisher_factory_mock, ) return TestSetup( connection_establisher_builder=connection_establisher_builder, @@ -106,7 +120,7 @@ def create_test_setup() -> TestSetup: timer_factory_mock=timer_factory_mock, timer_mocks=timer_mocks, sender_mock=sender_mock, - timeout_config=timeout_config + timeout_config=timeout_config, ) @@ -114,8 +128,12 @@ def test_init(): test_setup = create_test_setup() mock_cast(test_setup.timer_factory_mock.create).assert_not_called() mock_cast(test_setup.abort_timeout_sender_factory_mock.create).assert_not_called() - mock_cast(test_setup.synchronize_connection_sender_factory_mock.create).assert_not_called() - mock_cast(test_setup.connection_is_ready_sender_factory_mock.create).assert_not_called() + mock_cast( + test_setup.synchronize_connection_sender_factory_mock.create + ).assert_not_called() + mock_cast( + test_setup.connection_is_ready_sender_factory_mock.create + ).assert_not_called() mock_cast(test_setup.connection_establisher_factory_mock.create).assert_not_called() @@ -128,7 +146,7 @@ def test_create(): clock=test_setup.clock_mock, out_control_socket=test_setup.out_control_socket_mock, peer=test_setup.peer, - timeout_config=test_setup.timeout_config + timeout_config=test_setup.timeout_config, ) assert_timer_factory(test_setup) test_setup.sender_mock.assert_not_called() @@ -141,7 +159,9 @@ def test_create(): def assert_connection_is_ready_sender_factory_mock(test_setup): - mock_cast(test_setup.connection_is_ready_sender_factory_mock.create).assert_called_once_with( + mock_cast( + test_setup.connection_is_ready_sender_factory_mock.create + ).assert_called_once_with( my_connection_info=test_setup.my_connection_info, peer=test_setup.peer, out_control_socket=test_setup.out_control_socket_mock, @@ -150,21 +170,25 @@ def assert_connection_is_ready_sender_factory_mock(test_setup): def assert_abort_timeout_sender_factory_mock(test_setup): - mock_cast(test_setup.abort_timeout_sender_factory_mock.create).assert_called_once_with( + mock_cast( + test_setup.abort_timeout_sender_factory_mock.create + ).assert_called_once_with( my_connection_info=test_setup.my_connection_info, peer=test_setup.peer, out_control_socket=test_setup.out_control_socket_mock, timer=test_setup.timer_mocks[1], - reason='Timeout occurred during establishing connection.' + reason="Timeout occurred during establishing connection.", ) def assert_synchronize_connection_sender_factory_mock(test_setup): - mock_cast(test_setup.synchronize_connection_sender_factory_mock.create).assert_called_once_with( + mock_cast( + test_setup.synchronize_connection_sender_factory_mock.create + ).assert_called_once_with( my_connection_info=test_setup.my_connection_info, peer=test_setup.peer, sender=test_setup.sender_mock, - timer=test_setup.timer_mocks[0] + timer=test_setup.timer_mocks[0], ) @@ -174,14 +198,19 @@ def assert_timer_mocks(test_setup): def assert_timer_factory(test_setup): - test_setup.timer_factory_mock.assert_has_calls([ - call.create( - clock=test_setup.clock_mock, - timeout_in_ms=test_setup.timeout_config.synchronize_retry_timeout_in_ms), - call.create( - clock=test_setup.clock_mock, - timeout_in_ms=test_setup.timeout_config.abort_timeout_in_ms), - call.create( - clock=test_setup.clock_mock, - timeout_in_ms=test_setup.timeout_config.connection_is_ready_wait_time_in_ms), - ]) + test_setup.timer_factory_mock.assert_has_calls( + [ + call.create( + clock=test_setup.clock_mock, + timeout_in_ms=test_setup.timeout_config.synchronize_retry_timeout_in_ms, + ), + call.create( + clock=test_setup.clock_mock, + timeout_in_ms=test_setup.timeout_config.abort_timeout_in_ms, + ), + call.create( + clock=test_setup.clock_mock, + timeout_in_ms=test_setup.timeout_config.connection_is_ready_wait_time_in_ms, + ), + ] + ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_connection_is_ready_sender.py b/tests/unit_tests/udf_communication/peer_communication/test_connection_is_ready_sender.py index 134fdd02..078e82e8 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_connection_is_ready_sender.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_connection_is_ready_sender.py @@ -1,6 +1,6 @@ import dataclasses from typing import Union -from unittest.mock import MagicMock, create_autospec, call +from unittest.mock import MagicMock, call, create_autospec import pytest @@ -8,12 +8,13 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import \ - ConnectionIsReadySender +from exasol.analytics.udf.communication.peer_communicator.connection_is_ready_sender import ( + ConnectionIsReadySender, +) from exasol.analytics.udf.communication.peer_communicator.timer import Timer from exasol.analytics.udf.communication.serialization import serialize_message from exasol.analytics.udf.communication.socket_factory.abstract import Socket -from tests.mock_cast import mock_cast +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass() @@ -35,13 +36,14 @@ def create_test_setup(): name="t2", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=12), - group_identifier="g" - )) + group_identifier="g", + ) + ) my_connection_info = ConnectionInfo( name="t1", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=11), - group_identifier="g" + group_identifier="g", ) timer_mock = create_autospec(Timer) out_control_socket_mock = create_autospec(Socket) @@ -55,15 +57,15 @@ def create_test_setup(): peer=peer, timer_mock=timer_mock, out_control_socket_mock=out_control_socket_mock, - peer_is_ready_sender=connection_is_ready_sender + peer_is_ready_sender=connection_is_ready_sender, ) def test_init(): test_setup = create_test_setup() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [] ) @@ -76,21 +78,19 @@ def test_try_send_after_init(is_time: bool): test_setup.peer_is_ready_sender.try_send() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [ - call.is_time() - ] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) -@pytest.mark.parametrize("is_time,send_expected", - [ - (True, True), - (False, False), - ]) -def test_try_send_after_synchronize_connection( - is_time: bool, - send_expected: bool): +@pytest.mark.parametrize( + "is_time,send_expected", + [ + (True, True), + (False, False), + ], +) +def test_try_send_after_synchronize_connection(is_time: bool, send_expected: bool): test_setup = create_test_setup() test_setup.peer_is_ready_sender.received_synchronize_connection() mock_cast(test_setup.timer_mock.is_time).return_value = is_time @@ -99,17 +99,15 @@ def test_try_send_after_synchronize_connection( test_setup.peer_is_ready_sender.try_send() if send_expected: - assert ( - test_setup.out_control_socket_mock.mock_calls == - [ - call.send(serialize_message(messages.ConnectionIsReady(peer=test_setup.peer))) - ] - and test_setup.timer_mock.mock_calls == [call.is_time()] - ) + assert test_setup.out_control_socket_mock.mock_calls == [ + call.send( + serialize_message(messages.ConnectionIsReady(peer=test_setup.peer)) + ) + ] and test_setup.timer_mock.mock_calls == [call.is_time()] else: assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) @@ -122,17 +120,15 @@ def test_try_send_after_acknowledge_connection(is_time: bool): test_setup.peer_is_ready_sender.try_send() - assert ( - test_setup.out_control_socket_mock.mock_calls == - [ - call.send(serialize_message(messages.ConnectionIsReady(peer=test_setup.peer))) - ] - and test_setup.timer_mock.mock_calls == [call.is_time()] - ) + assert test_setup.out_control_socket_mock.mock_calls == [ + call.send(serialize_message(messages.ConnectionIsReady(peer=test_setup.peer))) + ] and test_setup.timer_mock.mock_calls == [call.is_time()] @pytest.mark.parametrize("is_time", [True, False]) -def test_try_send_after_synchronize_connection_and_acknowledge_connection(is_time: bool): +def test_try_send_after_synchronize_connection_and_acknowledge_connection( + is_time: bool, +): test_setup = create_test_setup() test_setup.peer_is_ready_sender.received_synchronize_connection() test_setup.peer_is_ready_sender.received_acknowledge_connection() @@ -141,13 +137,9 @@ def test_try_send_after_synchronize_connection_and_acknowledge_connection(is_tim test_setup.peer_is_ready_sender.try_send() - assert ( - test_setup.out_control_socket_mock.mock_calls == - [ - call.send(serialize_message(messages.ConnectionIsReady(peer=test_setup.peer))) - ] - and test_setup.timer_mock.mock_calls == [call.is_time()] - ) + assert test_setup.out_control_socket_mock.mock_calls == [ + call.send(serialize_message(messages.ConnectionIsReady(peer=test_setup.peer))) + ] and test_setup.timer_mock.mock_calls == [call.is_time()] @pytest.mark.parametrize("is_time", [True, False]) @@ -161,8 +153,8 @@ def test_try_send_twice_after_synchronize_connection(is_time: bool): test_setup.peer_is_ready_sender.try_send() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) @@ -177,6 +169,6 @@ def test_try_send_twice_after_acknowledge_connection(is_time: bool): test_setup.peer_is_ready_sender.try_send() assert ( - test_setup.out_control_socket_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_payload_handler.py b/tests/unit_tests/udf_communication/peer_communication/test_payload_handler.py index d2c7a5e2..4f0233fe 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_payload_handler.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_payload_handler.py @@ -1,16 +1,22 @@ import dataclasses from typing import Union -from unittest.mock import create_autospec, MagicMock, call +from unittest.mock import MagicMock, call, create_autospec import pytest from polyfactory.factories.pydantic_factory import ModelFactory from exasol.analytics.udf.communication import messages -from exasol.analytics.udf.communication.peer_communicator.payload_handler import PayloadHandler -from exasol.analytics.udf.communication.peer_communicator.payload_receiver import PayloadReceiver -from exasol.analytics.udf.communication.peer_communicator.payload_sender import PayloadSender +from exasol.analytics.udf.communication.peer_communicator.payload_handler import ( + PayloadHandler, +) +from exasol.analytics.udf.communication.peer_communicator.payload_receiver import ( + PayloadReceiver, +) +from exasol.analytics.udf.communication.peer_communicator.payload_sender import ( + PayloadSender, +) from exasol.analytics.udf.communication.socket_factory.abstract import Frame -from tests.mock_cast import mock_cast +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass @@ -41,15 +47,19 @@ def create_test_setup() -> TestSetup: def test_init(): test_setup = create_test_setup() - assert test_setup.payload_receiver_mock.mock_calls == [] and \ - test_setup.payload_sender_mock.mock_calls == [] + assert ( + test_setup.payload_receiver_mock.mock_calls == [] + and test_setup.payload_sender_mock.mock_calls == [] + ) def test_try_send(): test_setup = create_resetted_test_setup() test_setup.payload_handler.try_send() - assert test_setup.payload_receiver_mock.mock_calls == [] and \ - test_setup.payload_sender_mock.mock_calls == [call.try_send()] + assert ( + test_setup.payload_receiver_mock.mock_calls == [] + and test_setup.payload_sender_mock.mock_calls == [call.try_send()] + ) def create_resetted_test_setup(): @@ -63,8 +73,11 @@ def test_send_payload(): frames = [create_autospec(Frame)] message = ModelFactory.create_factory(model=messages.Payload).build() test_setup.payload_handler.send_payload(message, frames) - assert test_setup.payload_receiver_mock.mock_calls == [] and \ - test_setup.payload_sender_mock.mock_calls == [call.send_payload(message, frames)] + assert ( + test_setup.payload_receiver_mock.mock_calls == [] + and test_setup.payload_sender_mock.mock_calls + == [call.send_payload(message, frames)] + ) def test_received_payload(): @@ -72,30 +85,46 @@ def test_received_payload(): frames = [create_autospec(Frame)] message = ModelFactory.create_factory(model=messages.Payload).build() test_setup.payload_handler.received_payload(message, frames) - assert test_setup.payload_receiver_mock.mock_calls == [call.received_payload(message, frames)] and \ - test_setup.payload_sender_mock.mock_calls == [] + assert ( + test_setup.payload_receiver_mock.mock_calls + == [call.received_payload(message, frames)] + and test_setup.payload_sender_mock.mock_calls == [] + ) def test_received_acknowledge_payload(): test_setup = create_resetted_test_setup() message = ModelFactory.create_factory(model=messages.AcknowledgePayload).build() test_setup.payload_handler.received_acknowledge_payload(message) - assert test_setup.payload_receiver_mock.mock_calls == [] and \ - test_setup.payload_sender_mock.mock_calls == [call.received_acknowledge_payload(message)] - - -@pytest.mark.parametrize("payload_receiver_answer,payload_sender_answer, expected", - [ - (True, True, True), - (True, False, False), - (False, True, False), - (False, False, False), - ]) -def test_is_ready_to_stop(payload_receiver_answer: bool, payload_sender_answer: bool, expected: bool): + assert ( + test_setup.payload_receiver_mock.mock_calls == [] + and test_setup.payload_sender_mock.mock_calls + == [call.received_acknowledge_payload(message)] + ) + + +@pytest.mark.parametrize( + "payload_receiver_answer,payload_sender_answer, expected", + [ + (True, True, True), + (True, False, False), + (False, True, False), + (False, False, False), + ], +) +def test_is_ready_to_stop( + payload_receiver_answer: bool, payload_sender_answer: bool, expected: bool +): test_setup = create_resetted_test_setup() - mock_cast(test_setup.payload_sender_mock.is_ready_to_stop).return_value = payload_sender_answer - mock_cast(test_setup.payload_receiver_mock.is_ready_to_stop).return_value = payload_receiver_answer + mock_cast(test_setup.payload_sender_mock.is_ready_to_stop).return_value = ( + payload_sender_answer + ) + mock_cast(test_setup.payload_receiver_mock.is_ready_to_stop).return_value = ( + payload_receiver_answer + ) result = test_setup.payload_handler.is_ready_to_stop() - assert result == expected \ - and test_setup.payload_receiver_mock.mock_calls == [call.is_ready_to_stop()] \ - and test_setup.payload_sender_mock.mock_calls == [call.is_ready_to_stop()] + assert ( + result == expected + and test_setup.payload_receiver_mock.mock_calls == [call.is_ready_to_stop()] + and test_setup.payload_sender_mock.mock_calls == [call.is_ready_to_stop()] + ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_payload_message_sender.py b/tests/unit_tests/udf_communication/peer_communication/test_payload_message_sender.py index 8eff401e..490b3287 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_payload_message_sender.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_payload_message_sender.py @@ -1,6 +1,6 @@ import dataclasses from typing import List, Union -from unittest.mock import create_autospec, MagicMock, call +from unittest.mock import MagicMock, call, create_autospec import pytest @@ -8,13 +8,14 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender import \ - PayloadMessageSender +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender import ( + PayloadMessageSender, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender from exasol.analytics.udf.communication.peer_communicator.timer import Timer from exasol.analytics.udf.communication.serialization import serialize_message -from exasol.analytics.udf.communication.socket_factory.abstract import Socket, Frame -from tests.mock_cast import mock_cast +from exasol.analytics.udf.communication.socket_factory.abstract import Frame, Socket +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass(frozen=True) @@ -44,19 +45,23 @@ def create_test_setup() -> TestSetup: retry_timer_mock = create_autospec(Timer) frame_mocks = [create_autospec(Frame)] message = messages.Payload( - source=Peer(connection_info=ConnectionInfo( - name="t1", - ipaddress=IPAddress(ip_address="127.0.0.1"), - port=Port(port=1000), - group_identifier="group" - )), - destination=Peer(connection_info=ConnectionInfo( - name="t1", - ipaddress=IPAddress(ip_address="127.0.0.1"), - port=Port(port=1000), - group_identifier="group" - )), - sequence_number=0 + source=Peer( + connection_info=ConnectionInfo( + name="t1", + ipaddress=IPAddress(ip_address="127.0.0.1"), + port=Port(port=1000), + group_identifier="group", + ) + ), + destination=Peer( + connection_info=ConnectionInfo( + name="t1", + ipaddress=IPAddress(ip_address="127.0.0.1"), + port=Port(port=1000), + group_identifier="group", + ) + ), + sequence_number=0, ) payload_message_sender = PayloadMessageSender( sender=sender_mock, @@ -64,7 +69,7 @@ def create_test_setup() -> TestSetup: retry_timer=retry_timer_mock, out_control_socket=out_control_socket_mock, message=message, - frames=frame_mocks + frames=frame_mocks, ) return TestSetup( message=message, @@ -73,33 +78,38 @@ def create_test_setup() -> TestSetup: out_control_socket_mock=out_control_socket_mock, abort_time_mock=abort_time_mock, retry_timer_mock=retry_timer_mock, - payload_message_sender=payload_message_sender + payload_message_sender=payload_message_sender, ) def test_init(): test_setup = create_test_setup() - assert test_setup.out_control_socket_mock.mock_calls == [] \ - and mock_cast(test_setup.sender_mock.send_multipart).mock_calls == [call(test_setup.frame_mocks)] \ - and test_setup.retry_timer_mock.mock_calls == [] \ - and test_setup.abort_time_mock.mock_calls == [] + assert ( + test_setup.out_control_socket_mock.mock_calls == [] + and mock_cast(test_setup.sender_mock.send_multipart).mock_calls + == [call(test_setup.frame_mocks)] + and test_setup.retry_timer_mock.mock_calls == [] + and test_setup.abort_time_mock.mock_calls == [] + ) -@pytest.mark.parametrize( - "is_retry_time", - [True, False] -) +@pytest.mark.parametrize("is_retry_time", [True, False]) def test_try_send_abort_timer_is_time_true(is_retry_time: bool): test_setup = create_test_setup() test_setup.reset_mocks() mock_cast(test_setup.abort_time_mock.is_time).return_value = True mock_cast(test_setup.retry_timer_mock.is_time).return_value = is_retry_time test_setup.payload_message_sender.try_send() - abort_payload = messages.AbortPayload(payload=test_setup.message, reason="Send timeout reached") - assert mock_cast(test_setup.out_control_socket_mock.send).mock_calls == [call(serialize_message(abort_payload))] \ - and test_setup.sender_mock.mock_calls == [] \ - and test_setup.retry_timer_mock.mock_calls == [] \ - and test_setup.abort_time_mock.mock_calls == [call.is_time()] + abort_payload = messages.AbortPayload( + payload=test_setup.message, reason="Send timeout reached" + ) + assert ( + mock_cast(test_setup.out_control_socket_mock.send).mock_calls + == [call(serialize_message(abort_payload))] + and test_setup.sender_mock.mock_calls == [] + and test_setup.retry_timer_mock.mock_calls == [] + and test_setup.abort_time_mock.mock_calls == [call.is_time()] + ) def test_try_send_abort_timer_is_time_false_retry_timer_is_time_false(): @@ -108,10 +118,12 @@ def test_try_send_abort_timer_is_time_false_retry_timer_is_time_false(): mock_cast(test_setup.abort_time_mock.is_time).return_value = False mock_cast(test_setup.retry_timer_mock.is_time).return_value = False test_setup.payload_message_sender.try_send() - assert mock_cast(test_setup.out_control_socket_mock.send).mock_calls == [] \ - and test_setup.sender_mock.mock_calls == [] \ - and test_setup.retry_timer_mock.mock_calls == [call.is_time()] \ - and test_setup.abort_time_mock.mock_calls == [call.is_time()] + assert ( + mock_cast(test_setup.out_control_socket_mock.send).mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + and test_setup.retry_timer_mock.mock_calls == [call.is_time()] + and test_setup.abort_time_mock.mock_calls == [call.is_time()] + ) def test_try_send_abort_timer_is_time_false_retry_timer_is_time_true(): @@ -120,20 +132,19 @@ def test_try_send_abort_timer_is_time_false_retry_timer_is_time_true(): mock_cast(test_setup.abort_time_mock.is_time).return_value = False mock_cast(test_setup.retry_timer_mock.is_time).return_value = True test_setup.payload_message_sender.try_send() - assert mock_cast(test_setup.out_control_socket_mock.send).mock_calls == [] \ - and mock_cast(test_setup.sender_mock.send_multipart).mock_calls == [call(test_setup.frame_mocks)] \ - and test_setup.retry_timer_mock.mock_calls == [call.is_time(), call.reset_timer()] \ - and test_setup.abort_time_mock.mock_calls == [call.is_time()] + assert ( + mock_cast(test_setup.out_control_socket_mock.send).mock_calls == [] + and mock_cast(test_setup.sender_mock.send_multipart).mock_calls + == [call(test_setup.frame_mocks)] + and test_setup.retry_timer_mock.mock_calls + == [call.is_time(), call.reset_timer()] + and test_setup.abort_time_mock.mock_calls == [call.is_time()] + ) @pytest.mark.parametrize( ["is_retry_time", "is_abort_time"], - [ - (True, True), - (True, False), - (False, True), - (False, False) - ] + [(True, True), (True, False), (False, True), (False, False)], ) def test_try_send_after_stop(is_retry_time: bool, is_abort_time: bool): test_setup = create_test_setup() @@ -142,7 +153,9 @@ def test_try_send_after_stop(is_retry_time: bool, is_abort_time: bool): mock_cast(test_setup.abort_time_mock.is_time).return_value = is_abort_time mock_cast(test_setup.retry_timer_mock.is_time).return_value = is_retry_time test_setup.payload_message_sender.try_send() - assert test_setup.out_control_socket_mock.mock_calls == [] \ - and test_setup.sender_mock.mock_calls == [] \ - and test_setup.retry_timer_mock.mock_calls == [call.is_time()] \ - and test_setup.abort_time_mock.mock_calls == [call.is_time()] + assert ( + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + and test_setup.retry_timer_mock.mock_calls == [call.is_time()] + and test_setup.abort_time_mock.mock_calls == [call.is_time()] + ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_payload_receiver.py b/tests/unit_tests/udf_communication/peer_communication/test_payload_receiver.py index 84796bc0..075a7286 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_payload_receiver.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_payload_receiver.py @@ -1,6 +1,6 @@ import dataclasses -from typing import Union, Tuple, List -from unittest.mock import create_autospec, MagicMock, call +from typing import List, Tuple, Union +from unittest.mock import MagicMock, call, create_autospec import pytest @@ -8,9 +8,11 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.payload_receiver import PayloadReceiver +from exasol.analytics.udf.communication.peer_communicator.payload_receiver import ( + PayloadReceiver, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender -from exasol.analytics.udf.communication.socket_factory.abstract import Socket, Frame +from exasol.analytics.udf.communication.socket_factory.abstract import Frame, Socket @dataclasses.dataclass @@ -30,50 +32,66 @@ def reset_mock(self): def create_test_setup() -> TestSetup: sender_mock = create_autospec(Sender) out_control_socket_mock = create_autospec(Socket) - my_connection_info = ConnectionInfo(name="t1", - ipaddress=IPAddress(ip_address="127.0.0.1"), - port=Port(port=1000), - group_identifier="group") - peer = Peer(connection_info=ConnectionInfo(name="t2", - ipaddress=IPAddress(ip_address="127.0.0.1"), - port=Port(port=2000), - group_identifier="group")) - payload_receiver = PayloadReceiver(sender=sender_mock, - out_control_socket=out_control_socket_mock, - my_connection_info=my_connection_info, - peer=peer) + my_connection_info = ConnectionInfo( + name="t1", + ipaddress=IPAddress(ip_address="127.0.0.1"), + port=Port(port=1000), + group_identifier="group", + ) + peer = Peer( + connection_info=ConnectionInfo( + name="t2", + ipaddress=IPAddress(ip_address="127.0.0.1"), + port=Port(port=2000), + group_identifier="group", + ) + ) + payload_receiver = PayloadReceiver( + sender=sender_mock, + out_control_socket=out_control_socket_mock, + my_connection_info=my_connection_info, + peer=peer, + ) return TestSetup( peer=peer, my_connection_info=my_connection_info, sender_mock=sender_mock, out_control_socket_mock=out_control_socket_mock, - payload_receiver=payload_receiver + payload_receiver=payload_receiver, ) -def create_acknowledge_payload_message(test_setup: TestSetup, message: messages.Payload) -> messages.Message: - acknowledge_message = messages.Message(__root__=messages.AcknowledgePayload( - source=Peer(connection_info=test_setup.my_connection_info), - sequence_number=message.sequence_number, - destination=test_setup.peer - )) +def create_acknowledge_payload_message( + test_setup: TestSetup, message: messages.Payload +) -> messages.Message: + acknowledge_message = messages.Message( + __root__=messages.AcknowledgePayload( + source=Peer(connection_info=test_setup.my_connection_info), + sequence_number=message.sequence_number, + destination=test_setup.peer, + ) + ) return acknowledge_message -def create_payload_message(test_setup: TestSetup, sequence_number: int) -> Tuple[messages.Payload, List[Frame]]: +def create_payload_message( + test_setup: TestSetup, sequence_number: int +) -> Tuple[messages.Payload, List[Frame]]: frames = [create_autospec(Frame)] message = messages.Payload( source=test_setup.peer, destination=Peer(connection_info=test_setup.my_connection_info), - sequence_number=sequence_number + sequence_number=sequence_number, ) return message, frames def test_init(): test_setup = create_test_setup() - assert test_setup.out_control_socket_mock.mock_calls == [] \ - and test_setup.sender_mock.mock_calls == [] + assert ( + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + ) @pytest.mark.parametrize("number_of_messages", [i for i in range(1, 10)]) @@ -88,8 +106,9 @@ def test_received_payload_in_sequence(number_of_messages: int): message, frames = create_payload_message(test_setup, sequence_number) test_setup.payload_receiver.received_payload(message, frames) acknowledge_message = create_acknowledge_payload_message(test_setup, message) - assert test_setup.out_control_socket_mock.mock_calls == [call.send_multipart(frames)] \ - and test_setup.sender_mock.mock_calls == [call.send(message=acknowledge_message)] + assert test_setup.out_control_socket_mock.mock_calls == [ + call.send_multipart(frames) + ] and test_setup.sender_mock.mock_calls == [call.send(message=acknowledge_message)] @pytest.mark.parametrize("number_of_messages", [i for i in range(1, 10)]) @@ -105,15 +124,24 @@ def test_received_payload_in_reverse_sequence(number_of_messages: int): message, frames = create_payload_message(test_setup, sequence_number) test_setup.payload_receiver.received_payload(message, frames) acknowledge_message = create_acknowledge_payload_message(test_setup, message) - send_of_previous_messages = [call.send_multipart(frames) for frames in reversed(frames_of_previous_message)] + send_of_previous_messages = [ + call.send_multipart(frames) for frames in reversed(frames_of_previous_message) + ] out_control_mock_calls = [call.send_multipart(frames)] + send_of_previous_messages - assert test_setup.out_control_socket_mock.mock_calls == out_control_mock_calls \ - and test_setup.sender_mock.mock_calls == [call.send(message=acknowledge_message)] + assert ( + test_setup.out_control_socket_mock.mock_calls == out_control_mock_calls + and test_setup.sender_mock.mock_calls + == [call.send(message=acknowledge_message)] + ) -@pytest.mark.parametrize("number_of_messages, duplicated_message", - [(i, j) for i in range(1, 10) for j in range(0, i)]) -def test_received_payload_in_sequence_multiple_times(number_of_messages: int, duplicated_message: int): +@pytest.mark.parametrize( + "number_of_messages, duplicated_message", + [(i, j) for i in range(1, 10) for j in range(0, i)], +) +def test_received_payload_in_sequence_multiple_times( + number_of_messages: int, duplicated_message: int +): test_setup = create_test_setup() for sequence_number in range(number_of_messages): message, frames = create_payload_message(test_setup, sequence_number) @@ -123,14 +151,20 @@ def test_received_payload_in_sequence_multiple_times(number_of_messages: int, du message, frames = create_payload_message(test_setup, sequence_number) test_setup.payload_receiver.received_payload(message, frames) acknowledge_message = create_acknowledge_payload_message(test_setup, message) - assert test_setup.out_control_socket_mock.mock_calls == [] \ - and test_setup.sender_mock.mock_calls == [call.send(message=acknowledge_message)] + assert ( + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls + == [call.send(message=acknowledge_message)] + ) -@pytest.mark.parametrize("number_of_messages, duplicated_message", - [(i, j) for i in range(1, 10) for j in range(1, i)]) +@pytest.mark.parametrize( + "number_of_messages, duplicated_message", + [(i, j) for i in range(1, 10) for j in range(1, i)], +) def test_received_payload_in_reverse_sequence_incomplete_multiple_times( - number_of_messages: int, duplicated_message: int): + number_of_messages: int, duplicated_message: int +): test_setup = create_test_setup() frames_of_previous_message = [] for sequence_number in range(number_of_messages - 1, 0, -1): @@ -142,8 +176,11 @@ def test_received_payload_in_reverse_sequence_incomplete_multiple_times( message, frames = create_payload_message(test_setup, sequence_number) test_setup.payload_receiver.received_payload(message, frames) acknowledge_message = create_acknowledge_payload_message(test_setup, message) - assert test_setup.out_control_socket_mock.mock_calls == [] \ - and test_setup.sender_mock.mock_calls == [call.send(message=acknowledge_message)] + assert ( + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls + == [call.send(message=acknowledge_message)] + ) @pytest.mark.parametrize("number_of_messages", [i for i in range(1, 10)]) @@ -159,7 +196,9 @@ def test_is_ready_to_close_after_received_payload_in_sequence(number_of_messages @pytest.mark.parametrize("number_of_messages", [i for i in range(1, 10)]) -def test_is_ready_to_stop_after_received_payload_in_reverse_sequence(number_of_messages: int): +def test_is_ready_to_stop_after_received_payload_in_reverse_sequence( + number_of_messages: int, +): test_setup = create_test_setup() frames_of_previous_message = [] for sequence_number in range(number_of_messages - 1, -1, -1): @@ -171,10 +210,13 @@ def test_is_ready_to_stop_after_received_payload_in_reverse_sequence(number_of_m assert is_ready_to_stop -@pytest.mark.parametrize("number_of_messages, duplicated_message_sequence_number", - [(i, j) for i in range(1, 10) for j in range(0, i)]) -def test_is_ready_to_stop_payload_in_sequence_multiple_times(number_of_messages: int, - duplicated_message_sequence_number: int): +@pytest.mark.parametrize( + "number_of_messages, duplicated_message_sequence_number", + [(i, j) for i in range(1, 10) for j in range(0, i)], +) +def test_is_ready_to_stop_payload_in_sequence_multiple_times( + number_of_messages: int, duplicated_message_sequence_number: int +): test_setup = create_test_setup() for sequence_number in range(number_of_messages): message, frames = create_payload_message(test_setup, sequence_number) @@ -188,7 +230,9 @@ def test_is_ready_to_stop_payload_in_sequence_multiple_times(number_of_messages: @pytest.mark.parametrize("number_of_messages", [i for i in range(2, 10)]) -def test_is_ready_to_stop_after_received_payload_in_reverse_sequence_incomplete(number_of_messages: int): +def test_is_ready_to_stop_after_received_payload_in_reverse_sequence_incomplete( + number_of_messages: int, +): test_setup = create_test_setup() frames_of_previous_message = [] for sequence_number in range(number_of_messages - 1, 0, -1): diff --git a/tests/unit_tests/udf_communication/peer_communication/test_payload_sender.py b/tests/unit_tests/udf_communication/peer_communication/test_payload_sender.py index 89ba53ef..dd25d134 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_payload_sender.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_payload_sender.py @@ -1,23 +1,28 @@ import dataclasses -from typing import Union, Tuple, List -from unittest.mock import create_autospec, MagicMock, call +from typing import List, Tuple, Union +from unittest.mock import MagicMock, call, create_autospec from exasol.analytics.udf.communication import messages from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator.clock import Clock -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender import \ - PayloadMessageSender -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_factory import \ - PayloadMessageSenderFactory -from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config import \ - PayloadMessageSenderTimeoutConfig -from exasol.analytics.udf.communication.peer_communicator.payload_sender import PayloadSender +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender import ( + PayloadMessageSender, +) +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_factory import ( + PayloadMessageSenderFactory, +) +from exasol.analytics.udf.communication.peer_communicator.payload_message_sender_timeout_config import ( + PayloadMessageSenderTimeoutConfig, +) +from exasol.analytics.udf.communication.peer_communicator.payload_sender import ( + PayloadSender, +) from exasol.analytics.udf.communication.peer_communicator.sender import Sender from exasol.analytics.udf.communication.serialization import serialize_message -from exasol.analytics.udf.communication.socket_factory.abstract import Socket, Frame -from tests.mock_cast import mock_cast +from exasol.analytics.udf.communication.socket_factory.abstract import Frame, Socket +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass @@ -45,30 +50,42 @@ def reset_mock(self): def create_test_setup(number_of_messages: int) -> TestSetup: sender_mock = create_autospec(Sender) out_control_socket_mock = create_autospec(Socket) - my_connection_info = ConnectionInfo(name="t1", - ipaddress=IPAddress(ip_address="127.0.0.1"), - port=Port(port=1000), - group_identifier="group") - peer = Peer(connection_info=ConnectionInfo(name="t2", - ipaddress=IPAddress(ip_address="127.0.0.1"), - port=Port(port=2000), - group_identifier="group")) + my_connection_info = ConnectionInfo( + name="t1", + ipaddress=IPAddress(ip_address="127.0.0.1"), + port=Port(port=1000), + group_identifier="group", + ) + peer = Peer( + connection_info=ConnectionInfo( + name="t2", + ipaddress=IPAddress(ip_address="127.0.0.1"), + port=Port(port=2000), + group_identifier="group", + ) + ) clock_mock = create_autospec(Clock) - payload_message_sender_factory_mock: Union[MagicMock, PayloadMessageSenderFactory] = \ - create_autospec(PayloadMessageSenderFactory) - payload_message_sender_mocks = [create_autospec(PayloadMessageSender) for i in range(number_of_messages)] - mock_cast(payload_message_sender_factory_mock.create).side_effect = payload_message_sender_mocks + payload_message_sender_factory_mock: Union[ + MagicMock, PayloadMessageSenderFactory + ] = create_autospec(PayloadMessageSenderFactory) + payload_message_sender_mocks = [ + create_autospec(PayloadMessageSender) for i in range(number_of_messages) + ] + mock_cast(payload_message_sender_factory_mock.create).side_effect = ( + payload_message_sender_mocks + ) payload_message_sender_timeout_config = PayloadMessageSenderTimeoutConfig( - abort_timeout_in_ms=2, - retry_timeout_in_ms=1 - ) - payload_sender = PayloadSender(sender=sender_mock, - out_control_socket=out_control_socket_mock, - my_connection_info=my_connection_info, - peer=peer, - clock=clock_mock, - payload_message_sender_factory=payload_message_sender_factory_mock, - payload_message_sender_timeout_config=payload_message_sender_timeout_config) + abort_timeout_in_ms=2, retry_timeout_in_ms=1 + ) + payload_sender = PayloadSender( + sender=sender_mock, + out_control_socket=out_control_socket_mock, + my_connection_info=my_connection_info, + peer=peer, + clock=clock_mock, + payload_message_sender_factory=payload_message_sender_factory_mock, + payload_message_sender_timeout_config=payload_message_sender_timeout_config, + ) return TestSetup( peer=peer, my_connection_info=my_connection_info, @@ -78,45 +95,55 @@ def create_test_setup(number_of_messages: int) -> TestSetup: payload_message_sender_factory_mock=payload_message_sender_factory_mock, payload_message_sender_mocks=payload_message_sender_mocks, payload_message_sender_timeout_config=payload_message_sender_timeout_config, - payload_sender=payload_sender + payload_sender=payload_sender, ) -def create_acknowledge_payload_message(test_setup: TestSetup, message: messages.Payload) -> messages.Message: - acknowledge_message = messages.Message(__root__=messages.AcknowledgePayload( - source=Peer(connection_info=test_setup.my_connection_info), - sequence_number=message.sequence_number, - destination=test_setup.peer - )) +def create_acknowledge_payload_message( + test_setup: TestSetup, message: messages.Payload +) -> messages.Message: + acknowledge_message = messages.Message( + __root__=messages.AcknowledgePayload( + source=Peer(connection_info=test_setup.my_connection_info), + sequence_number=message.sequence_number, + destination=test_setup.peer, + ) + ) return acknowledge_message -def create_payload_message(test_setup: TestSetup, sequence_number: int) -> Tuple[messages.Payload, List[Frame]]: +def create_payload_message( + test_setup: TestSetup, sequence_number: int +) -> Tuple[messages.Payload, List[Frame]]: frames = [create_autospec(Frame)] message = messages.Payload( source=test_setup.peer, destination=Peer(connection_info=test_setup.my_connection_info), - sequence_number=sequence_number + sequence_number=sequence_number, ) return message, frames def test_init(): test_setup = create_test_setup(number_of_messages=0) - assert test_setup.out_control_socket_mock.mock_calls == [] \ - and test_setup.sender_mock.mock_calls == [] \ - and test_setup.clock_mock.mock_calls == [] \ - and test_setup.payload_message_sender_factory_mock.mock_calls == [] + assert ( + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + and test_setup.clock_mock.mock_calls == [] + and test_setup.payload_message_sender_factory_mock.mock_calls == [] + ) def test_try_send(): test_setup = create_test_setup(number_of_messages=0) test_setup.reset_mock() test_setup.payload_sender.try_send() - assert test_setup.out_control_socket_mock.mock_calls == [] \ - and test_setup.sender_mock.mock_calls == [] \ - and test_setup.clock_mock.mock_calls == [] \ - and test_setup.payload_message_sender_factory_mock.mock_calls == [] + assert ( + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + and test_setup.clock_mock.mock_calls == [] + and test_setup.payload_message_sender_factory_mock.mock_calls == [] + ) def test_send_payload(): @@ -124,17 +151,23 @@ def test_send_payload(): test_setup.reset_mock() payload_message, frames = create_payload_message(test_setup, 0) test_setup.payload_sender.send_payload(payload_message, frames) - assert test_setup.out_control_socket_mock.mock_calls == [] \ - and test_setup.sender_mock.mock_calls == [] \ - and test_setup.clock_mock.mock_calls == [] \ - and test_setup.payload_message_sender_factory_mock.mock_calls == [ - call.create(message=payload_message, - frames=frames, - sender=test_setup.sender_mock, - out_control_socket=test_setup.out_control_socket_mock, - clock=test_setup.clock_mock, - payload_message_sender_timeout_config=test_setup.payload_message_sender_timeout_config)] \ - and test_setup.payload_message_sender_mocks[0].mock_calls == [] + assert ( + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + and test_setup.clock_mock.mock_calls == [] + and test_setup.payload_message_sender_factory_mock.mock_calls + == [ + call.create( + message=payload_message, + frames=frames, + sender=test_setup.sender_mock, + out_control_socket=test_setup.out_control_socket_mock, + clock=test_setup.clock_mock, + payload_message_sender_timeout_config=test_setup.payload_message_sender_timeout_config, + ) + ] + and test_setup.payload_message_sender_mocks[0].mock_calls == [] + ) def test_try_send_after_send_payload(): @@ -143,11 +176,13 @@ def test_try_send_after_send_payload(): test_setup.payload_sender.send_payload(payload_message, frames) test_setup.reset_mock() test_setup.payload_sender.try_send() - assert test_setup.out_control_socket_mock.mock_calls == [] \ - and test_setup.sender_mock.mock_calls == [] \ - and test_setup.clock_mock.mock_calls == [] \ - and test_setup.payload_message_sender_factory_mock.mock_calls == [] \ - and test_setup.payload_message_sender_mocks[0].mock_calls == [call.try_send()] + assert ( + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + and test_setup.clock_mock.mock_calls == [] + and test_setup.payload_message_sender_factory_mock.mock_calls == [] + and test_setup.payload_message_sender_mocks[0].mock_calls == [call.try_send()] + ) def test_received_acknowledge_payload_after_send_payload(): @@ -155,26 +190,40 @@ def test_received_acknowledge_payload_after_send_payload(): payload_message, frames = create_payload_message(test_setup, 0) test_setup.payload_sender.send_payload(payload_message, frames) test_setup.reset_mock() - acknowledge_payload_message = create_acknowledge_payload_message(test_setup, payload_message) - test_setup.payload_sender.received_acknowledge_payload(message=acknowledge_payload_message.__root__) - assert test_setup.out_control_socket_mock.mock_calls == \ - [call.send(serialize_message(acknowledge_payload_message))] \ - and test_setup.sender_mock.mock_calls == [] \ - and test_setup.clock_mock.mock_calls == [] \ - and test_setup.payload_message_sender_factory_mock.mock_calls == [] \ - and test_setup.payload_message_sender_mocks[0].mock_calls == [call.stop()] + acknowledge_payload_message = create_acknowledge_payload_message( + test_setup, payload_message + ) + test_setup.payload_sender.received_acknowledge_payload( + message=acknowledge_payload_message.__root__ + ) + assert ( + test_setup.out_control_socket_mock.mock_calls + == [call.send(serialize_message(acknowledge_payload_message))] + and test_setup.sender_mock.mock_calls == [] + and test_setup.clock_mock.mock_calls == [] + and test_setup.payload_message_sender_factory_mock.mock_calls == [] + and test_setup.payload_message_sender_mocks[0].mock_calls == [call.stop()] + ) def test_received_acknowledge_payload_twice_after_send_payload(): test_setup = create_test_setup(number_of_messages=1) payload_message, frames = create_payload_message(test_setup, 0) test_setup.payload_sender.send_payload(payload_message, frames) - acknowledge_payload_message = create_acknowledge_payload_message(test_setup, payload_message) - test_setup.payload_sender.received_acknowledge_payload(message=acknowledge_payload_message.__root__) + acknowledge_payload_message = create_acknowledge_payload_message( + test_setup, payload_message + ) + test_setup.payload_sender.received_acknowledge_payload( + message=acknowledge_payload_message.__root__ + ) test_setup.reset_mock() - test_setup.payload_sender.received_acknowledge_payload(message=acknowledge_payload_message.__root__) - assert test_setup.out_control_socket_mock.mock_calls == [] \ - and test_setup.sender_mock.mock_calls == [] \ - and test_setup.clock_mock.mock_calls == [] \ - and test_setup.payload_message_sender_factory_mock.mock_calls == [] \ - and test_setup.payload_message_sender_mocks[0].mock_calls == [] + test_setup.payload_sender.received_acknowledge_payload( + message=acknowledge_payload_message.__root__ + ) + assert ( + test_setup.out_control_socket_mock.mock_calls == [] + and test_setup.sender_mock.mock_calls == [] + and test_setup.clock_mock.mock_calls == [] + and test_setup.payload_message_sender_factory_mock.mock_calls == [] + and test_setup.payload_message_sender_mocks[0].mock_calls == [] + ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_register_peer_forwarder_is_ready_sender.py b/tests/unit_tests/udf_communication/peer_communication/test_register_peer_forwarder_is_ready_sender.py index 40d0e863..a05f49b5 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_register_peer_forwarder_is_ready_sender.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_register_peer_forwarder_is_ready_sender.py @@ -1,6 +1,6 @@ import dataclasses from typing import Union -from unittest.mock import MagicMock, create_autospec, call +from unittest.mock import MagicMock, call, create_autospec import pytest @@ -8,14 +8,16 @@ from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_behavior_config \ - import RegisterPeerForwarderBehaviorConfig -from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_is_ready_sender import \ - RegisterPeerForwarderIsReadySender +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_behavior_config import ( + RegisterPeerForwarderBehaviorConfig, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_forwarder_is_ready_sender import ( + RegisterPeerForwarderIsReadySender, +) from exasol.analytics.udf.communication.peer_communicator.timer import Timer from exasol.analytics.udf.communication.serialization import serialize_message from exasol.analytics.udf.communication.socket_factory.abstract import Socket -from tests.mock_cast import mock_cast +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass @@ -32,16 +34,23 @@ def reset_mock(self): self.out_control_socket_mock.reset_mock() -def create_test_setup(behavior_config: RegisterPeerForwarderBehaviorConfig) -> TestSetup: - my_connection_info = ConnectionInfo(name="t0", - port=Port(port=1), - ipaddress=IPAddress(ip_address="127.0.0.1"), - group_identifier="g") - peer = Peer(connection_info= - ConnectionInfo(name="t1", - port=Port(port=2), - ipaddress=IPAddress(ip_address="127.0.0.1"), - group_identifier="g")) +def create_test_setup( + behavior_config: RegisterPeerForwarderBehaviorConfig, +) -> TestSetup: + my_connection_info = ConnectionInfo( + name="t0", + port=Port(port=1), + ipaddress=IPAddress(ip_address="127.0.0.1"), + group_identifier="g", + ) + peer = Peer( + connection_info=ConnectionInfo( + name="t1", + port=Port(port=2), + ipaddress=IPAddress(ip_address="127.0.0.1"), + group_identifier="g", + ) + ) timer_mock: Union[Timer, MagicMock] = create_autospec(Timer) out_control_socket_mock: Union[Socket, MagicMock] = create_autospec(Socket) register_peer_forwarder_is_ready_sender = RegisterPeerForwarderIsReadySender( @@ -49,32 +58,34 @@ def create_test_setup(behavior_config: RegisterPeerForwarderBehaviorConfig) -> T my_connection_info=my_connection_info, behavior_config=behavior_config, timer=timer_mock, - out_control_socket=out_control_socket_mock + out_control_socket=out_control_socket_mock, ) test_setup = TestSetup( my_connection_info=my_connection_info, peer=peer, register_peer_forwarder_is_ready_sender=register_peer_forwarder_is_ready_sender, timer_mock=timer_mock, - out_control_socket_mock=out_control_socket_mock + out_control_socket_mock=out_control_socket_mock, ) return test_setup -@pytest.mark.parametrize("needs_to_send_acknowledge_register_peer,needs_to_send_register_peer", [ - (True, True), - (True, False), - (False, True), - (False, False) -]) -def test_init(needs_to_send_acknowledge_register_peer: bool, needs_to_send_register_peer: bool): - test_setup = create_test_setup(RegisterPeerForwarderBehaviorConfig( - needs_to_send_register_peer=needs_to_send_register_peer, - needs_to_send_acknowledge_register_peer=needs_to_send_acknowledge_register_peer - )) +@pytest.mark.parametrize( + "needs_to_send_acknowledge_register_peer,needs_to_send_register_peer", + [(True, True), (True, False), (False, True), (False, False)], +) +def test_init( + needs_to_send_acknowledge_register_peer: bool, needs_to_send_register_peer: bool +): + test_setup = create_test_setup( + RegisterPeerForwarderBehaviorConfig( + needs_to_send_register_peer=needs_to_send_register_peer, + needs_to_send_acknowledge_register_peer=needs_to_send_acknowledge_register_peer, + ) + ) assert ( - test_setup.timer_mock.mock_calls == [] - and test_setup.out_control_socket_mock.mock_calls == [] + test_setup.timer_mock.mock_calls == [] + and test_setup.out_control_socket_mock.mock_calls == [] ) @@ -91,32 +102,38 @@ def test_init(needs_to_send_acknowledge_register_peer: bool, needs_to_send_regis (False, True, True, False), (False, True, False, False), (False, False, True, True), - (False, False, False, True) - ]) -def test_try_send_after_init(needs_to_send_acknowledge_register_peer: bool, - needs_to_send_register_peer: bool, - is_time: bool, - expected_to_send): - test_setup = create_test_setup(RegisterPeerForwarderBehaviorConfig( - needs_to_send_register_peer=needs_to_send_register_peer, - needs_to_send_acknowledge_register_peer=needs_to_send_acknowledge_register_peer - )) + (False, False, False, True), + ], +) +def test_try_send_after_init( + needs_to_send_acknowledge_register_peer: bool, + needs_to_send_register_peer: bool, + is_time: bool, + expected_to_send, +): + test_setup = create_test_setup( + RegisterPeerForwarderBehaviorConfig( + needs_to_send_register_peer=needs_to_send_register_peer, + needs_to_send_acknowledge_register_peer=needs_to_send_acknowledge_register_peer, + ) + ) mock_cast(test_setup.timer_mock.is_time).return_value = is_time test_setup.reset_mock() test_setup.register_peer_forwarder_is_ready_sender.try_send() if expected_to_send: - assert ( - test_setup.timer_mock.mock_calls == [call.is_time()] - and test_setup.out_control_socket_mock.mock_calls == [ - call.send(serialize_message(messages.PeerRegisterForwarderIsReady( - peer=test_setup.peer - ))) - ] - ) + assert test_setup.timer_mock.mock_calls == [ + call.is_time() + ] and test_setup.out_control_socket_mock.mock_calls == [ + call.send( + serialize_message( + messages.PeerRegisterForwarderIsReady(peer=test_setup.peer) + ) + ) + ] else: assert ( - test_setup.timer_mock.mock_calls == [call.is_time()] - and test_setup.out_control_socket_mock.mock_calls == [] + test_setup.timer_mock.mock_calls == [call.is_time()] + and test_setup.out_control_socket_mock.mock_calls == [] ) @@ -133,34 +150,39 @@ def test_try_send_after_init(needs_to_send_acknowledge_register_peer: bool, (False, True, True, True), (False, True, False, True), (False, False, True, True), - (False, False, False, True) - ]) + (False, False, False, True), + ], +) def test_try_send_after_received_acknowledge_register_peer( - needs_to_send_acknowledge_register_peer: bool, - needs_to_send_register_peer: bool, - is_time: bool, - expected_to_send): - test_setup = create_test_setup(RegisterPeerForwarderBehaviorConfig( - needs_to_send_register_peer=needs_to_send_register_peer, - needs_to_send_acknowledge_register_peer=needs_to_send_acknowledge_register_peer - )) + needs_to_send_acknowledge_register_peer: bool, + needs_to_send_register_peer: bool, + is_time: bool, + expected_to_send, +): + test_setup = create_test_setup( + RegisterPeerForwarderBehaviorConfig( + needs_to_send_register_peer=needs_to_send_register_peer, + needs_to_send_acknowledge_register_peer=needs_to_send_acknowledge_register_peer, + ) + ) mock_cast(test_setup.timer_mock.is_time).return_value = is_time test_setup.register_peer_forwarder_is_ready_sender.received_acknowledge_register_peer() test_setup.reset_mock() test_setup.register_peer_forwarder_is_ready_sender.try_send() if expected_to_send: - assert ( - test_setup.timer_mock.mock_calls == [call.is_time()] - and test_setup.out_control_socket_mock.mock_calls == [ - call.send(serialize_message(messages.PeerRegisterForwarderIsReady( - peer=test_setup.peer - ))) - ] - ) + assert test_setup.timer_mock.mock_calls == [ + call.is_time() + ] and test_setup.out_control_socket_mock.mock_calls == [ + call.send( + serialize_message( + messages.PeerRegisterForwarderIsReady(peer=test_setup.peer) + ) + ) + ] else: assert ( - test_setup.timer_mock.mock_calls == [call.is_time()] - and test_setup.out_control_socket_mock.mock_calls == [] + test_setup.timer_mock.mock_calls == [call.is_time()] + and test_setup.out_control_socket_mock.mock_calls == [] ) @@ -177,34 +199,39 @@ def test_try_send_after_received_acknowledge_register_peer( (False, True, True, False), (False, True, False, False), (False, False, True, True), - (False, False, False, True) - ]) + (False, False, False, True), + ], +) def test_try_send_after_received_register_peer_complete( - needs_to_send_acknowledge_register_peer: bool, - needs_to_send_register_peer: bool, - is_time: bool, - expected_to_send): - test_setup = create_test_setup(RegisterPeerForwarderBehaviorConfig( - needs_to_send_register_peer=needs_to_send_register_peer, - needs_to_send_acknowledge_register_peer=needs_to_send_acknowledge_register_peer - )) + needs_to_send_acknowledge_register_peer: bool, + needs_to_send_register_peer: bool, + is_time: bool, + expected_to_send, +): + test_setup = create_test_setup( + RegisterPeerForwarderBehaviorConfig( + needs_to_send_register_peer=needs_to_send_register_peer, + needs_to_send_acknowledge_register_peer=needs_to_send_acknowledge_register_peer, + ) + ) mock_cast(test_setup.timer_mock.is_time).return_value = is_time test_setup.register_peer_forwarder_is_ready_sender.received_register_peer_complete() test_setup.reset_mock() test_setup.register_peer_forwarder_is_ready_sender.try_send() if expected_to_send: - assert ( - test_setup.timer_mock.mock_calls == [call.is_time()] - and test_setup.out_control_socket_mock.mock_calls == [ - call.send(serialize_message(messages.PeerRegisterForwarderIsReady( - peer=test_setup.peer - ))) - ] - ) + assert test_setup.timer_mock.mock_calls == [ + call.is_time() + ] and test_setup.out_control_socket_mock.mock_calls == [ + call.send( + serialize_message( + messages.PeerRegisterForwarderIsReady(peer=test_setup.peer) + ) + ) + ] else: assert ( - test_setup.timer_mock.mock_calls == [call.is_time()] - and test_setup.out_control_socket_mock.mock_calls == [] + test_setup.timer_mock.mock_calls == [call.is_time()] + and test_setup.out_control_socket_mock.mock_calls == [] ) @@ -221,33 +248,38 @@ def test_try_send_after_received_register_peer_complete( (False, True, True, True), (False, True, False, True), (False, False, True, True), - (False, False, False, True) - ]) + (False, False, False, True), + ], +) def test_try_send_after_received_acknowledge_register_peer_and_received_register_peer_complete( - needs_to_send_acknowledge_register_peer: bool, - needs_to_send_register_peer: bool, - is_time: bool, - expected_to_send): - test_setup = create_test_setup(RegisterPeerForwarderBehaviorConfig( - needs_to_send_register_peer=needs_to_send_register_peer, - needs_to_send_acknowledge_register_peer=needs_to_send_acknowledge_register_peer - )) + needs_to_send_acknowledge_register_peer: bool, + needs_to_send_register_peer: bool, + is_time: bool, + expected_to_send, +): + test_setup = create_test_setup( + RegisterPeerForwarderBehaviorConfig( + needs_to_send_register_peer=needs_to_send_register_peer, + needs_to_send_acknowledge_register_peer=needs_to_send_acknowledge_register_peer, + ) + ) mock_cast(test_setup.timer_mock.is_time).return_value = is_time test_setup.register_peer_forwarder_is_ready_sender.received_register_peer_complete() test_setup.register_peer_forwarder_is_ready_sender.received_acknowledge_register_peer() test_setup.reset_mock() test_setup.register_peer_forwarder_is_ready_sender.try_send() if expected_to_send: - assert ( - test_setup.timer_mock.mock_calls == [call.is_time()] - and test_setup.out_control_socket_mock.mock_calls == [ - call.send(serialize_message(messages.PeerRegisterForwarderIsReady( - peer=test_setup.peer - ))) - ] - ) + assert test_setup.timer_mock.mock_calls == [ + call.is_time() + ] and test_setup.out_control_socket_mock.mock_calls == [ + call.send( + serialize_message( + messages.PeerRegisterForwarderIsReady(peer=test_setup.peer) + ) + ) + ] else: assert ( - test_setup.timer_mock.mock_calls == [call.is_time()] - and test_setup.out_control_socket_mock.mock_calls == [] + test_setup.timer_mock.mock_calls == [call.is_time()] + and test_setup.out_control_socket_mock.mock_calls == [] ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_register_peer_sender.py b/tests/unit_tests/udf_communication/peer_communication/test_register_peer_sender.py index c9a4ab9a..f96e42ad 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_register_peer_sender.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_register_peer_sender.py @@ -1,18 +1,20 @@ import dataclasses -from typing import Union, cast, Any -from unittest.mock import MagicMock, Mock, create_autospec, call +from typing import Any, Union, cast +from unittest.mock import MagicMock, Mock, call, create_autospec import pytest from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer -from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import \ - RegisterPeerConnection -from exasol.analytics.udf.communication.peer_communicator.register_peer_sender import \ - RegisterPeerSender +from exasol.analytics.udf.communication.peer_communicator.register_peer_connection import ( + RegisterPeerConnection, +) +from exasol.analytics.udf.communication.peer_communicator.register_peer_sender import ( + RegisterPeerSender, +) from exasol.analytics.udf.communication.peer_communicator.timer import Timer -from tests.mock_cast import mock_cast +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass() @@ -34,13 +36,14 @@ def create_test_setup(needs_to_send_for_peer: bool): name="t2", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=12), - group_identifier="g" - )) + group_identifier="g", + ) + ) my_connection_info = ConnectionInfo( name="t1", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=11), - group_identifier="g" + group_identifier="g", ) timer_mock = create_autospec(Timer) register_peer_connection = create_autospec(RegisterPeerConnection) @@ -49,38 +52,43 @@ def create_test_setup(needs_to_send_for_peer: bool): my_connection_info=my_connection_info, register_peer_connection=register_peer_connection, timer=timer_mock, - needs_to_send_for_peer=needs_to_send_for_peer + needs_to_send_for_peer=needs_to_send_for_peer, ) return TestSetup( peer=peer, timer_mock=timer_mock, register_peer_connection=register_peer_connection, - register_peer_sender=register_peer_sender + register_peer_sender=register_peer_sender, ) -@pytest.mark.parametrize("needs_to_send_for_peer", - [ - (True,), - (False,), - ]) +@pytest.mark.parametrize( + "needs_to_send_for_peer", + [ + (True,), + (False,), + ], +) def test_init(needs_to_send_for_peer: bool): test_setup = create_test_setup(needs_to_send_for_peer=needs_to_send_for_peer) assert ( - test_setup.register_peer_connection.mock_calls == [] - and test_setup.timer_mock.mock_calls == [] + test_setup.register_peer_connection.mock_calls == [] + and test_setup.timer_mock.mock_calls == [] ) -@pytest.mark.parametrize("needs_to_send_for_peer,is_time,send_expected", - [ - (True, True, True), - (True, False, False), - (False, True, False), - (False, False, False), - - ]) -def test_try_send_after_init(needs_to_send_for_peer: bool, is_time: bool, send_expected: bool): +@pytest.mark.parametrize( + "needs_to_send_for_peer,is_time,send_expected", + [ + (True, True, True), + (True, False, False), + (False, True, False), + (False, False, False), + ], +) +def test_try_send_after_init( + needs_to_send_for_peer: bool, is_time: bool, send_expected: bool +): test_setup = create_test_setup(needs_to_send_for_peer=needs_to_send_for_peer) mock_cast(test_setup.timer_mock.is_time).return_value = is_time test_setup.reset_mock() @@ -88,33 +96,28 @@ def test_try_send_after_init(needs_to_send_for_peer: bool, is_time: bool, send_e test_setup.register_peer_sender.try_send() if send_expected: - assert ( - test_setup.register_peer_connection.mock_calls == - [ - call.forward(test_setup.peer) - ] - and test_setup.timer_mock.mock_calls == [ - call.is_time(), - call.reset_timer() - ] - ) + assert test_setup.register_peer_connection.mock_calls == [ + call.forward(test_setup.peer) + ] and test_setup.timer_mock.mock_calls == [call.is_time(), call.reset_timer()] else: assert ( - test_setup.register_peer_connection.mock_calls == [] - and test_setup.timer_mock.mock_calls == [ - call.is_time() - ] + test_setup.register_peer_connection.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) -@pytest.mark.parametrize("needs_to_send_for_peer,is_time,send_expected", - [ - (True, True, True), - (True, False, False), - (False, True, False), - (False, False, False), - ]) -def test_try_send_after_init_twice(needs_to_send_for_peer: bool, is_time: bool, send_expected: bool): +@pytest.mark.parametrize( + "needs_to_send_for_peer,is_time,send_expected", + [ + (True, True, True), + (True, False, False), + (False, True, False), + (False, False, False), + ], +) +def test_try_send_after_init_twice( + needs_to_send_for_peer: bool, is_time: bool, send_expected: bool +): test_setup = create_test_setup(needs_to_send_for_peer=needs_to_send_for_peer) mock_cast(test_setup.timer_mock.is_time).return_value = is_time test_setup.register_peer_sender.try_send() @@ -122,32 +125,25 @@ def test_try_send_after_init_twice(needs_to_send_for_peer: bool, is_time: bool, test_setup.register_peer_sender.try_send() if send_expected: - assert ( - test_setup.register_peer_connection.mock_calls == - [ - call.forward(test_setup.peer) - ] - and test_setup.timer_mock.mock_calls == [ - call.is_time(), - call.reset_timer() - ] - ) + assert test_setup.register_peer_connection.mock_calls == [ + call.forward(test_setup.peer) + ] and test_setup.timer_mock.mock_calls == [call.is_time(), call.reset_timer()] else: assert ( - test_setup.register_peer_connection.mock_calls == [] - and test_setup.timer_mock.mock_calls == [ - call.is_time() - ] + test_setup.register_peer_connection.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) -@pytest.mark.parametrize("needs_to_send_for_peer,is_time", - [ - (True, True), - (True, False), - (False, True), - (False, False), - ]) +@pytest.mark.parametrize( + "needs_to_send_for_peer,is_time", + [ + (True, True), + (True, False), + (False, True), + (False, False), + ], +) def test_try_send_after_stop(needs_to_send_for_peer: bool, is_time: bool): test_setup = create_test_setup(needs_to_send_for_peer=needs_to_send_for_peer) test_setup.register_peer_sender.stop() @@ -157,6 +153,6 @@ def test_try_send_after_stop(needs_to_send_for_peer: bool, is_time: bool): test_setup.register_peer_sender.try_send() assert ( - test_setup.register_peer_connection.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.register_peer_connection.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_synchronize_connection_sender.py b/tests/unit_tests/udf_communication/peer_communication/test_synchronize_connection_sender.py index 3b918659..a5b3ad2a 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_synchronize_connection_sender.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_synchronize_connection_sender.py @@ -1,14 +1,15 @@ import dataclasses from typing import Union -from unittest.mock import MagicMock, create_autospec, call +from unittest.mock import MagicMock, call, create_autospec from exasol.analytics.udf.communication import messages from exasol.analytics.udf.communication.connection_info import ConnectionInfo from exasol.analytics.udf.communication.ip_address import IPAddress, Port from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator.sender import Sender -from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import \ - SynchronizeConnectionSender +from exasol.analytics.udf.communication.peer_communicator.synchronize_connection_sender import ( + SynchronizeConnectionSender, +) from exasol.analytics.udf.communication.peer_communicator.timer import Timer from tests.utils.mock_cast import mock_cast @@ -33,13 +34,14 @@ def create_test_setup(): name="t2", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=12), - group_identifier="g" - )) + group_identifier="g", + ) + ) my_connection_info = ConnectionInfo( name="t1", ipaddress=IPAddress(ip_address="127.0.0.1"), port=Port(port=11), - group_identifier="g" + group_identifier="g", ) timer_mock = create_autospec(Timer) sender_mock = create_autospec(Sender) @@ -47,22 +49,22 @@ def create_test_setup(): sender=sender_mock, timer=timer_mock, my_connection_info=my_connection_info, - peer=peer + peer=peer, ) return TestSetup( peer=peer, sender_mock=sender_mock, timer_mock=timer_mock, my_connection_info=my_connection_info, - synchronize_connection_sender=synchronize_connection_sender + synchronize_connection_sender=synchronize_connection_sender, ) def test_init(): test_setup = create_test_setup() assert ( - test_setup.sender_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [] + test_setup.sender_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [] ) @@ -74,8 +76,8 @@ def test_try_send_after_init_and_is_time_false(): test_setup.synchronize_connection_sender.try_send() assert ( - test_setup.sender_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.sender_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) @@ -86,22 +88,17 @@ def test_try_send_after_init_and_is_time_false_and_force(): test_setup.synchronize_connection_sender.try_send(force=True) - assert ( - test_setup.sender_mock.mock_calls == - [ - call.send( - messages.Message(__root__=messages.SynchronizeConnection( - source=test_setup.my_connection_info, - destination=test_setup.peer, - attempt=1 - ))) - ] - and test_setup.timer_mock.mock_calls == - [ - call.is_time(), - call.reset_timer() - ] - ) + assert test_setup.sender_mock.mock_calls == [ + call.send( + messages.Message( + __root__=messages.SynchronizeConnection( + source=test_setup.my_connection_info, + destination=test_setup.peer, + attempt=1, + ) + ) + ) + ] and test_setup.timer_mock.mock_calls == [call.is_time(), call.reset_timer()] def test_try_send_after_init_and_is_time_true(): @@ -111,22 +108,17 @@ def test_try_send_after_init_and_is_time_true(): test_setup.synchronize_connection_sender.try_send() - assert ( - test_setup.sender_mock.mock_calls == - [ - call.send( - messages.Message(__root__=messages.SynchronizeConnection( - source=test_setup.my_connection_info, - destination=test_setup.peer, - attempt=1 - ))) - ] - and test_setup.timer_mock.mock_calls == - [ - call.is_time(), - call.reset_timer() - ] - ) + assert test_setup.sender_mock.mock_calls == [ + call.send( + messages.Message( + __root__=messages.SynchronizeConnection( + source=test_setup.my_connection_info, + destination=test_setup.peer, + attempt=1, + ) + ) + ) + ] and test_setup.timer_mock.mock_calls == [call.is_time(), call.reset_timer()] def test_try_send_twice_and_is_time_true(): @@ -137,22 +129,17 @@ def test_try_send_twice_and_is_time_true(): test_setup.synchronize_connection_sender.try_send() - assert ( - test_setup.sender_mock.mock_calls == - [ - call.send( - messages.Message(__root__=messages.SynchronizeConnection( - source=test_setup.my_connection_info, - destination=test_setup.peer, - attempt=2 - ))) - ] - and test_setup.timer_mock.mock_calls == - [ - call.is_time(), - call.reset_timer() - ] - ) + assert test_setup.sender_mock.mock_calls == [ + call.send( + messages.Message( + __root__=messages.SynchronizeConnection( + source=test_setup.my_connection_info, + destination=test_setup.peer, + attempt=2, + ) + ) + ) + ] and test_setup.timer_mock.mock_calls == [call.is_time(), call.reset_timer()] def test_received_acknowledge_connection_after_init(): @@ -163,8 +150,8 @@ def test_received_acknowledge_connection_after_init(): test_setup.synchronize_connection_sender.stop() assert ( - test_setup.sender_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [] + test_setup.sender_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [] ) @@ -177,8 +164,8 @@ def test_received_acknowledge_connection_after_send(): test_setup.synchronize_connection_sender.stop() assert ( - test_setup.sender_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [] + test_setup.sender_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [] ) @@ -191,6 +178,6 @@ def test_try_send_after_received_acknowledge_connection_and_is_time_true(): test_setup.synchronize_connection_sender.try_send() assert ( - test_setup.sender_mock.mock_calls == [] - and test_setup.timer_mock.mock_calls == [call.is_time()] + test_setup.sender_mock.mock_calls == [] + and test_setup.timer_mock.mock_calls == [call.is_time()] ) diff --git a/tests/unit_tests/udf_communication/peer_communication/test_timer.py b/tests/unit_tests/udf_communication/peer_communication/test_timer.py index 27f2e21a..2d2a8afb 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_timer.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_timer.py @@ -1,9 +1,9 @@ from typing import Union -from unittest.mock import create_autospec, MagicMock, call +from unittest.mock import MagicMock, call, create_autospec from exasol.analytics.udf.communication.peer_communicator.clock import Clock from exasol.analytics.udf.communication.peer_communicator.timer import Timer -from tests.mock_cast import mock_cast +from tests.utils.mock_cast import mock_cast def test_init(): @@ -35,6 +35,7 @@ def test_is_time_true(): assert result == True and clock_mock.mock_calls == [call.current_timestamp_in_ms()] + def test_is_time_true_after_true(): clock_mock: Union[MagicMock, Clock] = create_autospec(Clock) mock_cast(clock_mock.current_timestamp_in_ms).side_effect = [0, 11, 12] diff --git a/tests/unit_tests/udf_communication/socket_factory/fault_injection/test_socket.py b/tests/unit_tests/udf_communication/socket_factory/fault_injection/test_socket.py index 8eedada9..ff325adb 100644 --- a/tests/unit_tests/udf_communication/socket_factory/fault_injection/test_socket.py +++ b/tests/unit_tests/udf_communication/socket_factory/fault_injection/test_socket.py @@ -1,13 +1,12 @@ -from typing import Union, Optional -from unittest.mock import create_autospec, MagicMock, call +from typing import Optional, Union +from unittest.mock import MagicMock, call, create_autospec import numpy as np import pytest from numpy.random import RandomState -from exasol.analytics.udf.communication.socket_factory import abstract -from exasol.analytics.udf.communication.socket_factory import fault_injection -from tests.mock_cast import mock_cast +from exasol.analytics.udf.communication.socket_factory import abstract, fault_injection +from tests.utils.mock_cast import mock_cast def test_create_socket_with(): @@ -72,8 +71,9 @@ def test_socket_send_no_fault_connect_inproc(): def test_socket_send_mulitpart_fault(): socket_mock: Union[abstract.Socket, MagicMock] = create_autospec(abstract.Socket) - frame_mock: Union[abstract.Frame, MagicMock] = create_autospec(abstract.Frame, - spec_set=True) + frame_mock: Union[abstract.Frame, MagicMock] = create_autospec( + abstract.Frame, spec_set=True + ) random_state_mock: Union[RandomState, MagicMock] = create_autospec(RandomState) mock_cast(random_state_mock.random_sample).side_effect = [np.array([0.09])] message = [frame_mock] @@ -84,8 +84,9 @@ def test_socket_send_mulitpart_fault(): def test_socket_send_mulitpart_should_be_fault_but_bind_inproc_is_reliable(): socket_mock: Union[abstract.Socket, MagicMock] = create_autospec(abstract.Socket) - frame_mock: Union[abstract.Frame, MagicMock] = create_autospec(abstract.Frame, - spec_set=True) + frame_mock: Union[abstract.Frame, MagicMock] = create_autospec( + abstract.Frame, spec_set=True + ) frame = fault_injection.Frame(frame_mock) random_state_mock: Union[RandomState, MagicMock] = create_autospec(RandomState) mock_cast(random_state_mock.random_sample).side_effect = [np.array([0.09])] @@ -98,8 +99,9 @@ def test_socket_send_mulitpart_should_be_fault_but_bind_inproc_is_reliable(): def test_socket_send_mulitpart_should_be_fault_but_bind_random_port_inproc_is_reliable(): socket_mock: Union[abstract.Socket, MagicMock] = create_autospec(abstract.Socket) - frame_mock: Union[abstract.Frame, MagicMock] = create_autospec(abstract.Frame, - spec_set=True) + frame_mock: Union[abstract.Frame, MagicMock] = create_autospec( + abstract.Frame, spec_set=True + ) frame = fault_injection.Frame(frame_mock) random_state_mock: Union[RandomState, MagicMock] = create_autospec(RandomState) mock_cast(random_state_mock.random_sample).side_effect = [np.array([0.09])] @@ -112,8 +114,9 @@ def test_socket_send_mulitpart_should_be_fault_but_bind_random_port_inproc_is_re def test_socket_send_mulitpart_should_be_fault_but_connect_inproc_is_reiliable(): socket_mock: Union[abstract.Socket, MagicMock] = create_autospec(abstract.Socket) - frame_mock: Union[abstract.Frame, MagicMock] = create_autospec(abstract.Frame, - spec_set=True) + frame_mock: Union[abstract.Frame, MagicMock] = create_autospec( + abstract.Frame, spec_set=True + ) frame = fault_injection.Frame(frame_mock) random_state_mock: Union[RandomState, MagicMock] = create_autospec(RandomState) mock_cast(random_state_mock.random_sample).side_effect = [np.array([0.09])] @@ -126,8 +129,9 @@ def test_socket_send_mulitpart_should_be_fault_but_connect_inproc_is_reiliable() def test_socket_send_multipart_no_fault(): socket_mock: Union[abstract.Socket, MagicMock] = create_autospec(abstract.Socket) - frame_mock: Union[abstract.Frame, MagicMock] = create_autospec(abstract.Frame, - spec_set=True) + frame_mock: Union[abstract.Frame, MagicMock] = create_autospec( + abstract.Frame, spec_set=True + ) frame = fault_injection.Frame(frame_mock) random_state_mock: Union[RandomState, MagicMock] = create_autospec(RandomState) mock_cast(random_state_mock.random_sample).side_effect = [np.array([0.1])] @@ -148,7 +152,7 @@ def test_socket_receive(): def test_socket_bind(): socket_mock: Union[abstract.Socket, MagicMock] = create_autospec(abstract.Socket) random_state_mock: Union[RandomState, MagicMock] = create_autospec(RandomState) - address = 'address' + address = "address" with fault_injection.Socket(socket_mock, 0.1, random_state_mock) as socket: socket.bind(address) assert mock_cast(socket_mock.bind).mock_calls == [call(address)] @@ -157,7 +161,7 @@ def test_socket_bind(): def test_socket_bind_random_port(): socket_mock: Union[abstract.Socket, MagicMock] = create_autospec(abstract.Socket) random_state_mock: Union[RandomState, MagicMock] = create_autospec(RandomState) - address = 'address' + address = "address" with fault_injection.Socket(socket_mock, 0.1, random_state_mock) as socket: socket.bind_to_random_port(address) assert mock_cast(socket_mock.bind_to_random_port).mock_calls == [call(address)] @@ -166,7 +170,7 @@ def test_socket_bind_random_port(): def test_socket_connect(): socket_mock: Union[abstract.Socket, MagicMock] = create_autospec(abstract.Socket) random_state_mock: Union[RandomState, MagicMock] = create_autospec(RandomState) - address = 'address' + address = "address" with fault_injection.Socket(socket_mock, 0.1, random_state_mock) as socket: socket.connect(address) assert mock_cast(socket_mock.connect).mock_calls == [call(address)] @@ -177,7 +181,9 @@ def test_socket_poll(): random_state_mock: Union[RandomState, MagicMock] = create_autospec(RandomState) with fault_injection.Socket(socket_mock, 0.1, random_state_mock) as socket: socket.poll(abstract.PollerFlag.POLLIN, timeout_in_ms=1) - assert mock_cast(socket_mock.poll).mock_calls == [call(abstract.PollerFlag.POLLIN, 1)] + assert mock_cast(socket_mock.poll).mock_calls == [ + call(abstract.PollerFlag.POLLIN, 1) + ] def test_socket_set_identity(): diff --git a/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_poller.py b/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_poller.py index d0e6236e..7c135d9c 100644 --- a/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_poller.py +++ b/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_poller.py @@ -1,17 +1,24 @@ import zmq -from exasol.analytics.udf.communication.socket_factory.abstract import SocketType, \ - PollerFlag -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory +from exasol.analytics.udf.communication.socket_factory.abstract import ( + PollerFlag, + SocketType, +) +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQSocketFactory, +) def test_create_poller(): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) - with factory.create_socket(SocketType.PAIR) as socket1, \ - factory.create_socket(SocketType.PAIR) as socket2, \ - factory.create_socket(SocketType.PAIR) as socket3, \ - factory.create_socket(SocketType.PAIR) as socket4: + with factory.create_socket(SocketType.PAIR) as socket1, factory.create_socket( + SocketType.PAIR + ) as socket2, factory.create_socket( + SocketType.PAIR + ) as socket3, factory.create_socket( + SocketType.PAIR + ) as socket4: socket1.bind("inproc://test1") socket2.connect("inproc://test1") socket3.bind("inproc://test2") @@ -26,5 +33,5 @@ def test_create_poller(): result = poller.poll() assert result == { socket2: {PollerFlag.POLLIN}, - socket4: {PollerFlag.POLLIN} + socket4: {PollerFlag.POLLIN}, } diff --git a/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_socket.py b/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_socket.py index 54878429..2c416d06 100644 --- a/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_socket.py +++ b/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_socket.py @@ -1,20 +1,24 @@ -from typing import Union, Optional -from typing import Union, Optional -from unittest.mock import create_autospec, MagicMock +from typing import Optional, Union +from unittest.mock import MagicMock, create_autospec import pytest import zmq from zmq import ZMQError -from exasol.analytics.udf.communication.socket_factory.abstract import SocketType, \ - PollerFlag -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory, \ - ZMQFrame, ZMQSocket -from tests.mock_cast import mock_cast +from exasol.analytics.udf.communication.socket_factory.abstract import ( + PollerFlag, + SocketType, +) +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQFrame, + ZMQSocket, + ZMQSocketFactory, +) +from tests.utils.mock_cast import mock_cast def test_create_socket_with(): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) with factory.create_socket(SocketType.PAIR) as socket1: with factory.create_socket(SocketType.PAIR) as socket2: @@ -26,7 +30,7 @@ def test_create_socket_with(): def test_socket_send_receive(): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) with factory.create_socket(SocketType.PAIR) as socket1: with factory.create_socket(SocketType.PAIR) as socket2: @@ -39,7 +43,7 @@ def test_socket_send_receive(): def test_socket_bind_random_port(): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) with factory.create_socket(SocketType.ROUTER) as socket1: with factory.create_socket(SocketType.DEALER) as socket2: @@ -52,7 +56,7 @@ def test_socket_bind_random_port(): def test_socket_send_receive_multipart(): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) with factory.create_socket(SocketType.PAIR) as socket1: with factory.create_socket(SocketType.PAIR) as socket2: @@ -60,7 +64,7 @@ def test_socket_send_receive_multipart(): socket2.connect("inproc://test") input_message = [ factory.create_frame(b"123"), - factory.create_frame(b"456") + factory.create_frame(b"456"), ] socket1.send_multipart(input_message) output_message = socket2.receive_multipart() @@ -68,13 +72,15 @@ def test_socket_send_receive_multipart(): input_message_type = [type(frame) for frame in input_message] output_message_bytes = [frame.to_bytes() for frame in output_message] output_message_type = [type(frame) for frame in output_message] - assert input_message_bytes == output_message_bytes \ - and input_message_type == output_message_type \ - and input_message_type[0] == ZMQFrame + assert ( + input_message_bytes == output_message_bytes + and input_message_type == output_message_type + and input_message_type[0] == ZMQFrame + ) def test_socket_poll_in(): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) with factory.create_socket(SocketType.PAIR) as socket1: with factory.create_socket(SocketType.PAIR) as socket2: @@ -86,7 +92,7 @@ def test_socket_poll_in(): def test_socket_poll_out(): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) with factory.create_socket(SocketType.PAIR) as socket1: with factory.create_socket(SocketType.PAIR) as socket2: @@ -97,7 +103,7 @@ def test_socket_poll_out(): def test_socket_poll_in_out(): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) with factory.create_socket(SocketType.PAIR) as socket1: with factory.create_socket(SocketType.PAIR) as socket2: @@ -109,7 +115,7 @@ def test_socket_poll_in_out(): def test_socket_set_identity(): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) with factory.create_socket(SocketType.PAIR) as socket1: name = "test" diff --git a/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_socket_factory.py b/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_socket_factory.py index 3e6f8096..72c46190 100644 --- a/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_socket_factory.py +++ b/tests/unit_tests/udf_communication/socket_factory/zmq/test_zmq_socket_factory.py @@ -2,37 +2,50 @@ import zmq from exasol.analytics.udf.communication.socket_factory.abstract import SocketType -from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ZMQSocketFactory, \ - ZMQSocket, ZMQFrame, ZMQPoller +from exasol.analytics.udf.communication.socket_factory.zmq_wrapper import ( + ZMQFrame, + ZMQPoller, + ZMQSocket, + ZMQSocketFactory, +) -@pytest.mark.parametrize("socket_type,zmq_socket_type", - [ - (SocketType.PAIR, zmq.PAIR), - (SocketType.DEALER, zmq.DEALER), - (SocketType.ROUTER, zmq.ROUTER) - ]) +@pytest.mark.parametrize( + "socket_type,zmq_socket_type", + [ + (SocketType.PAIR, zmq.PAIR), + (SocketType.DEALER, zmq.DEALER), + (SocketType.ROUTER, zmq.ROUTER), + ], +) def test_create_socket(socket_type: SocketType, zmq_socket_type): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) socket = factory.create_socket(socket_type) - assert isinstance(socket, ZMQSocket) \ - and isinstance(socket._internal_socket, zmq.Socket) \ - and socket._internal_socket.type == zmq_socket_type + assert ( + isinstance(socket, ZMQSocket) + and isinstance(socket._internal_socket, zmq.Socket) + and socket._internal_socket.type == zmq_socket_type + ) + def test_create_frame(): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) value = b"123" frame = factory.create_frame(value) - assert isinstance(frame, ZMQFrame) \ - and frame.to_bytes() == value \ - and isinstance(frame._internal_frame, zmq.Frame) \ - and frame._internal_frame.bytes == value + assert ( + isinstance(frame, ZMQFrame) + and frame.to_bytes() == value + and isinstance(frame._internal_frame, zmq.Frame) + and frame._internal_frame.bytes == value + ) def test_create_poller(): - with zmq.Context() as context: + with zmq.Context() as context: factory = ZMQSocketFactory(context) poller = factory.create_poller() - assert isinstance(poller, ZMQPoller) and isinstance(poller._internal_poller, zmq.Poller) + assert isinstance(poller, ZMQPoller) and isinstance( + poller._internal_poller, zmq.Poller + ) diff --git a/tests/unit_tests/udf_communication/test_broadcast_operation.py b/tests/unit_tests/udf_communication/test_broadcast_operation.py index b890e0dc..055a3b45 100644 --- a/tests/unit_tests/udf_communication/test_broadcast_operation.py +++ b/tests/unit_tests/udf_communication/test_broadcast_operation.py @@ -1,6 +1,6 @@ import dataclasses -from typing import Union, List, Optional -from unittest.mock import MagicMock, create_autospec, call, Mock +from typing import List, Optional, Union +from unittest.mock import MagicMock, Mock, call, create_autospec from polyfactory.factories.pydantic_factory import ModelFactory @@ -9,8 +9,11 @@ from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator from exasol.analytics.udf.communication.serialization import serialize_message -from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory, Frame -from tests.mock_cast import mock_cast +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Frame, + SocketFactory, +) +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass(frozen=True) @@ -30,15 +33,21 @@ def reset_mocks(self): def create_setup(value: Optional[bytes]) -> Fixture: sequence_number = 0 - localhost_communicator_mock: Union[MagicMock, PeerCommunicator] = create_autospec(PeerCommunicator) - multi_node_communicator_mock: Union[MagicMock, PeerCommunicator] = create_autospec(PeerCommunicator) - socket_factory_mock: Union[MagicMock, SocketFactory] = create_autospec(SocketFactory) + localhost_communicator_mock: Union[MagicMock, PeerCommunicator] = create_autospec( + PeerCommunicator + ) + multi_node_communicator_mock: Union[MagicMock, PeerCommunicator] = create_autospec( + PeerCommunicator + ) + socket_factory_mock: Union[MagicMock, SocketFactory] = create_autospec( + SocketFactory + ) broadcast_operation = BroadcastOperation( sequence_number=sequence_number, value=value, localhost_communicator=localhost_communicator_mock, multi_node_communicator=multi_node_communicator_mock, - socket_factory=socket_factory_mock + socket_factory=socket_factory_mock, ) test_setup = Fixture( sequence_number=sequence_number, @@ -46,7 +55,7 @@ def create_setup(value: Optional[bytes]) -> Fixture: localhost_communicator_mock=localhost_communicator_mock, multi_node_communicator_mock=multi_node_communicator_mock, socket_factory_mock=socket_factory_mock, - broadcast_operation=broadcast_operation + broadcast_operation=broadcast_operation, ) return test_setup @@ -54,9 +63,9 @@ def create_setup(value: Optional[bytes]) -> Fixture: def test_init(): test_setup = create_setup(value=None) assert ( - test_setup.multi_node_communicator_mock.mock_calls == [] - and test_setup.localhost_communicator_mock.mock_calls == [] - and test_setup.socket_factory_mock.mock_calls == [] + test_setup.multi_node_communicator_mock.mock_calls == [] + and test_setup.localhost_communicator_mock.mock_calls == [] + and test_setup.socket_factory_mock.mock_calls == [] ) @@ -69,19 +78,27 @@ def test_call_localhost_rank_greater_zero(): leader = ModelFactory.create_factory(Peer).build() test_setup.localhost_communicator_mock.peer = peer test_setup.localhost_communicator_mock.leader = leader - frames: List[Union[Frame, MagicMock]] = [create_autospec(Frame), create_autospec(Frame)] - mock_cast(frames[0].to_bytes).return_value = serialize_message(messages.Broadcast( - source=leader, - destination=peer, - sequence_number=test_setup.sequence_number, - )) + frames: List[Union[Frame, MagicMock]] = [ + create_autospec(Frame), + create_autospec(Frame), + ] + mock_cast(frames[0].to_bytes).return_value = serialize_message( + messages.Broadcast( + source=leader, + destination=peer, + sequence_number=test_setup.sequence_number, + ) + ) mock_cast(frames[1].to_bytes).return_value = expected_value mock_cast(test_setup.localhost_communicator_mock.recv).side_effect = [frames] result = test_setup.broadcast_operation() - assert result == expected_value \ - and mock_cast(test_setup.localhost_communicator_mock.recv).mock_calls == [call(peer=leader)] \ - and test_setup.socket_factory_mock.mock_calls == [] \ - and test_setup.multi_node_communicator_mock.mock_calls == [] + assert ( + result == expected_value + and mock_cast(test_setup.localhost_communicator_mock.recv).mock_calls + == [call(peer=leader)] + and test_setup.socket_factory_mock.mock_calls == [] + and test_setup.multi_node_communicator_mock.mock_calls == [] + ) def test_call_localhost_rank_equal_zero_multi_node_rank_greater_zero(): @@ -98,32 +115,43 @@ def test_call_localhost_rank_equal_zero_multi_node_rank_greater_zero(): localhost_leader = ModelFactory.create_factory(Peer).build() test_setup.localhost_communicator_mock.leader = localhost_leader test_setup.multi_node_communicator_mock.leader = multi_node_leader - frames: List[Union[Frame, MagicMock]] = [create_autospec(Frame), create_autospec(Frame)] - mock_cast(frames[0].to_bytes).return_value = serialize_message(messages.Broadcast( - source=multi_node_leader, - destination=multi_node_peer, - sequence_number=test_setup.sequence_number, - )) + frames: List[Union[Frame, MagicMock]] = [ + create_autospec(Frame), + create_autospec(Frame), + ] + mock_cast(frames[0].to_bytes).return_value = serialize_message( + messages.Broadcast( + source=multi_node_leader, + destination=multi_node_peer, + sequence_number=test_setup.sequence_number, + ) + ) mock_cast(frames[1].to_bytes).return_value = expected_value mock_cast(test_setup.multi_node_communicator_mock.recv).side_effect = [frames] - mock_cast(test_setup.localhost_communicator_mock.peers).return_value = [localhost_leader, localhost_peer] + mock_cast(test_setup.localhost_communicator_mock.peers).return_value = [ + localhost_leader, + localhost_peer, + ] result = test_setup.broadcast_operation() - assert result == expected_value \ - and mock_cast(test_setup.localhost_communicator_mock.send).mock_calls == [ - call(peer=localhost_peer, message=[frame_mocks[0], frames[1]]) - ] \ - and mock_cast(test_setup.multi_node_communicator_mock.recv).mock_calls == [ - call(multi_node_leader) - ] \ - and mock_cast(test_setup.socket_factory_mock.create_frame).mock_calls == [ - call(serialize_message( - messages.Broadcast( - source=localhost_leader, - destination=localhost_peer, - sequence_number=0 - ) - )) - ] + assert ( + result == expected_value + and mock_cast(test_setup.localhost_communicator_mock.send).mock_calls + == [call(peer=localhost_peer, message=[frame_mocks[0], frames[1]])] + and mock_cast(test_setup.multi_node_communicator_mock.recv).mock_calls + == [call(multi_node_leader)] + and mock_cast(test_setup.socket_factory_mock.create_frame).mock_calls + == [ + call( + serialize_message( + messages.Broadcast( + source=localhost_leader, + destination=localhost_peer, + sequence_number=0, + ) + ) + ) + ] + ) def test_call_localhost_rank_equal_zero_multi_node_rank_equal_zero_multi_node_number_of_peers_one(): @@ -136,32 +164,47 @@ def test_call_localhost_rank_equal_zero_multi_node_rank_equal_zero_multi_node_nu localhost_leader = ModelFactory.create_factory(Peer).build() test_setup.localhost_communicator_mock.leader = localhost_leader test_setup.multi_node_communicator_mock.leader = multi_node_leader - frame_mocks: List[Union[Frame, MagicMock]] = [create_autospec(Frame), create_autospec(Frame)] + frame_mocks: List[Union[Frame, MagicMock]] = [ + create_autospec(Frame), + create_autospec(Frame), + ] mock_cast(test_setup.socket_factory_mock.create_frame).side_effect = frame_mocks - mock_cast(frame_mocks[0].to_bytes).return_value = serialize_message(messages.Broadcast( - source=localhost_leader, - destination=localhost_peer, - sequence_number=test_setup.sequence_number, - )) + mock_cast(frame_mocks[0].to_bytes).return_value = serialize_message( + messages.Broadcast( + source=localhost_leader, + destination=localhost_peer, + sequence_number=test_setup.sequence_number, + ) + ) mock_cast(frame_mocks[1].to_bytes).return_value = test_setup.value - mock_cast(test_setup.localhost_communicator_mock.peers).return_value = [localhost_leader, localhost_peer] - mock_cast(test_setup.multi_node_communicator_mock.peers).return_value = [multi_node_leader] + mock_cast(test_setup.localhost_communicator_mock.peers).return_value = [ + localhost_leader, + localhost_peer, + ] + mock_cast(test_setup.multi_node_communicator_mock.peers).return_value = [ + multi_node_leader + ] result = test_setup.broadcast_operation() - assert result == test_setup.value \ - and mock_cast(test_setup.localhost_communicator_mock.send).mock_calls == [ - call(peer=localhost_peer, message=[frame_mocks[1], frame_mocks[0]]) - ] \ - and mock_cast(test_setup.multi_node_communicator_mock.peers).mock_calls == [call()] \ - and mock_cast(test_setup.socket_factory_mock.create_frame).mock_calls == [ - call(b'0'), - call(serialize_message( - messages.Broadcast( - source=localhost_leader, - destination=localhost_peer, - sequence_number=0 - ) - )) - ] + assert ( + result == test_setup.value + and mock_cast(test_setup.localhost_communicator_mock.send).mock_calls + == [call(peer=localhost_peer, message=[frame_mocks[1], frame_mocks[0]])] + and mock_cast(test_setup.multi_node_communicator_mock.peers).mock_calls + == [call()] + and mock_cast(test_setup.socket_factory_mock.create_frame).mock_calls + == [ + call(b"0"), + call( + serialize_message( + messages.Broadcast( + source=localhost_leader, + destination=localhost_peer, + sequence_number=0, + ) + ) + ), + ] + ) def test_call_localhost_rank_equal_zero_multi_node_rank_equal_zero_multi_node_number_of_peers_two(): @@ -174,29 +217,44 @@ def test_call_localhost_rank_equal_zero_multi_node_rank_equal_zero_multi_node_nu localhost_leader = ModelFactory.create_factory(Peer).build() test_setup.localhost_communicator_mock.leader = localhost_leader test_setup.multi_node_communicator_mock.leader = multi_node_leader - frame_mocks: List[Union[Frame, MagicMock]] = [create_autospec(Frame), create_autospec(Frame)] + frame_mocks: List[Union[Frame, MagicMock]] = [ + create_autospec(Frame), + create_autospec(Frame), + ] mock_cast(test_setup.socket_factory_mock.create_frame).side_effect = frame_mocks - mock_cast(frame_mocks[0].to_bytes).return_value = serialize_message(messages.Broadcast( - source=multi_node_leader, - destination=multi_node_peer, - sequence_number=test_setup.sequence_number, - )) + mock_cast(frame_mocks[0].to_bytes).return_value = serialize_message( + messages.Broadcast( + source=multi_node_leader, + destination=multi_node_peer, + sequence_number=test_setup.sequence_number, + ) + ) mock_cast(frame_mocks[1].to_bytes).return_value = test_setup.value - mock_cast(test_setup.localhost_communicator_mock.peers).return_value = [localhost_leader] - mock_cast(test_setup.multi_node_communicator_mock.peers).return_value = [multi_node_leader, multi_node_peer] + mock_cast(test_setup.localhost_communicator_mock.peers).return_value = [ + localhost_leader + ] + mock_cast(test_setup.multi_node_communicator_mock.peers).return_value = [ + multi_node_leader, + multi_node_peer, + ] result = test_setup.broadcast_operation() - assert result == test_setup.value \ - and mock_cast(test_setup.multi_node_communicator_mock.send).mock_calls == [ - call(peer=multi_node_peer, message=[frame_mocks[1], frame_mocks[0]]) - ] \ - and mock_cast(test_setup.localhost_communicator_mock.peers).mock_calls == [call()] \ - and mock_cast(test_setup.socket_factory_mock.create_frame).mock_calls == [ - call(b'0'), - call(serialize_message( - messages.Broadcast( - source=multi_node_leader, - destination=multi_node_peer, - sequence_number=0 - ) - )) - ] + assert ( + result == test_setup.value + and mock_cast(test_setup.multi_node_communicator_mock.send).mock_calls + == [call(peer=multi_node_peer, message=[frame_mocks[1], frame_mocks[0]])] + and mock_cast(test_setup.localhost_communicator_mock.peers).mock_calls + == [call()] + and mock_cast(test_setup.socket_factory_mock.create_frame).mock_calls + == [ + call(b"0"), + call( + serialize_message( + messages.Broadcast( + source=multi_node_leader, + destination=multi_node_peer, + sequence_number=0, + ) + ) + ), + ] + ) diff --git a/tests/unit_tests/udf_communication/test_gather_operation.py b/tests/unit_tests/udf_communication/test_gather_operation.py index e9aebf2f..342eddae 100644 --- a/tests/unit_tests/udf_communication/test_gather_operation.py +++ b/tests/unit_tests/udf_communication/test_gather_operation.py @@ -1,6 +1,6 @@ import dataclasses from typing import Union -from unittest.mock import MagicMock, create_autospec, call, Mock +from unittest.mock import MagicMock, Mock, call, create_autospec from polyfactory.factories.pydantic_factory import ModelFactory @@ -8,9 +8,15 @@ from exasol.analytics.udf.communication.messages import Gather from exasol.analytics.udf.communication.peer import Peer from exasol.analytics.udf.communication.peer_communicator import PeerCommunicator -from exasol.analytics.udf.communication.serialization import serialize_message, deserialize_message -from exasol.analytics.udf.communication.socket_factory.abstract import SocketFactory, Frame -from tests.mock_cast import mock_cast +from exasol.analytics.udf.communication.serialization import ( + deserialize_message, + serialize_message, +) +from exasol.analytics.udf.communication.socket_factory.abstract import ( + Frame, + SocketFactory, +) +from tests.utils.mock_cast import mock_cast @dataclasses.dataclass(frozen=True) @@ -32,16 +38,22 @@ def reset_mocks(self): def create_setup(number_of_instances_per_node: int) -> Fixture: sequence_number = 0 value = b"0" - localhost_communicator_mock: Union[MagicMock, PeerCommunicator] = create_autospec(PeerCommunicator) - multi_node_communicator_mock: Union[MagicMock, PeerCommunicator] = create_autospec(PeerCommunicator) - socket_factory_mock: Union[MagicMock, SocketFactory] = create_autospec(SocketFactory) + localhost_communicator_mock: Union[MagicMock, PeerCommunicator] = create_autospec( + PeerCommunicator + ) + multi_node_communicator_mock: Union[MagicMock, PeerCommunicator] = create_autospec( + PeerCommunicator + ) + socket_factory_mock: Union[MagicMock, SocketFactory] = create_autospec( + SocketFactory + ) gather_operation = GatherOperation( sequence_number=sequence_number, value=value, number_of_instances_per_node=number_of_instances_per_node, localhost_communicator=localhost_communicator_mock, multi_node_communicator=multi_node_communicator_mock, - socket_factory=socket_factory_mock + socket_factory=socket_factory_mock, ) test_setup = Fixture( sequence_number=sequence_number, @@ -50,7 +62,7 @@ def create_setup(number_of_instances_per_node: int) -> Fixture: localhost_communicator_mock=localhost_communicator_mock, multi_node_communicator_mock=multi_node_communicator_mock, socket_factory_mock=socket_factory_mock, - gather_operation=gather_operation + gather_operation=gather_operation, ) return test_setup @@ -58,10 +70,11 @@ def create_setup(number_of_instances_per_node: int) -> Fixture: def test_init(): test_setup = create_setup(number_of_instances_per_node=2) assert ( - test_setup.multi_node_communicator_mock.mock_calls == [] - and test_setup.localhost_communicator_mock.mock_calls == [] + test_setup.multi_node_communicator_mock.mock_calls == [] + and test_setup.localhost_communicator_mock.mock_calls == [] and test_setup.socket_factory_mock.mock_calls == [] - ) + ) + def test_call_localhost_rank_greater_zero(): test_setup = create_setup(number_of_instances_per_node=2) @@ -74,20 +87,26 @@ def test_call_localhost_rank_greater_zero(): test_setup.localhost_communicator_mock.peer = peer test_setup.localhost_communicator_mock.leader = leader result = test_setup.gather_operation() - assert result is None \ - and mock_cast(test_setup.localhost_communicator_mock.send).mock_calls == [ - call(peer=leader, message=[frame_mocks[1], frame_mocks[0]]) - ] and mock_cast(test_setup.socket_factory_mock.create_frame).mock_calls == [ - call(test_setup.value), - call(serialize_message( - Gather( - source=peer, - destination=leader, - position=1, - sequence_number=test_setup.sequence_number, - ) - )) - ] and test_setup.multi_node_communicator_mock.mock_calls == [] + assert ( + result is None + and mock_cast(test_setup.localhost_communicator_mock.send).mock_calls + == [call(peer=leader, message=[frame_mocks[1], frame_mocks[0]])] + and mock_cast(test_setup.socket_factory_mock.create_frame).mock_calls + == [ + call(test_setup.value), + call( + serialize_message( + Gather( + source=peer, + destination=leader, + position=1, + sequence_number=test_setup.sequence_number, + ) + ) + ), + ] + and test_setup.multi_node_communicator_mock.mock_calls == [] + ) def test_call_localhost_rank_equal_zero_multi_node_rank_greater_zero(): @@ -104,49 +123,67 @@ def test_call_localhost_rank_equal_zero_multi_node_rank_greater_zero(): test_setup.multi_node_communicator_mock.peer = multi_node_peer test_setup.localhost_communicator_mock.peer = localhost_leader recv_message_frame_mock: Union[MagicMock, Frame] = create_autospec(Frame) - mock_cast(recv_message_frame_mock.to_bytes).return_value = serialize_message(Gather( - source=localhost_peer, - destination=localhost_leader, - sequence_number=test_setup.sequence_number, - position=1 - )) + mock_cast(recv_message_frame_mock.to_bytes).return_value = serialize_message( + Gather( + source=localhost_peer, + destination=localhost_leader, + sequence_number=test_setup.sequence_number, + position=1, + ) + ) recv_value_frame_mock: Union[MagicMock, Frame] = create_autospec(Frame) - mock_cast(test_setup.multi_node_communicator_mock.peers).return_value = [multi_node_leader, multi_node_peer] + mock_cast(test_setup.multi_node_communicator_mock.peers).return_value = [ + multi_node_leader, + multi_node_peer, + ] test_setup.multi_node_communicator_mock.leader = multi_node_leader - mock_cast(test_setup.localhost_communicator_mock.peers).return_value = [localhost_leader, localhost_peer] - mock_cast(test_setup.localhost_communicator_mock.poll_peers).return_value = [localhost_peer] + mock_cast(test_setup.localhost_communicator_mock.peers).return_value = [ + localhost_leader, + localhost_peer, + ] + mock_cast(test_setup.localhost_communicator_mock.poll_peers).return_value = [ + localhost_peer + ] mock_cast(test_setup.localhost_communicator_mock.recv).side_effect = [ - [recv_message_frame_mock, recv_value_frame_mock]] + [recv_message_frame_mock, recv_value_frame_mock] + ] result = test_setup.gather_operation() - assert result is None \ - and test_setup.localhost_communicator_mock.mock_calls == [ - call.peers(), - call.poll_peers(), - call.recv(localhost_peer) - ] \ - and test_setup.multi_node_communicator_mock.mock_calls == [ - call.send(peer=multi_node_leader, message=[frame_mocks[1], frame_mocks[0]]), - call.send(peer=multi_node_leader, message=[frame_mocks[2], recv_value_frame_mock]) - ] \ - and mock_cast(test_setup.socket_factory_mock.create_frame).mock_calls == [ - call(b'0'), - call(serialize_message( - Gather( - source=multi_node_peer, - destination=multi_node_leader, - position=2, - sequence_number=0 - ) - )), - call(serialize_message( - Gather( - source=multi_node_peer, - destination=multi_node_leader, - position=3, - sequence_number=0 - ) - )), - ] + assert ( + result is None + and test_setup.localhost_communicator_mock.mock_calls + == [call.peers(), call.poll_peers(), call.recv(localhost_peer)] + and test_setup.multi_node_communicator_mock.mock_calls + == [ + call.send(peer=multi_node_leader, message=[frame_mocks[1], frame_mocks[0]]), + call.send( + peer=multi_node_leader, message=[frame_mocks[2], recv_value_frame_mock] + ), + ] + and mock_cast(test_setup.socket_factory_mock.create_frame).mock_calls + == [ + call(b"0"), + call( + serialize_message( + Gather( + source=multi_node_peer, + destination=multi_node_leader, + position=2, + sequence_number=0, + ) + ) + ), + call( + serialize_message( + Gather( + source=multi_node_peer, + destination=multi_node_leader, + position=3, + sequence_number=0, + ) + ) + ), + ] + ) def test_call_localhost_rank_equal_zero_multi_node_rank_equal_zero_multi_node_number_of_peers_one(): @@ -162,24 +199,29 @@ def test_call_localhost_rank_equal_zero_multi_node_rank_equal_zero_multi_node_nu test_setup.multi_node_communicator_mock.peer = multi_node_peer test_setup.localhost_communicator_mock.peer = localhost_leader recv_message_frame_mock: Union[MagicMock, Frame] = create_autospec(Frame) - mock_cast(recv_message_frame_mock.to_bytes).return_value = serialize_message(Gather( - source=localhost_peer, - destination=localhost_leader, - sequence_number=test_setup.sequence_number, - position=1 - )) + mock_cast(recv_message_frame_mock.to_bytes).return_value = serialize_message( + Gather( + source=localhost_peer, + destination=localhost_leader, + sequence_number=test_setup.sequence_number, + position=1, + ) + ) recv_value_frame_mock: Union[MagicMock, Frame] = create_autospec(Frame) - mock_cast(test_setup.localhost_communicator_mock.poll_peers).return_value = [localhost_peer] + mock_cast(test_setup.localhost_communicator_mock.poll_peers).return_value = [ + localhost_peer + ] mock_cast(test_setup.localhost_communicator_mock.recv).side_effect = [ - [recv_message_frame_mock, recv_value_frame_mock]] + [recv_message_frame_mock, recv_value_frame_mock] + ] result = test_setup.gather_operation() - assert result is not None \ - and test_setup.localhost_communicator_mock.mock_calls == [ - call.poll_peers(), - call.recv(localhost_peer) - ] \ - and test_setup.multi_node_communicator_mock.mock_calls == [] \ - and mock_cast(test_setup.socket_factory_mock).mock_calls == [] + assert ( + result is not None + and test_setup.localhost_communicator_mock.mock_calls + == [call.poll_peers(), call.recv(localhost_peer)] + and test_setup.multi_node_communicator_mock.mock_calls == [] + and mock_cast(test_setup.socket_factory_mock).mock_calls == [] + ) def test_call_localhost_rank_equal_zero_multi_node_rank_equal_zero_multi_node_number_of_peers_two(): @@ -192,23 +234,31 @@ def test_call_localhost_rank_equal_zero_multi_node_rank_equal_zero_multi_node_nu multi_node_leader = ModelFactory.create_factory(Peer).build() test_setup.multi_node_communicator_mock.peer = multi_node_peer recv_message_frame_mock: Union[MagicMock, Frame] = create_autospec(Frame) - mock_cast(recv_message_frame_mock.to_bytes).return_value = serialize_message(Gather( - source=multi_node_peer, - destination=multi_node_leader, - sequence_number=test_setup.sequence_number, - position=1 - )) + mock_cast(recv_message_frame_mock.to_bytes).return_value = serialize_message( + Gather( + source=multi_node_peer, + destination=multi_node_leader, + sequence_number=test_setup.sequence_number, + position=1, + ) + ) recv_value_frame_mock: Union[MagicMock, Frame] = create_autospec(Frame) - mock_cast(test_setup.multi_node_communicator_mock.peers).return_value = [multi_node_leader, multi_node_peer] + mock_cast(test_setup.multi_node_communicator_mock.peers).return_value = [ + multi_node_leader, + multi_node_peer, + ] test_setup.multi_node_communicator_mock.leader = multi_node_leader - mock_cast(test_setup.multi_node_communicator_mock.poll_peers).return_value = [multi_node_peer] + mock_cast(test_setup.multi_node_communicator_mock.poll_peers).return_value = [ + multi_node_peer + ] mock_cast(test_setup.multi_node_communicator_mock.recv).side_effect = [ - [recv_message_frame_mock, recv_value_frame_mock]] + [recv_message_frame_mock, recv_value_frame_mock] + ] result = test_setup.gather_operation() - assert result is not None \ - and test_setup.localhost_communicator_mock.mock_calls == [] \ - and test_setup.multi_node_communicator_mock.mock_calls == [ - call.poll_peers(), - call.recv(multi_node_peer) - ] \ - and mock_cast(test_setup.socket_factory_mock).mock_calls == [] + assert ( + result is not None + and test_setup.localhost_communicator_mock.mock_calls == [] + and test_setup.multi_node_communicator_mock.mock_calls + == [call.poll_peers(), call.recv(multi_node_peer)] + and mock_cast(test_setup.socket_factory_mock).mock_calls == [] + ) diff --git a/tests/unit_tests/udf_communication/test_messages.py b/tests/unit_tests/udf_communication/test_messages.py index 165ca504..d7546144 100644 --- a/tests/unit_tests/udf_communication/test_messages.py +++ b/tests/unit_tests/udf_communication/test_messages.py @@ -5,7 +5,10 @@ from pydantic.fields import ModelField from exasol.analytics.udf.communication.messages import * -from exasol.analytics.udf.communication.serialization import serialize_message, deserialize_message +from exasol.analytics.udf.communication.serialization import ( + deserialize_message, + serialize_message, +) base_message_subclasses = BaseMessage.__subclasses__() @@ -23,7 +26,10 @@ def test_message_serialization(message_class: Type): def test_message_type(message_class: Type): factory = ModelFactory.create_factory(model=message_class) message = factory.build() - assert "message_type" in message.__dict__ and message.message_type == message.__class__.__name__ + assert ( + "message_type" in message.__dict__ + and message.message_type == message.__class__.__name__ + ) def test_all_base_message_subclasses_are_registered_in_root_field_of_message(): diff --git a/tests/unit_tests/udf_framework/mock_query_handlers.py b/tests/unit_tests/udf_framework/mock_query_handlers.py index e11a66e2..bce84d88 100644 --- a/tests/unit_tests/udf_framework/mock_query_handlers.py +++ b/tests/unit_tests/udf_framework/mock_query_handlers.py @@ -1,23 +1,24 @@ -from typing import Dict, Any, Union - -from exasol.analytics.schema.column import \ - Column -from exasol.analytics.schema.column_name import \ - ColumnName -from exasol.analytics.schema.column_type import \ - ColumnType - -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.query.select import SelectQueryWithColumnDefinition, SelectQuery -from exasol.analytics.query_handler.result import Finish, Continue -from exasol.analytics.query_handler.query.result.interface import QueryResult -from exasol.analytics.query_handler.udf.interface import UDFQueryHandler -from exasol.analytics.query_handler.udf.interface import UDFQueryHandlerFactory +from typing import Any, Dict, Union + +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.query.result.interface import QueryResult +from exasol.analytics.query_handler.query.select import ( + SelectQuery, + SelectQueryWithColumnDefinition, +) +from exasol.analytics.query_handler.result import Continue, Finish +from exasol.analytics.query_handler.udf.interface import ( + UDFQueryHandler, + UDFQueryHandlerFactory, +) +from exasol.analytics.schema.column import Column +from exasol.analytics.schema.column_name import ColumnName +from exasol.analytics.schema.column_type import ColumnType TEST_CONNECTION = "TEST_CONNECTION" TEST_INPUT = "<>" -FINAL_RESULT = '<>' +FINAL_RESULT = "<>" QUERY_LIST = [SelectQuery("SELECT 1 FROM DUAL"), SelectQuery("SELECT 2 FROM DUAL")] @@ -28,43 +29,51 @@ def __init__(self, parameter: str, query_handler_context: ScopeQueryHandlerConte if not isinstance(parameter, str): raise AssertionError(f"Expected parameter={parameter} to be a string.") if parameter != TEST_INPUT: - raise AssertionError(f"Expected parameter={parameter} to be '{TEST_INPUT}'.") + raise AssertionError( + f"Expected parameter={parameter} to be '{TEST_INPUT}'." + ) def start(self) -> Union[Continue, Finish[str]]: return Finish(result=FINAL_RESULT) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[str]]: pass class MockQueryHandlerWithOneIterationFactory(UDFQueryHandlerFactory): - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: return MockQueryHandlerWithOneIteration(parameter, query_handler_context) class MockQueryHandlerWithTwoIterations(UDFQueryHandler): - def __init__(self, - parameter: str, - query_handler_context: ScopeQueryHandlerContext): + def __init__(self, parameter: str, query_handler_context: ScopeQueryHandlerContext): super().__init__(parameter, query_handler_context) self._parameter = parameter def start(self) -> Union[Continue, Finish[str]]: - return_query = "SELECT a, table1.b, c FROM table1, table2 " \ - "WHERE table1.b=table2.b" + return_query = ( + "SELECT a, table1.b, c FROM table1, table2 " "WHERE table1.b=table2.b" + ) return_query_columns = [ Column(ColumnName("a"), ColumnType("INTEGER")), - Column(ColumnName("b"), ColumnType("INTEGER"))] + Column(ColumnName("b"), ColumnType("INTEGER")), + ] query_handler_return_query = SelectQueryWithColumnDefinition( - query_string=return_query, - output_columns=return_query_columns) + query_string=return_query, output_columns=return_query_columns + ) query_handler_result = Continue( - query_list=QUERY_LIST, - input_query=query_handler_return_query) + query_list=QUERY_LIST, input_query=query_handler_return_query + ) return query_handler_result - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[Dict[str, Any]]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[Dict[str, Any]]]: a = query_result.a if a != 1: raise AssertionError(f"Expected query_result.a={a} to be 1.") @@ -80,7 +89,9 @@ def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Fini class MockQueryHandlerWithTwoIterationsFactory(UDFQueryHandlerFactory): - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: return MockQueryHandlerWithTwoIterations(parameter, query_handler_context) @@ -93,17 +104,25 @@ def start(self) -> Union[Continue, Finish[str]]: self._query_handler_context.get_temporary_table_name() return Finish(result=FINAL_RESULT) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[str]]: pass class QueryHandlerTestWithOneIterationAndTempTableFactory(UDFQueryHandlerFactory): - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - return QueryHandlerTestWithOneIterationAndTempTable(parameter, query_handler_context) + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: + return QueryHandlerTestWithOneIterationAndTempTable( + parameter, query_handler_context + ) -class MockQueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContext(UDFQueryHandler): +class MockQueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContext( + UDFQueryHandler +): def __init__(self, parameter: str, query_handler_context: ScopeQueryHandlerContext): super().__init__(parameter, query_handler_context) @@ -113,14 +132,22 @@ def start(self) -> Union[Continue, Finish[str]]: self.child = self._query_handler_context.get_child_query_handler_context() return Finish(result=FINAL_RESULT) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[str]]: pass -class MockQueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContextFactory(UDFQueryHandlerFactory): +class MockQueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContextFactory( + UDFQueryHandlerFactory +): - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - return MockQueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContext(parameter, query_handler_context) + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: + return MockQueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContext( + parameter, query_handler_context + ) class MockQueryHandlerWithOneIterationWithNotReleasedTemporaryObject(UDFQueryHandler): @@ -135,14 +162,22 @@ def start(self) -> Union[Continue, Finish[str]]: self.proxy = self.child.get_temporary_table_name() return Finish(result=FINAL_RESULT) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[str]]: pass -class MockQueryHandlerWithOneIterationWithNotReleasedTemporaryObjectFactory(UDFQueryHandlerFactory): +class MockQueryHandlerWithOneIterationWithNotReleasedTemporaryObjectFactory( + UDFQueryHandlerFactory +): - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - return MockQueryHandlerWithOneIterationWithNotReleasedTemporaryObject(parameter, query_handler_context) + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: + return MockQueryHandlerWithOneIterationWithNotReleasedTemporaryObject( + parameter, query_handler_context + ) class MockQueryHandlerUsingConnection(UDFQueryHandler): @@ -153,13 +188,18 @@ def __init__(self, parameter: str, query_handler_context: ScopeQueryHandlerConte def start(self) -> Union[Continue, Finish[str]]: connection = self._query_handler_context.get_connection(TEST_CONNECTION) return Finish( - f"{connection.name},{connection.address},{connection.user},{connection.password}") + f"{connection.name},{connection.address},{connection.user},{connection.password}" + ) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[str]]: pass class MockQueryHandlerUsingConnectionFactory(UDFQueryHandlerFactory): - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: + def create( + self, parameter: str, query_handler_context: ScopeQueryHandlerContext + ) -> UDFQueryHandler: return MockQueryHandlerUsingConnection(parameter, query_handler_context) diff --git a/tests/unit_tests/udf_framework/test_dynamic_modules.py b/tests/unit_tests/udf_framework/test_dynamic_modules.py index 5b4a2166..c519bee4 100644 --- a/tests/unit_tests/udf_framework/test_dynamic_modules.py +++ b/tests/unit_tests/udf_framework/test_dynamic_modules.py @@ -1,8 +1,6 @@ import pytest -from exasol.analytics.utils.dynamic_modules import ( - create_module, - ModuleExistsException, -) + +from exasol.analytics.utils.dynamic_modules import ModuleExistsException, create_module class ExampleClass: @@ -17,20 +15,25 @@ def test_create_module_with_class(): mod = create_module("xx1") mod.add_to_module(ExampleClass) import xx1 + instance = xx1.ExampleClass() - assert isinstance(instance, ExampleClass) and \ - ExampleClass.__module__ == "xx1" + assert isinstance(instance, ExampleClass) and ExampleClass.__module__ == "xx1" def test_add_function(): mod = create_module("xx2") import xx2 + xx2.add_to_module(example_function) - assert xx2.example_function() == "example_function return value" \ + assert ( + xx2.example_function() == "example_function return value" and example_function.__module__ == "xx2" + ) def test_add_function_to_existing_module(): create_module("xx3") - with pytest.raises(ModuleExistsException, match='Module "xx3" already exists') as ex: + with pytest.raises( + ModuleExistsException, match='Module "xx3" already exists' + ) as ex: create_module("xx3") diff --git a/tests/unit_tests/udf_framework/test_json_udf_query_handler.py b/tests/unit_tests/udf_framework/test_json_udf_query_handler.py index 4e186a38..aacafb32 100644 --- a/tests/unit_tests/udf_framework/test_json_udf_query_handler.py +++ b/tests/unit_tests/udf_framework/test_json_udf_query_handler.py @@ -1,44 +1,46 @@ import json -import pytest - from json import JSONDecodeError from typing import Union -from exasol.analytics.schema import ( - Column, - ColumnType, - ColumnName, -) +import pytest -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.json_udf_query_handler import JSONQueryHandler, JSONType -from exasol.analytics.query_handler.result import Continue, Finish -from exasol.analytics.query_handler.query.result.python_query_result import PythonQueryResult +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.json_udf_query_handler import ( + JSONQueryHandler, + JSONType, +) from exasol.analytics.query_handler.query.result.interface import QueryResult +from exasol.analytics.query_handler.query.result.python_query_result import ( + PythonQueryResult, +) +from exasol.analytics.query_handler.result import Continue, Finish from exasol.analytics.query_handler.udf.json_impl import JsonUDFQueryHandler +from exasol.analytics.schema import Column, ColumnName, ColumnType class ConstructorTestJSONQueryHandler(JSONQueryHandler): - def __init__(self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) def start(self) -> Union[Continue, Finish[JSONType]]: raise AssertionError("Should not be called") - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[JSONType]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[JSONType]]: raise AssertionError("Should not be called") def test_constructor_valid_json(top_level_query_handler_context_mock): - parameter = { - "test_key": "test_value" - } + parameter = {"test_key": "test_value"} json_str_parameter = json.dumps(parameter) query_handler = JsonUDFQueryHandler( parameter=json_str_parameter, query_handler_context=top_level_query_handler_context_mock, - wrapped_json_query_handler_class=ConstructorTestJSONQueryHandler + wrapped_json_query_handler_class=ConstructorTestJSONQueryHandler, ) @@ -47,32 +49,34 @@ def test_constructor_invalid_json(top_level_query_handler_context_mock): query_handler = JsonUDFQueryHandler( parameter="'abc'='ced'", query_handler_context=top_level_query_handler_context_mock, - wrapped_json_query_handler_class=ConstructorTestJSONQueryHandler + wrapped_json_query_handler_class=ConstructorTestJSONQueryHandler, ) class StartReturnParameterTestJSONQueryHandler(JSONQueryHandler): - def __init__(self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter def start(self) -> Union[Continue, Finish[JSONType]]: return Finish[JSONType](self._parameter) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[JSONType]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[JSONType]]: raise AssertionError("Should not be called") def test_start_return_parameter(top_level_query_handler_context_mock): - parameter = { - "test_key": "test_value" - } + parameter = {"test_key": "test_value"} json_str_parameter = json.dumps(parameter) query_handler = JsonUDFQueryHandler( parameter=json_str_parameter, query_handler_context=top_level_query_handler_context_mock, - wrapped_json_query_handler_class=StartReturnParameterTestJSONQueryHandler + wrapped_json_query_handler_class=StartReturnParameterTestJSONQueryHandler, ) result = query_handler.start() assert isinstance(result, Finish) and result.result == json_str_parameter @@ -80,30 +84,33 @@ def test_start_return_parameter(top_level_query_handler_context_mock): class HandleQueryResultCheckQueryResultTestJSONQueryHandler(JSONQueryHandler): - def __init__(self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext): + def __init__( + self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter def start(self) -> Union[Continue, Finish[JSONType]]: raise AssertionError("Should not be called") - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[JSONType]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[JSONType]]: a = query_result.a return Finish[JSONType]({"a": a}) def test_handle_query_result_check_query_result(top_level_query_handler_context_mock): - parameter = { - "test_key": "test_value" - } + parameter = {"test_key": "test_value"} json_str_parameter = json.dumps(parameter) query_handler = JsonUDFQueryHandler( parameter=json_str_parameter, query_handler_context=top_level_query_handler_context_mock, - wrapped_json_query_handler_class=HandleQueryResultCheckQueryResultTestJSONQueryHandler + wrapped_json_query_handler_class=HandleQueryResultCheckQueryResultTestJSONQueryHandler, ) result = query_handler.handle_query_result( - PythonQueryResult(data=[(1,)], - columns=[Column(ColumnName("a"), - ColumnType("INTEGER"))])) + PythonQueryResult( + data=[(1,)], columns=[Column(ColumnName("a"), ColumnType("INTEGER"))] + ) + ) assert isinstance(result, Finish) and result.result == '{"a": 1}' diff --git a/tests/unit_tests/udf_framework/test_json_udf_query_handler_factory.py b/tests/unit_tests/udf_framework/test_json_udf_query_handler_factory.py index d0e9f869..9f7dbf8d 100644 --- a/tests/unit_tests/udf_framework/test_json_udf_query_handler_factory.py +++ b/tests/unit_tests/udf_framework/test_json_udf_query_handler_factory.py @@ -1,37 +1,42 @@ import json - from typing import Union -from exasol.analytics.schema import ( - Column, - ColumnType, - ColumnName, +from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext +from exasol.analytics.query_handler.json_udf_query_handler import ( + JSONQueryHandler, + JSONType, ) - -from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext -from exasol.analytics.query_handler.json_udf_query_handler import JSONQueryHandler, JSONType -from exasol.analytics.query_handler.result import Continue, Finish -from exasol.analytics.query_handler.query.result.python_query_result import PythonQueryResult from exasol.analytics.query_handler.query.result.interface import QueryResult -from exasol.analytics.query_handler.udf.json_impl import JsonUDFQueryHandlerFactory +from exasol.analytics.query_handler.query.result.python_query_result import ( + PythonQueryResult, +) +from exasol.analytics.query_handler.result import Continue, Finish from exasol.analytics.query_handler.udf.interface import UDFQueryHandler +from exasol.analytics.query_handler.udf.json_impl import JsonUDFQueryHandlerFactory +from exasol.analytics.schema import Column, ColumnName, ColumnType class TestJSONQueryHandler(JSONQueryHandler): __test__ = False - def __init__(self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext): + + def __init__( + self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext + ): super().__init__(parameter, query_handler_context) self._parameter = parameter def start(self) -> Union[Continue, Finish[JSONType]]: return Finish[JSONType](self._parameter) - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[JSONType]]: + def handle_query_result( + self, query_result: QueryResult + ) -> Union[Continue, Finish[JSONType]]: return Finish[JSONType](self._parameter) class TestJsonUDFQueryHandlerFactory(JsonUDFQueryHandlerFactory): __test__ = False + def __init__(self): super().__init__(TestJSONQueryHandler) @@ -39,12 +44,17 @@ def __init__(self): def test(top_level_query_handler_context_mock): test_input = {"a": 1} json_str = json.dumps(test_input) - query_handler = TestJsonUDFQueryHandlerFactory().create(json_str, top_level_query_handler_context_mock) + query_handler = TestJsonUDFQueryHandlerFactory().create( + json_str, top_level_query_handler_context_mock + ) start_result = query_handler.start() handle_query_result = query_handler.handle_query_result( - PythonQueryResult(data=[(1,)], - columns=[Column(ColumnName("a"), - ColumnType("INTEGER"))]) + PythonQueryResult( + data=[(1,)], columns=[Column(ColumnName("a"), ColumnType("INTEGER"))] + ) + ) + assert ( + isinstance(query_handler, UDFQueryHandler) + and start_result.result == json_str + and handle_query_result.result == json_str ) - assert isinstance(query_handler, UDFQueryHandler) \ - and start_result.result == json_str and handle_query_result.result == json_str diff --git a/tests/unit_tests/udf_framework/test_query_handler_runner_udf_mock.py b/tests/unit_tests/udf_framework/test_query_handler_runner_udf_mock.py index d4e81ebf..86531a76 100644 --- a/tests/unit_tests/udf_framework/test_query_handler_runner_udf_mock.py +++ b/tests/unit_tests/udf_framework/test_query_handler_runner_udf_mock.py @@ -1,9 +1,8 @@ import json -import pytest import re - from typing import Any, Dict +import pytest from exasol_udf_mock_python.column import Column from exasol_udf_mock_python.connection import Connection from exasol_udf_mock_python.group import Group @@ -19,7 +18,6 @@ from tests.unit_tests.udf_framework.mock_query_handlers import TEST_CONNECTION from tests.utils.test_utils import pytest_regex - TEMPORARY_NAME_PREFIX = "temporary_name_prefix" BUCKETFS_DIRECTORY = "directory" @@ -51,7 +49,10 @@ def query_handler_bfs_connection(tmp_path): base_path=f"{path}", ) -def create_mocked_exa_env(udf_script_name: str, bfs_connection, connections: Dict[str, Any] = {}): + +def create_mocked_exa_env( + udf_script_name: str, bfs_connection, connections: Dict[str, Any] = {} +): meta = create_mock_data(udf_script_name) connections[BUCKETFS_CONNECTION_NAME] = bfs_connection return MockExaEnvironment(metadata=meta, connections=connections) @@ -61,6 +62,7 @@ def create_mocked_exa_env(udf_script_name: str, bfs_connection, connections: Dic def udf_script_name(): return "AAF_TEST_UDF" + @pytest.fixture def mocked_exa_env(query_handler_bfs_connection, udf_script_name): return create_mocked_exa_env(udf_script_name, query_handler_bfs_connection) @@ -68,6 +70,7 @@ def mocked_exa_env(query_handler_bfs_connection, udf_script_name): def _udf_wrapper(): from exasol_udf_mock_python.udf_context import UDFContext + from exasol.analytics.query_handler.udf.runner.udf import QueryHandlerRunnerUDF udf = QueryHandlerRunnerUDF(exa) @@ -91,9 +94,7 @@ def create_mock_data(script_name: str): Column("7", str, "VARCHAR(2000000)"), # parameters ], output_type="EMITS", - output_columns=[ - Column("outputs", str, "VARCHAR(2000000)") - ], + output_columns=[Column("outputs", str, "VARCHAR(2000000)")], is_variadic_input=True, script_name=script_name, ) @@ -109,15 +110,22 @@ def test_query_handler_udf_with_one_iteration(mocked_exa_env): "temp_schema", "MockQueryHandlerWithOneIterationFactory", "tests.unit_tests.udf_framework.mock_query_handlers", - mock_query_handlers.TEST_INPUT + mock_query_handlers.TEST_INPUT, ) result = UDFMockExecutor().run([Group([input_data])], mocked_exa_env) rows = [row[0] for row in result[0].rows] - expected_rows = [None, None, QueryHandlerStatus.FINISHED.name, mock_query_handlers.FINAL_RESULT] + expected_rows = [ + None, + None, + QueryHandlerStatus.FINISHED.name, + mock_query_handlers.FINAL_RESULT, + ] assert rows == expected_rows -def test_query_handler_udf_with_one_iteration_with_not_released_child_query_handler_context(mocked_exa_env): +def test_query_handler_udf_with_one_iteration_with_not_released_child_query_handler_context( + mocked_exa_env, +): input_data = ( 0, BUCKETFS_CONNECTION_NAME, @@ -126,18 +134,22 @@ def test_query_handler_udf_with_one_iteration_with_not_released_child_query_hand "temp_schema", "MockQueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContextFactory", "tests.unit_tests.udf_framework.mock_query_handlers", - "{}" + "{}", ) result = UDFMockExecutor().run([Group([input_data])], mocked_exa_env) rows = [row[0] for row in result[0].rows] - expected_rows = [None, - None, - QueryHandlerStatus.ERROR.name, - pytest_regex(r".*The following child contexts were not released:*", re.DOTALL)] + expected_rows = [ + None, + None, + QueryHandlerStatus.ERROR.name, + pytest_regex(r".*The following child contexts were not released:*", re.DOTALL), + ] assert rows == expected_rows -def test_query_handler_udf_with_one_iteration_with_not_released_temporary_object(mocked_exa_env): +def test_query_handler_udf_with_one_iteration_with_not_released_temporary_object( + mocked_exa_env, +): input_data = ( 0, BUCKETFS_CONNECTION_NAME, @@ -146,15 +158,17 @@ def test_query_handler_udf_with_one_iteration_with_not_released_temporary_object "temp_schema", "MockQueryHandlerWithOneIterationWithNotReleasedTemporaryObjectFactory", "tests.unit_tests.udf_framework.mock_query_handlers", - "{}" + "{}", ) result = UDFMockExecutor().run([Group([input_data])], mocked_exa_env) rows = [row[0] for row in result[0].rows] - expected_rows = [None, - None, - QueryHandlerStatus.ERROR.name, - pytest_regex(r".*The following child contexts were not released.*", re.DOTALL), - 'DROP TABLE IF EXISTS "temp_schema"."temporary_name_prefix_2_1";'] + expected_rows = [ + None, + None, + QueryHandlerStatus.ERROR.name, + pytest_regex(r".*The following child contexts were not released.*", re.DOTALL), + 'DROP TABLE IF EXISTS "temp_schema"."temporary_name_prefix_2_1";', + ] assert rows == expected_rows @@ -167,19 +181,30 @@ def test_query_handler_udf_with_one_iteration_and_temp_table(mocked_exa_env): "temp_schema", "QueryHandlerTestWithOneIterationAndTempTableFactory", "tests.unit_tests.udf_framework.mock_query_handlers", - "{}" + "{}", ) result = UDFMockExecutor().run([Group([input_data])], mocked_exa_env) rows = [row[0] for row in result[0].rows] - table_cleanup_query = 'DROP TABLE IF EXISTS "temp_schema"."temporary_name_prefix_1";' - expected_rows = [None, None, QueryHandlerStatus.FINISHED.name, mock_query_handlers.FINAL_RESULT, - table_cleanup_query] + table_cleanup_query = ( + 'DROP TABLE IF EXISTS "temp_schema"."temporary_name_prefix_1";' + ) + expected_rows = [ + None, + None, + QueryHandlerStatus.FINISHED.name, + mock_query_handlers.FINAL_RESULT, + table_cleanup_query, + ] assert rows == expected_rows -def test_query_handler_udf_with_two_iteration(query_handler_bfs_connection, udf_script_name): +def test_query_handler_udf_with_two_iteration( + query_handler_bfs_connection, udf_script_name +): def state_file_exists(iteration: int) -> bool: - bucketfs_location = create_bucketfs_location_from_conn_object(query_handler_bfs_connection) + bucketfs_location = create_bucketfs_location_from_conn_object( + query_handler_bfs_connection + ) bucketfs_path = f"{BUCKETFS_DIRECTORY}/{TEMPORARY_NAME_PREFIX}/state" state_file = f"{str(iteration)}.pkl" return (bucketfs_location / bucketfs_path / state_file).exists() @@ -193,22 +218,30 @@ def state_file_exists(iteration: int) -> bool: "temp_schema", "MockQueryHandlerWithTwoIterationsFactory", "tests.unit_tests.udf_framework.mock_query_handlers", - "{}" + "{}", ) executor = UDFMockExecutor() result = executor.run([Group([input_data])], exa) rows = [row[0] for row in result[0].rows] - expected_return_query_view = 'CREATE VIEW "temp_schema"."temporary_name_prefix_2_1" AS ' \ - 'SELECT a, table1.b, c ' \ - 'FROM table1, table2 ' \ - 'WHERE table1.b=table2.b;' - return_query = f'SELECT "TEST_SCHEMA"."{udf_script_name}"(' \ - '1,' \ - "'bucketfs_connection','directory','temporary_name_prefix'," \ - '"a","b") ' \ - 'FROM "temp_schema"."temporary_name_prefix_2_1";' - expected_rows = [expected_return_query_view, return_query, QueryHandlerStatus.CONTINUE.name, "{}"] + \ - [query.query_string for query in mock_query_handlers.QUERY_LIST] + expected_return_query_view = ( + 'CREATE VIEW "temp_schema"."temporary_name_prefix_2_1" AS ' + "SELECT a, table1.b, c " + "FROM table1, table2 " + "WHERE table1.b=table2.b;" + ) + return_query = ( + f'SELECT "TEST_SCHEMA"."{udf_script_name}"(' + "1," + "'bucketfs_connection','directory','temporary_name_prefix'," + '"a","b") ' + 'FROM "temp_schema"."temporary_name_prefix_2_1";' + ) + expected_rows = [ + expected_return_query_view, + return_query, + QueryHandlerStatus.CONTINUE.name, + "{}", + ] + [query.query_string for query in mock_query_handlers.QUERY_LIST] assert rows == expected_rows previous = 0 @@ -222,16 +255,19 @@ def state_file_exists(iteration: int) -> bool: input_columns=[ Column("0", int, "INTEGER"), # iter_num Column("1", str, "VARCHAR(2000000)"), # temporary_bfs_location_conn - Column("2", str, "VARCHAR(2000000)"), # temporary_bfs_location_directory + Column( + "2", str, "VARCHAR(2000000)" + ), # temporary_bfs_location_directory Column("3", str, "VARCHAR(2000000)"), # temporary_name_prefix Column("4", int, "INTEGER"), # column a of the input query Column("5", int, "INTEGER"), # column b of the input query - ], output_type="EMITS", - output_columns=[ - Column("outputs", str, "VARCHAR(2000000)") ], - is_variadic_input=True), - connections={BUCKETFS_CONNECTION_NAME: query_handler_bfs_connection}) + output_type="EMITS", + output_columns=[Column("outputs", str, "VARCHAR(2000000)")], + is_variadic_input=True, + ), + connections={BUCKETFS_CONNECTION_NAME: query_handler_bfs_connection}, + ) input_data = ( 1, @@ -239,17 +275,26 @@ def state_file_exists(iteration: int) -> bool: BUCKETFS_DIRECTORY, TEMPORARY_NAME_PREFIX, 1, - 2 + 2, ) result = executor.run([Group([input_data])], exa) rows = [row[0] for row in result[0].rows] - cleanup_return_query_view = 'DROP VIEW IF EXISTS "temp_schema"."temporary_name_prefix_2_1";' - expected_rows = [None, None, QueryHandlerStatus.FINISHED.name, mock_query_handlers.FINAL_RESULT, - cleanup_return_query_view] + cleanup_return_query_view = ( + 'DROP VIEW IF EXISTS "temp_schema"."temporary_name_prefix_2_1";' + ) + expected_rows = [ + None, + None, + QueryHandlerStatus.FINISHED.name, + mock_query_handlers.FINAL_RESULT, + cleanup_return_query_view, + ] assert rows == expected_rows -def test_query_handler_udf_using_connection(query_handler_bfs_connection, udf_script_name): +def test_query_handler_udf_using_connection( + query_handler_bfs_connection, udf_script_name +): test_connection = udf_mock_connection( address="test_connection", user="test_connection_user", @@ -258,7 +303,7 @@ def test_query_handler_udf_using_connection(query_handler_bfs_connection, udf_sc exa = create_mocked_exa_env( udf_script_name, query_handler_bfs_connection, - { TEST_CONNECTION: test_connection }, + {TEST_CONNECTION: test_connection}, ) input_data = ( 0, @@ -268,17 +313,21 @@ def test_query_handler_udf_using_connection(query_handler_bfs_connection, udf_sc "temp_schema", "MockQueryHandlerUsingConnectionFactory", "tests.unit_tests.udf_framework.mock_query_handlers", - "{}" + "{}", ) result = UDFMockExecutor().run([Group([input_data])], exa) rows = [row[0] for row in result[0].rows] expected_rows = [ - None, None, QueryHandlerStatus.FINISHED.name, - ",".join([ - TEST_CONNECTION, - test_connection.address, - test_connection.user, - test_connection.password, - ]) + None, + None, + QueryHandlerStatus.FINISHED.name, + ",".join( + [ + TEST_CONNECTION, + test_connection.address, + test_connection.user, + test_connection.password, + ] + ), ] assert rows == expected_rows diff --git a/tests/utils/db_queries.py b/tests/utils/db_queries.py index 30843681..60a622a9 100644 --- a/tests/utils/db_queries.py +++ b/tests/utils/db_queries.py @@ -1,16 +1,12 @@ from typing import List -deployed_script_list = [ - "AAF_QUERY_HANDLER_UDF", - "AAF_RUN_QUERY_HANDLER" -] +deployed_script_list = ["AAF_QUERY_HANDLER_UDF", "AAF_RUN_QUERY_HANDLER"] class DBQueries: @staticmethod def get_all_scripts(db_conn, schema_name) -> List[int]: - query_all_scripts = \ - f""" + query_all_scripts = f""" SELECT SCRIPT_NAME FROM EXA_ALL_SCRIPTS WHERE SCRIPT_SCHEMA = '{schema_name.upper()}' @@ -20,8 +16,7 @@ def get_all_scripts(db_conn, schema_name) -> List[int]: @staticmethod def check_all_scripts_deployed(db_conn, schema_name) -> bool: - all_scripts = DBQueries.get_all_scripts( - db_conn, schema_name) + all_scripts = DBQueries.get_all_scripts(db_conn, schema_name) return all(script in all_scripts for script in deployed_script_list) @staticmethod @@ -41,4 +36,5 @@ def get_language_settings_from(db_conn, alter_type): def set_language_settings_to(db_conn, alter_type, language_settings): db_conn.execute( f"""ALTER {alter_type.upper()} SET SCRIPT_LANGUAGES= - '{language_settings}'""") + '{language_settings}'""" + ) diff --git a/tests/utils/revert_language_settings.py b/tests/utils/revert_language_settings.py index 7b0a6f54..6ea349f9 100644 --- a/tests/utils/revert_language_settings.py +++ b/tests/utils/revert_language_settings.py @@ -1,5 +1,7 @@ import contextlib -import pyexasol # type: ignore + +import pyexasol # type: ignore + @contextlib.contextmanager def revert_language_settings(connection: pyexasol.ExaConnection): @@ -11,5 +13,9 @@ def revert_language_settings(connection: pyexasol.ExaConnection): try: yield finally: - connection.execute(f"ALTER SYSTEM SET SCRIPT_LANGUAGES='{language_settings[0]}';") - connection.execute(f"ALTER SESSION SET SCRIPT_LANGUAGES='{language_settings[1]}';") + connection.execute( + f"ALTER SYSTEM SET SCRIPT_LANGUAGES='{language_settings[0]}';" + ) + connection.execute( + f"ALTER SESSION SET SCRIPT_LANGUAGES='{language_settings[1]}';" + ) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 35ae6079..db00347c 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,5 +1,6 @@ import re + class pytest_regex: """Assert that a given string meets some expectations.""" @@ -11,4 +12,4 @@ def __eq__(self, actual): return bool(match) def __repr__(self): - return self._regex.pattern \ No newline at end of file + return self._regex.pattern