diff --git a/python/lib/communication/dmod/communication/__init__.py b/python/lib/communication/dmod/communication/__init__.py index 5dadb9b75..8178b3a70 100644 --- a/python/lib/communication/dmod/communication/__init__.py +++ b/python/lib/communication/dmod/communication/__init__.py @@ -1,6 +1,6 @@ from ._version import __version__ -from .client import DataServiceClient, InternalServiceClient, ModelExecRequestClient, ExternalRequestClient, \ - PartitionerServiceClient, SchedulerClient +from .client import AuthClient, CachedAuthClient, DataServiceClient, ExternalRequestClient, PartitionerServiceClient, \ + RequestClient, SchedulerClient, TransportLayerClient, WebSocketClient from .maas_request import get_available_models, get_available_outputs, get_distribution_types, get_parameters, \ get_request, AbstractNgenRequest, Distribution, DmodJobRequest, ExternalRequest, ExternalRequestResponse,\ ModelExecRequest, ModelExecRequestResponse, NWMRequest, NWMRequestResponse, Scalar, NGENRequest, \ diff --git a/python/lib/communication/dmod/communication/_version.py b/python/lib/communication/dmod/communication/_version.py index ef9199407..a842d05a7 100644 --- a/python/lib/communication/dmod/communication/_version.py +++ b/python/lib/communication/dmod/communication/_version.py @@ -1 +1 @@ -__version__ = '0.14.0' +__version__ = '0.15.0' diff --git a/python/lib/communication/dmod/communication/client.py b/python/lib/communication/dmod/communication/client.py index 571f0f97d..dd05b1cec 100644 --- a/python/lib/communication/dmod/communication/client.py +++ b/python/lib/communication/dmod/communication/client.py @@ -2,25 +2,21 @@ import datetime import json import ssl -import traceback import typing from abc import ABC, abstractmethod from asyncio import AbstractEventLoop +from deprecated import deprecated from pathlib import Path from typing import Generic, Optional, Type, TypeVar, Union -from dmod.core.serializable import Serializable import websockets -from .maas_request import ExternalRequest, ExternalRequestResponse, ModelExecRequest, ModelExecRequestResponse, NWMRequest, \ - NGENRequest -from .message import AbstractInitRequest, Message, Response, InitRequestResponseReason -from .partition_request import PartitionRequest, PartitionResponse -from .dataset_management_message import DatasetManagementMessage, DatasetManagementResponse -from .scheduler_request import SchedulerRequestMessage, SchedulerRequestResponse -from .evaluation_request import EvaluationConnectionRequest +from .maas_request import ExternalRequest, ExternalRequestResponse +from .message import AbstractInitRequest, Response +from .partition_request import PartitionResponse +from .dataset_management_message import DatasetManagementResponse +from .scheduler_request import SchedulerRequestResponse from .evaluation_request import EvaluationConnectionRequestResponse -from .validator import NWMRequestJsonValidator from .update_message import UpdateMessage, UpdateMessageResponse import logging @@ -28,15 +24,7 @@ # TODO: refactor this to allow for implementation-specific overriding more easily logger = logging.getLogger("gui_log") -M = TypeVar("M", bound=AbstractInitRequest) -R = TypeVar("R", bound=Response) - -EXTERN_REQ_M = TypeVar("EXTERN_REQ_M", bound=ExternalRequest) -EXTERN_REQ_R = TypeVar("EXTERN_REQ_R", bound=ExternalRequestResponse) - -MOD_EX_M = TypeVar("MOD_EX_M", bound=ModelExecRequest) -MOD_EX_R = TypeVar("MOD_EX_R", bound=ModelExecRequestResponse) - +CONN = TypeVar("CONN") def get_or_create_eventloop() -> AbstractEventLoop: """ @@ -58,31 +46,99 @@ def get_or_create_eventloop() -> AbstractEventLoop: raise -class AbstractClient(ABC): +class TransportLayerClient(ABC): """ - Abstract client capable of securely communicating with a server at some endpoint. + Abstract client capable of communicating with a server at some endpoint. + + Abstract client for interacting with a service at the OSI transport layer. It provides an interface for sending + data to accept data and send this data to a server at some endpoint. The interface function for this behavior + supports optionally waiting for and returning a raw data response. Alternatively, the type provides a function for + receiving a response from the server independently. - Abstract client with an interface for securely sending data to a server at some endpoint. The interface function - for this behavior supports optionally waiting for and returning a raw response. Alternatively, the type provides an - interface for receiving a response from the server independently. + Instances are capable of securing communications using an ::class:`SSLContext`. A customized context or default + context can be created, depending on the parameters passed during init. """ - def __init__(self, endpoint_uri: str, *args, **kwargs): + + @classmethod + @abstractmethod + def get_endpoint_protocol_str(cls, use_secure_connection: bool = True) -> str: + """ + Get the protocol substring portion for valid connection URI strings for an instance of this class. + + Parameters + ---------- + use_secure_connection : bool + Whether to get the protocol substring applicable for secure connections (``True`` by default). + + Returns + ------- + str + The protocol substring portion for valid connection URI strings for an instance of this class. + """ + pass + + def __init__(self, *, endpoint_host: str, endpoint_port: Union[int, str], endpoint_path: Optional[str] = None, + cafile: Optional[Path] = None, capath: Optional[Path] = None, use_default_context: bool = False, + **kwargs): """ Initialize this instance. + Initialization may or may not include creation of an ::class:`SSLContext`, according to these rules: + - If ``cafile`` is ``None``, ``capath`` is ``None``, and ``use_default_context`` is ``False`` (which are the + default values for each), then no ::class:`SSLContext` is created. + - If ``use_default_context`` is ``True``, ::function:`ssl.create_default_context` is used to create a + context object, with ``cafile`` and ``capath`` passed as kwargs. + - If either ``cafile`` or ``capath`` is not ``None``, and ``use_default_context`` is ``False``, a customized + context object is created, with certificates loaded from locations at ``cafile`` and/or ``capath``. + Parameters ---------- - endpoint_uri: str + endpoint_host: str + The host component for building this client's endpoint URI for opening a connection. Does not include the protocol. The endpoint for the client to connect to when opening a connection. + endpoint_port: Union[int, str] + The host port component for building this client's endpoint URI for opening a connection. + endpoint_path: Optional[str] + The optional path component for building this client's endpoint URI for opening a connection. + cafile: Optional[Path] + Optional path to CA certificates PEM file. + capath: Optional[Path] + Optional path to directory containing CA certificates PEM files, following an OpenSSL specific layout (see + ::function:`ssl.SSLContext.load_verify_locations`). + use_default_context: bool + Whether to use ::function:`ssl.create_default_context` to create a ::class:`SSLContext` (default ``False``). args Other unused positional parameters. kwargs Other unused keyword parameters. """ - super().__init__(*args, **kwargs) + super().__init__() + + self._endpoint_host: str = endpoint_host.strip() + self._endpoint_port = endpoint_port.strip() if isinstance(endpoint_port, str) else endpoint_port + self._endpoint_path: str = '' if endpoint_path is None else endpoint_path.strip() + + self._endpoint_uri = None - self.endpoint_uri = endpoint_uri - """str: The endpoint for the client to connect to when opening a connection.""" + if use_default_context: + self._client_ssl_context = ssl.create_default_context(cafile=cafile, capath=capath) + elif cafile is not None or capath is not None: + self._client_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self._client_ssl_context.load_verify_locations(cafile=cafile, capath=capath) + else: + self._client_ssl_context = None + + @abstractmethod + def _get_endpoint_uri(self) -> str: + """ + Get the endpoint for the client to connect to when opening a connection. + + Returns + ------- + str + The endpoint for the client to connect to when opening a connection. + """ + pass @abstractmethod async def async_send(self, data: Union[str, bytearray, bytes], await_response: bool = False) -> Optional[str]: @@ -115,26 +171,162 @@ async def async_recv(self) -> str: """ pass + +class AuthClient: + """ + Simple client object responsible for handling acquiring and applying authenticated session details to requests. + """ + def __init__(self, transport_client: TransportLayerClient, *args, **kwargs): + self._transport_client: TransportLayerClient = transport_client + # TODO: get full session implementation if possible + self._session_id, self._session_secret, self._session_created = None, None, None + self._force_reauth = False + + def _acquire_session(self) -> bool: + """ + Synchronous function to acquire an authenticated session. + + Wrapper convenience function for use outside of the async event loop. + + Returns + ------- + bool + Whether acquiring an authenticated session was successful. + + See Also + ------- + _async_acquire_session + """ + try: + return get_or_create_eventloop().run_until_complete(self._async_acquire_session()) + except Exception as e: + msg = f"{self.__class__.__name__} failed to acquire auth credential due to {e.__class__.__name__}: {str(e)}" + logger.error(msg) + return False + + async def _async_acquire_session(self) -> bool: + """ + Acquire an authenticated session. + + Returns + ------- + bool + Whether acquiring an authenticated session was successful. + """ + # Clear anything previously set when forced reauth + if self.force_reauth: + self._session_id, self._session_secret, self._session_created = None, None, None + self.force_reauth = False + # Otherwise, if we have the session details already, just return True + elif all([self._session_id, self._session_secret, self._session_created]): + return True + + try: + auth_resp = await self._transport_client.async_send(data=json.dumps(self._prepare_auth_request_payload()), + await_response=True) + return self._parse_auth_data(auth_resp) + # In the future, consider whether we should treat ConnectionResetError separately + except Exception as e: + msg = f"{self.__class__.__name__} failed to acquire auth credential due to {e.__class__.__name__}: {str(e)}" + logger.error(msg) + return False + + def _parse_auth_data(self, auth_data_str: str): + """ + Parse serialized authentication data and update instance state accordingly. + + Parse the given serialized authentication data and update the state of the instance accordingly to represent the + successful authentication (assuming the data parses appropriately). This method must support, at minimum, + parsing the text data returned from the service as the response to the authentication payload, + + Note that a return value of ``True`` indicates the instance holds valid authentication details that can be + applied to requests. + + Parameters + ---------- + auth_data_str : str + The data to be parsed, such as that returned in the service response to an authentication payload. + + Returns + ---------- + bool + Whether parsing was successful. + """ + try: + auth_response = json.loads(auth_data_str) + # TODO: consider making sure this parses to a SessionInitResponse + session_id = auth_response['data']['session_id'] + session_secret = auth_response['data']['session_secret'] + session_created = auth_response['data']['created'] + if all((session_id, session_secret, session_created)): + self._session_id, self._session_secret, self._session_created = session_id, session_secret, session_created + return True + else: + return False + except Exception as e: + return False + + def _prepare_auth_request_payload(self) -> dict: + """ + Generate JSON payload to be transmitted by ::method:`async_acquire_session` to service when requesting auth. + + Returns + ------- + dict + The JSON payload to be transmitted by ::method:`async_acquire_session` to the service when requesting auth. + """ + # Right now, it doesn't matter as long as it is valid + # TODO: Fix this to not be ... fixed ... + return {'username': 'someone', 'user_secret': 'something'} + + async def apply_auth(self, external_request: ExternalRequest) -> bool: + """ + Apply appropriate authentication details to this request object, acquiring them first if needed. + + Parameters + ---------- + external_request : ExternalRequest + A request that needs the appropriate session secret applied. + + Returns + ---------- + bool + Whether the secret was obtained and applied successfully. + """ + if await self._async_acquire_session(): + external_request.session_secret = self._session_secret + return True + else: + return False + @property - @abstractmethod - def client_ssl_context(self) -> ssl.SSLContext: + def force_reauth(self) -> bool: """ - Get the client SSL context property. + Whether the client should be forced to reacquire a new authenticated session from the service. Returns ------- - ssl.SSLContext - The client SSL context for secure connections. + bool + Whether the client should be forced to re-authenticate and get a new session from the auth service. """ - pass + return self._force_reauth + @force_reauth.setter + def force_reauth(self, should_force_new: bool): + self._force_reauth = should_force_new + + @property + def session_created(self) -> Optional[str]: + return self._session_created + + @property + def session_id(self) -> Optional[str]: + return self._session_id -class ExternalClient(AbstractClient, ABC): - """ - Abstract client encapsulating the logic for using external connections secured using sessions. - Abstract client type that requires connections that work using secure sessions. It is able to serialize session - details to a file and, by default, load them from this file if appropriate. +class CachedAuthClient(AuthClient): + """ + Extension of ::class:`AuthClient` that supports caching the session to a file. """ def __init__(self, session_file: Optional[Path] = None, *args, **kwargs): @@ -142,12 +334,13 @@ def __init__(self, session_file: Optional[Path] = None, *args, **kwargs): Initialize this instance, including creating empty session-related attributes. If a ``session_file`` is not given, a default path in the home directory with a timestamp-based name will be - used. + used. If ``session_file`` is a directory, similiarly a timestamp-based default basename will be used for a file + in this directory. Parameters ---------- session_file : Optional[Path] - Optional path to file for a serialized session, both for loading from and saving to. + Optional specified path to file for a serialized session, both for loading from and saving to. args kwargs @@ -158,92 +351,63 @@ def __init__(self, session_file: Optional[Path] = None, *args, **kwargs): superclass init. """ super().__init__(*args, **kwargs) - # TODO: get full session implementation if possible - self._session_id, self._session_secret, self._session_created, self._is_new_session = None, None, None, None + + self._is_new_session = None + self._force_reload = False + + default_basename = '.dmod_session' + if session_file is None: - self._cached_session_file = Path.home().joinpath( - '.{}_session'.format(datetime.datetime.now().strftime('%Y%m%d%H%M%S%s'))) + self._cached_session_file = Path.home().joinpath(default_basename) + elif session_file.is_dir(): + self._cached_session_file = session_file.joinpath(default_basename) else: self._cached_session_file = session_file - def _acquire_new_session(self): - try: - return get_or_create_eventloop().run_until_complete(self._async_acquire_new_session()) - except Exception as e: - logger.info("Expecting exception to follow") - logger.exception("Failed _acquire_session_info") - return False + assert isinstance(self._cached_session_file, Path) + assert self._cached_session_file.is_file() or not self._cached_session_file.exists() - def _acquire_session_info(self, use_current_values: bool = True, force_new: bool = False) -> bool: + async def _async_acquire_session(self) -> bool: """ - Attempt to set the session information properties needed for a secure connection. - - Parameters - ---------- - use_current_values : bool - Whether to use currently held attribute values for session details, if already not None (disregarded if - ``force_new`` is ``True``). - force_new : bool - Whether to force acquiring a new session, regardless of data available is available on an existing session. + Acquire an authenticated session. Returns ------- bool - Whether session details were acquired and set successfully. + Whether acquiring an authenticated session was successful. """ - logger.debug("{}._acquire_session_info: getting session info".format(self.__class__.__name__)) - if not force_new and not self._check_if_new_session_needed(use_current_values=use_current_values): - logger.debug('Using previously acquired session details (new session not forced)') + if not self._check_if_new_session_needed(): return True - else: - logger.debug("Session from {}}: force_new={}".format(self.__class__.__name__, force_new)) - tmp = self._acquire_new_session() - logger.debug("Session Info Return: {}".format(tmp)) - return tmp - async def _async_acquire_session_info(self, use_current_values: bool = True, force_new: bool = False) -> bool: - """ - Async attempt to set the session information properties needed for a secure connection. - - Parameters - ---------- - use_current_values : bool - Whether to use currently held attribute values for session details, if already not None (disregarded if - ``force_new`` is ``True``). - force_new : bool - Whether to force acquiring a new session, regardless of data available is available on an existing session. + try: + auth_resp = await self._transport_client.async_send(data=json.dumps(self._prepare_auth_request_payload()), + await_response=True) + # Execute the call to the parsing function before attempting to write, but don't set the attributes yet + session_attribute_vals_tuple = self._parse_auth_data(auth_resp) - Returns - ------- - bool - Whether session details were acquired and set successfully. - """ - if not force_new and not self._check_if_new_session_needed(use_current_values=use_current_values): - logger.debug('Using previously acquired session details (new session not forced)') - return True - else: - tmp = await self._async_acquire_new_session(cached_session_file=self._cached_session_file) - logger.debug("Session Info Return: {}".format(tmp)) - return tmp + # Need a nested try block here to control what happens with a failure to cache the session + try: + self._cached_session_file.write_text(auth_resp) + except Exception as inner_e: + # TODO: consider having parameters/attributes to control exactly how this is handled ... + # ... for now just catch and pass so a bad save file doesn't tank us + msg = f"{self.__class__.__name__} successfully authenticated but failed to cache details to file " \ + f"'{str(self._cached_session_file)}' due to {inner_e.__class__.__name__}: {str(inner_e)}" + logger.warning(msg) + pass - async def _async_acquire_new_session(self, cached_session_file: Optional[Path] = None): - try: - logger.info("Connection to request handler web socket") - auth_details = await self.authenticate(cached_session_file=cached_session_file) - logger.info("auth_details returned") - self._session_id, self._session_secret, self._session_created = auth_details + # Wait until after the cache file write section to modify any instance state + self._session_id, self._session_secret, self._session_created = session_attribute_vals_tuple + self.force_reauth = False self._is_new_session = True return True - except ConnectionResetError as e: - logger.info("Expecting exception to follow") - logger.exception("Failed _acquire_session_info") - return False + # In the future, consider whether we should treat ConnectionResetError separately except Exception as e: - logger.info("Expecting exception to follow") - logger.exception("Failed _acquire_session_info") + msg = f"{self.__class__.__name__} failed to acquire auth credential due to {e.__class__.__name__}: {str(e)}" + logger.error(msg) return False - def _check_if_new_session_needed(self, use_current_values: bool = True) -> bool: + def _check_if_new_session_needed(self) -> bool: """ Check if a new session is required, potentially loading a cached session from an implementation-specific source. @@ -252,147 +416,244 @@ def _check_if_new_session_needed(self, use_current_values: bool = True) -> bool: For the default implementation of this function, the source for a cached session is a serialized session file. - Loading of a cached session will not be done if ``use_current_values`` is ``True`` and session attributes are - properly set (i.e., non-``None`` and non-empty). Further, loaded cached session details will not be used if any - is empty or ``None``. + For a new session to be needed, there must be no other **acceptable** source of authenticated session data. + + If ::attribute:`force_reauth` is set to ``True``, any currently stored session attributes are cleared and the + function returns ``True``. Nothing is loaded from a cached session file. + + If ::attribute:`force_reload` is set to ``True``, any currently stored session attributes are cleared. However, + the function does not return at this point, and instead proceeds with remaining logic. + + The session attributes of this instance subsequently checked for acceptable session data. If at this point they + are all properly set (i.e., non-``None`` and non-empty) and ::attribute:`force_reload` is ``False``, then the + function returns ``False``. + + If any session attributes are not properly set or ::attribute:`force_reload` is ``True``, the function attempts + to load a session from the cached session file. If valid session attributes can be loaded, the function then + returns ``False``. If they could not be loaded, the function will return ``True``, indicating a new session + needs to be acquired. The function will return ``False`` IFF all session attributes are non-``None`` and non-empty at the end of the function's execution. - Parameters - ---------- - use_current_values : bool - Whether it is acceptable to use the current values of the instance's session-related attributes, if all such - attributes already have values set. - Returns ------- bool Whether a new session must be acquired. """ - # If we should use current values, and current values constitute a valid session, then we do not need a new one - if use_current_values and all([self._session_id, self._session_secret, self._session_created]): + # If we need to re-auth, clear any old session data and immediately return True (i.e., new session is needed) + if self.force_reauth: + self._session_id, self._session_secret, self._session_created = None, None, None + return True + + # If we need to reload, also clear any old session data, but this time proceed with the rest of the function + if self.force_reload: + self._session_id, self._session_secret, self._session_created = None, None, None + # Once we force clearing these to ensure a reload is attempted, reset the attribute + self.force_reload = False + # If not set to force a reload, we may already have valid session attributes; short here if so + elif all([self._session_id, self._session_secret, self._session_created]): return False + # If there is a cached session file, we will try to load from it if self._cached_session_file.exists(): try: - session_id, secret, created = self.parse_session_auth_text(self._cached_session_file.read_text()) + session_id, secret, created = self._parse_auth_data(self._cached_session_file.read_text()) # Only set if all three read properties are valid if all([session_id, secret, created]): self._session_id = session_id self._session_secret = secret self._session_created = created + self._is_new_session = False except Exception as e: pass # Return opposite of whether session properties are now set correctly (that would mean don't need a session) return not all([self._session_id, self._session_secret, self._session_created]) - # Otherwise (i.e., don't/can't use current session details + no cached file to load), need a new session else: return True - # TODO: ... - async def authenticate(self, cached_session_file: Optional[Path] = None): - #async with websockets.connect(self.endpoint_uri, ssl=self.client_ssl_context) as websocket: - #async with websockets.connect(self.maas_endpoint_uri) as websocket: - # return await EditView._authenticate_over_websocket(websocket) - # Right now, it doesn't matter as long as it is valid - # TODO: Fix this to not be ... fixed ... - json_as_dict = {'username': 'someone', 'user_secret': 'something'} - response_txt = await self.async_send(data=json.dumps(json_as_dict), await_response=True) - try: - if cached_session_file is not None and not cached_session_file.is_dir() \ - and cached_session_file.parent.is_dir(): - cached_session_file.write_text(response_txt) - except Exception as e: - # TODO: consider logging something here, but for now just handle so a bad save file doesn't tank us - pass - #print('*************** Auth response: ' + json.dumps(response_txt)) - return self.parse_session_auth_text(response_txt) - @property - def is_new_session(self): - return self._is_new_session + def force_reload(self) -> bool: + """ + Whether client should be forced to reload cached auth data on the next call to ::method:`async_acquire_session`. - def parse_session_auth_text(self, auth_text: str): - auth_response = json.loads(auth_text) - # TODO: consider making sure this parses to a SessionInitResponse - maas_session_id = auth_response['data']['session_id'] - maas_session_secret = auth_response['data']['session_secret'] - maas_session_created = auth_response['data']['created'] - return maas_session_id, maas_session_secret, maas_session_created + Note that this property will be (re)set to ``False`` after the next call to ::method:`async_acquire_session`. - @property - def session_created(self): - return self._session_created + Returns + ------- + bool + Whether to force reloading cached auth data on the next called to ::method:`async_acquire_session`. + """ + return self._force_reload - @property - def session_id(self): - return self._session_id + @force_reload.setter + def force_reload(self, should_force_reload: bool): + self._force_reload = should_force_reload @property - def session_secret(self): - return self._session_secret + def is_new_session(self) -> Optional[bool]: + """ + Whether the current session was obtained newly from the service, as opposed to read from cache. + Returns + ------- + Optional[bool] + Whether the current session was obtained newly from the service, as opposed to read from cache; ``None`` if + no session is yet acquired/loaded. + """ + return self._is_new_session -class WebSocketClient(AbstractClient, ABC): - """ - Abstract subtype of ::class:`AbstractClient` that specifically works over websocket connections. - An abstract websocket-based implementation of ::class:`AbstractClient`. Instances are also async context managers - for runtime contexts that handle websocket connections, with the manager function returning the instance itself. +class RequestClient: + """ + Simple DMOD service client, dealing with DMOD request message and response objects. - A new runtime context will check whether there is an open websocket connection already and open a connection if not. - In all cases, it maintains an instance attribute that is a counter of the number of active usages of the connection - (i.e., the number of separate, active contexts). When the context is exited, the instance's active usage counter is - reduced by one and, if that context represents the last active use of the connection, the connection object is - closed and then has its reference removed. + Basic client type for interaction with a DMOD service. Its primary function, ::method:`async_make_request`, accepts + some DMOD ::class:`AbstractInitRequest` object, uses a ::class:`TransportLayerClient` to submit the request object + to a service, and receives/returns the service's response. - The ::method:`async_send` and ::method:`async_recv` functions can be used without already being in an active context - (i.e., they will enter a new context for the scope of the function). However, within in an already open context, - calls to ::method:`async_send` and ::method:`async_recv` can be used as needed to support arbitrarily communication - over the websocket. + To parse responses, instances must know the appropriate class type for a response. This can be provided as an + optional parameter to ::method:`async_make_request`. A default response class type can also be supplied to an + instance during init, which is used by ::method:`async_make_request` if a class type is not provided. One of the + two must be set for ::method:`async_make_request` to function. """ - @classmethod - def build_endpoint_uri(cls, host: str, port: Union[int, str], path: Optional[str] = None, is_secure: bool = True): - proto = 'wss' if is_secure else 'ws' - if path is None: - path = '' - else: - path = path.strip() - if path[0] != '/': - path = '/' + path - return '{}://{}:{}{}'.format(proto, host.strip(), str(port).strip(), path) - - def __init__(self, ssl_directory: Path, *args, **kwargs): + def __init__(self, *, + transport_client: TransportLayerClient, + default_response_type: Optional[Type[Response]] = None, + **kwargs): """ - Initialize this instance. + Initialize. Parameters ---------- - ssl_directory - args + transport_client : TransportLayerClient + The client for handling the underlying raw OSI transport layer communications with the service. + default_response_type: Optional[Type[Response]] + Optional class type for responses, to use when no response class param is given when making a request. kwargs + """ + self._transport_client = transport_client + self._default_response_type: Optional[Type[Response]] = default_response_type - Other Parameters + def _process_request_response(self, response_str: str, response_type: Optional[Type[Response]] = None) -> Response: + """ + Process the serial form of a response returned by ::method:`async_send` into a response object. + + Parameters ---------- - endpoint_uri : str - The endpoint for the client to connect to when opening a connection, for superclass init. + response_str : str + The string returned by a request made via ::method:`async_send`. + response_type: Optional[Type[Response]] + An optional class type for the response that, if ``None`` (the default) is replaced with the default + provided at initialization. + + Returns + ------- + Response + The inflated response object. + + See Also + ------- + async_send """ - super().__init__(*args, **kwargs) + if response_type is None: + response_type = self._default_response_type + + response_json = {} + try: + # Consume the response confirmation by deserializing first to JSON, then from this to a response object + response_json = json.loads(response_str) + try: + response_object = response_type.factory_init_from_deserialized_json(response_json) + if response_object is None: + msg = f'********** {self.__class__.__name__} could not deserialize {response_type.__name__} ' \ + f'from raw websocket response: `{str(response_str)}`' + reason = f'{self.__class__.__name__} Could Not Deserialize To {response_type.__name__}' + response_object = response_type(success=False, reason=reason, message=msg, data=response_json) + except Exception as e2: + msg = f'********** While deserializing {response_type.__name__}, {self.__class__.__name__} ' \ + f'encountered {e2.__class__.__name__}: {str(e2)}' + reason = f'{self.__class__.__name__} {e2.__class__.__name__} Deserialize {response_type.__name__}' + response_object = response_type(success=False, reason=reason, message=msg, data=response_json) + except Exception as e: + reason = 'Invalid JSON Response' + msg = f'Encountered {e.__class__.__name__} loading response to JSON: {str(e)}' + response_object = response_type(success=False, reason=reason, message=msg, data=response_json) + + if not response_object.success: + logging.error(response_object.message) + logging.debug(f'{self.__class__.__name__} returning {str(response_type)} {response_str}') + return response_object - self._ssl_directory = ssl_directory - """Path: The parent directory of the cert PEM file used for the client SSL context.""" + async def async_make_request(self, message: AbstractInitRequest, response_type: Optional[Type[Response]] = None) -> Response: + """ + Async send a request message object and return the received response. - # Setup this as a property to allow more private means to override the actual filename of the cert PEM file - self._client_ssl_context = None - """ssl.SSLContext: The private field for the client SSL context property.""" + Send (within Python's async functionality) the appropriate type of request :class:`Message` for this client + implementation type and return the response as a corresponding, appropriate :class:`Response` instance. - self._cert_pem_file_basename: str = 'certificate.pem' - """str: The basename of the certificate PEM file to use.""" + Parameters + ---------- + message : AbstractInitRequest + The request message object. + response_type: Optional[Type[Response]] + An optional class type for the response that, if ``None`` (the default) is replaced with the default + provided at initialization. - self.connection: typing.Optional[websockets.WebSocketClientProtocol] = None - """Optional[websockets.client.Connect]: The open websocket connection, if set, for this client's context.""" + Returns + ------- + Response + the request response object + """ + if response_type is None: + if self._default_response_type is None: + msg = f"{self.__class__.__name__} can't make request with neither response type parameter or default" + raise RuntimeError(msg) + else: + response_type = self._default_response_type + + response_json = {} + try: + # Send the request and get the service response + serialized_response = await self._transport_client.async_send(data=str(message), await_response=True) + if serialized_response is None: + raise ValueError(f'Serialized response from {self.__class__.__name__} async message was `None`') + except Exception as e: + reason = f'{self.__class__.__name__} Send {message.__class__.__name__} Failure ({e.__class__.__name__})' + msg = f'Sending {message.__class__.__name__} raised {e.__class__.__name__}: {str(e)}' + logger.error(msg) + return response_type(success=False, reason=reason, message=msg, data=response_json) + + assert isinstance(serialized_response, str) + return self._process_request_response(serialized_response) + + +class ConnectionContextClient(Generic[CONN], TransportLayerClient, ABC): + """ + Transport client subtype that maintains connections via an async managed contexts. + + Instances of this type will increment an active connections counter upon entering the context. If the counter was + at ``0``, a new connection will be opened using ::method:`_establish_connection` and assigned to + ::attribute:`connection`. The reverse happens on context close, with ::method:`_close_connection` being used to + close the connection once the counter is ``0`` again. + + Subtypes should provide implementations for ::method:`_establish_connection` and ::method:`_close_connection`. + + Implementations of ::method:`async_send` and ::method:`async_recv` functions are provided. They can be used without + already being in an active context (i.e., they will enter a new context for the scope of the function). However, + within in an already open context, calls to ::method:`async_send` and ::method:`async_recv` can be used as needed to + support arbitrarily communication over the websocket. + + The ::method:`async_send` and ::method:`async_recv` implementations depend on ::method:`_connection_send` and + ::method:`_connection_recv`, which must be provided by subtypes. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._connection: typing.Optional[CONN] = None + """Optional[CONN]: The open connection, if set, for this client's context.""" self._opening_connection = False """bool: Whether some task is in the process of opening a new connection in the context, but is awaiting.""" @@ -415,12 +676,12 @@ async def __aenter__(self): # Safely conclude at this point that nothing else (worth paying attention to) is in the middle of opening a # connection, so check whether there already is one ... - if self.connection is None: + if self._connection is None: # If not, mark that this exec is opening a connection, before giving up control during the await self._opening_connection = True # Then asynchronously open the connection ... try: - self.connection = await websockets.connect(self.endpoint_uri, ssl=self.client_ssl_context) + self._connection = await self._establish_connection() except Exception as e: raise e # And now, note that we are no longer in the middle of an attempt to open a connection @@ -435,225 +696,204 @@ async def __aexit__(self, *exc_info): """ self.active_connections -= 1 if self.active_connections < 1: - await self.connection.close() - self.connection = None + await self._close_connection() + self._connection = None self.active_connections = 0 - async def async_send(self, data: Union[str, bytearray], await_response: bool = False): + @abstractmethod + async def _connection_recv(self) -> Optional[str]: """ - Send data to websocket, by default returning immediately, but optionally receiving and returning response. + Perform operations to receive data over already opened ::attribute:`connection`. - The function will cause the runtime context to be entered, opening a connection if needed. In such cases, - the connection will also be closed at the conclusion of this function. + Returns + ------- + Optional[str] + Data received over already opened ::attribute:`connection`. + """ + pass + + @abstractmethod + async def _connection_send(self, data: Union[str, bytearray]): + """ + Perform operations to send data over already opened ::attribute:`connection`. Parameters ---------- - data: Optional[str] + data The data to send. - await_response - Whether the method should also await a response on the websocket connection and return it. + """ + pass - Returns - ------- - Optional[str] - The response to the sent data, if one should be awaited; otherwise ``None``. + @abstractmethod + async def _close_connection(self): """ - async with self as websocket: - #TODO ensure correct type for data??? - await websocket.connection.send(data) - return await websocket.connection.recv() if await_response else None + Close the managed context's established connection. + """ + pass - async def listen(self) -> typing.Union[str, bytes]: + @abstractmethod + async def _establish_connection(self) -> CONN: """ - Waits for a message through the websocket connection + Establish a connection for the managed context. - Returns: - A string for data sent through the socket as a string and bytes for data sent as binary + Returns + ------- + CONN + A newly established connection. """ - async with self as websocket: - return await websocket.connection.recv() + pass - @abstractmethod - async def async_make_request(self, message: Message) -> Response: + async def async_send(self, data: Union[str, bytearray], await_response: bool = False): """ - Send (within Python's async functionality) the appropriate type of request :class:`Message` for this client - implementation type and return the response as a corresponding, appropriate :class:`Response` instance. + Send data over connection, by default returning immediately, but optionally receiving and returning response. + + The function will cause the runtime context to be entered, opening a connection if needed. In such cases, + the connection will also be closed at the conclusion of this function. Parameters ---------- - message - the request message object + data: Optional[str] + The data to send. + await_response + Whether the method should also await a response on the connection and return it. Returns ------- - response - the request response object + Optional[str] + The response to the sent data, if one should be awaited; otherwise ``None``. """ - pass + async with self as connection_owner: + await connection_owner._connection_send(data) + return await connection_owner._connection_recv() if await_response else None async def async_recv(self) -> Union[str, bytes]: """ - Receive data over the websocket connection. + Receive data over the connection. Returns ------- Union[str, bytes] The data received over the connection. """ - with self as websocket: - return await websocket.connection.recv() + with self as connection_owner: + return await connection_owner._connection_recv() @property - def client_ssl_context(self) -> ssl.SSLContext: - """ - Get the client SSL context property, lazily instantiating if necessary. + def connection(self) -> Optional[CONN]: + return self._connection - Returns - ------- - ssl.SSLContext - the client SSL context for secure connections - """ - if self._client_ssl_context is None: - self._client_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - endpoint_pem = self._ssl_directory.joinpath(self._cert_pem_file_basename) - self.client_ssl_context.load_verify_locations(endpoint_pem) - return self._client_ssl_context - -class InternalServiceClient(WebSocketClient, Generic[M, R], ABC): +class WebSocketClient(ConnectionContextClient[websockets.WebSocketClientProtocol]): """ - Abstraction for a client that interacts with some internal, non-public-facing DMOD service. + Subtype of ::class:`ConnectionContextClient` that specifically works over SSL-secured websocket connections. + + A websocket-based implementation of ::class:`ConnectionContextClient`. Instances are also async context managers for + runtime contexts that handle websocket connections, with the manager function returning the instance itself. + + A new runtime context will check whether there is an open websocket connection already and open a connection if not. + In all cases, it maintains an instance attribute that is a counter of the number of active usages of the connection + (i.e., the number of separate, active contexts). When the context is exited, the instance's active usage counter is + reduced by one and, if that context represents the last active use of the connection, the connection object is + closed and then has its reference removed. + + The ::method:`async_send` and ::method:`async_recv` functions can be used without already being in an active context + (i.e., they will enter a new context for the scope of the function). However, within in an already open context, + calls to ::method:`async_send` and ::method:`async_recv` can be used as needed to support arbitrarily communication + over the websocket. """ @classmethod - @abstractmethod - def get_response_subtype(cls) -> Type[R]: + def get_endpoint_protocol_str(cls, use_secure_connection: bool = True) -> str: """ - Return the response subtype class appropriate for this client implementation. + Get the protocol substring portion for valid connection URI strings for an instance of this class. + + Parameters + ---------- + use_secure_connection : bool + Whether to get the protocol substring applicable for secure connections (``True`` by default). Returns ------- - Type[R] - The response subtype class appropriate for this client implementation. + str + The protocol substring portion for valid connection URI strings for an instance of this class. """ - pass + return 'wss' if use_secure_connection else 'ws' - def build_response(self, success: bool, reason: str, message: str = '', data: Optional[dict] = None, - **kwargs) -> R: + async def _connection_recv(self) -> Optional[str]: """ - Build a response of the appropriate subtype from the given response details. + Perform operations to receive data over already opened ::attribute:`connection`. - Build a response of the appropriate subtype for this particular implementation, using the given parameters for - this function as the initialization params for the response. Per the design of ::class:`Response`, the primary - attributes are ::attribute:`Response.success`, ::attribute:`Response.reason`, ::attribute:`Response.message`, - and ::attribute:`Response.data`. However, implementations may permit or require additional param values, which - can be supplied via keyword args. - - As with the init of ::class:`Request`, defaults of ``''`` (empty string) and ``None`` are in place for for - ``message`` and ``data`` respectively. + Returns + ------- + Optional[str] + Data received over already opened ::attribute:`connection`. + """ + return await self.connection.recv() - A default implementation is provided that initializes an instance of the type return by - ::method:`get_response_subtype`. Keyword args are not used in this default implementation. + async def _connection_send(self, data: Union[str, bytearray]): + """ + Perform operations to send data over already opened ::attribute:`connection`. Parameters ---------- - success : bool - The value for ::attribute:`Response.success` to use when initializing the response object. - reason : str - The value for ::attribute:`Response.reason` to use when initializing the response object. - message : str - The value for ::attribute:`Response.message` to use when initializing the response object (default: ``''``). - data : dict - The value for ::attribute:`Response.data` to use when initializing the response object (default: ``None``). - kwargs : dict - A dict for any additional implementation specific init params for the response object. - - Returns - ------- - R - A response object of the appropriate subtype. + data + The data to send. """ - return self.get_response_subtype()(success=success, reason=reason, message=message, data=data) - - def _process_request_response(self, response_str: str): - response_type = self.get_response_subtype() - my_class_name = self.__class__.__name__ - response_json = {} - try: - # Consume the response confirmation by deserializing first to JSON, then from this to a response object - response_json = json.loads(response_str) - try: - response_object = response_type.factory_init_from_deserialized_json(response_json) - if response_object is None: - msg = '********** {} could not deserialize {} from raw websocket response: `{}`'.format( - my_class_name, response_type.__name__, str(response_str)) - reason = '{} Could Not Deserialize To {}'.format(my_class_name, response_type.__name__) - response_object = self.build_response(success=False, reason=reason, message=msg, data=response_json) - except Exception as e2: - msg = '********** While deserializing {}, {} encountered {}: {}'.format( - response_type.__name__, my_class_name, e2.__class__.__name__, str(e2)) - reason = '{} {} Deserializing {}'.format(my_class_name, e2.__class__.__name__, response_type.__name__) - response_object = self.build_response(success=False, reason=reason, message=msg, data=response_json) - except Exception as e: - reason = 'Invalid JSON Response' - msg = 'Encountered {} loading response to JSON: {}'.format(e.__class__.__name__, str(e)) - response_object = self.build_response(success=False, reason=reason, message=msg, data=response_json) + await self.connection.send(data) - if not response_object.success: - logging.error(response_object.message) - logging.debug('************* {} returning {} object {}'.format(self.__class__.__name__, response_type.__name__, - response_object.to_json())) - return response_object + async def _close_connection(self): + """ + Close the managed context's established connection. + """ + await self.connection.close() - async def async_make_request(self, message: M) -> R: + async def _establish_connection(self) -> CONN: """ - Async send the given request and return the corresponding response. + Establish a connection for the managed context. - Send (within Python's async functionality) the appropriate type of request :class:`Message` for this client - implementation type and return the response as a corresponding, appropriate :class:`Response` instance. + Returns + ------- + CONN + A newly established connection. + """ + return await websockets.connect(self._get_endpoint_uri(), ssl=self._client_ssl_context) - Parameters - ---------- - message : M - The request message object. + def _get_endpoint_uri(self) -> str: + """ + The endpoint for the client to connect to when opening a connection. Returns ------- - response : R - The request response object. - """ - response_type = self.get_response_subtype() - expected_req_type = response_type.get_response_to_type() - my_class_name = self.__class__.__name__ - req_class_name = message.__class__.__name__ - - if not isinstance(message, expected_req_type): - reason = '{} Received Unexpected Type {}'.format(my_class_name, req_class_name) - msg = '{} received unexpected {} instance as request, rather than a {} instance; not submitting'.format( - my_class_name, req_class_name, expected_req_type.__name__) - logger.error(msg) - return self.build_response(success=False, reason=reason, message=msg) + str + The endpoint for the client to connect to when opening a connection. + """ + if self._endpoint_uri is None: + proto = self.get_endpoint_protocol_str(use_secure_connection=self._client_ssl_context is not None) - response_json = {} - try: - # Send the request and get the service response - serialized_response = await self.async_send(data=str(message), await_response=True) - if serialized_response is None: - raise ValueError('Serialized response from {} async message was `None`'.format(my_class_name)) - except Exception as e: - reason = '{} Send {} Failure ({})'.format(my_class_name, req_class_name, e.__class__.__name__) - msg = '{} encountered {} sending {}: {}'.format(my_class_name, e.__class__.__name__, req_class_name, str(e)) - logger.error(msg) - return self.build_response(success=False, reason=reason, message=msg, data=response_json) + if self._endpoint_path and self._endpoint_path[0] != '/': + path_str = '/' + self._endpoint_path + else: + path_str = self._endpoint_path - return self._process_request_response(serialized_response) + self._endpoint_uri = f"{proto}://{self._endpoint_host}:{self._endpoint_port!s}{path_str}" + return self._endpoint_uri + async def listen(self) -> typing.Union[str, bytes]: + """ + Waits for a message through the websocket connection + + Returns: + A string for data sent through the socket as a string and bytes for data sent as binary + """ + async with self as websocket: + return await websocket.connection.recv() -class SchedulerClient(InternalServiceClient[SchedulerRequestMessage, SchedulerRequestResponse]): - @classmethod - def get_response_subtype(cls) -> Type[SchedulerRequestResponse]: - return SchedulerRequestResponse +@deprecated("Use RequestClient or ExternalRequestClient instead") +class SchedulerClient(RequestClient): + + def __init__(self, *args, **kwargs): + super().__init__(default_response_type=SchedulerRequestResponse, *args, **kwargs) async def async_send_update(self, message: UpdateMessage) -> UpdateMessageResponse: """ @@ -678,7 +918,7 @@ async def async_send_update(self, message: UpdateMessage) -> UpdateMessageRespon response_json = {} serialized_response = None try: - serialized_response = await self.async_send(data=str(message), await_response=True) + serialized_response = await self._transport_client.async_send(data=str(message), await_response=True) if serialized_response is None: raise ValueError('Response from {} async update message was `None`'.format(self.__class__.__name__)) response_object = UpdateMessageResponse.factory_init_from_deserialized_json(json.loads(serialized_response)) @@ -692,72 +932,10 @@ async def async_send_update(self, message: UpdateMessage) -> UpdateMessageRespon return UpdateMessageResponse(digest=message.digest, object_found=False, success=False, reason=reason, response_text='None' if serialized_response is None else serialized_response) - async def get_results(self): - logging.debug('************* {} preparing to yield results'.format(self.__class__.__name__)) - async for message in self.connection: - logging.debug('************* {} yielding result: {}'.format(self.__class__.__name__, str(message))) - yield message - - -class ExternalRequestClient(ExternalClient, WebSocketClient, Generic[EXTERN_REQ_M, EXTERN_REQ_R], ABC): - - @staticmethod - def _request_failed_due_to_expired_session(response_obj: EXTERN_REQ_R): - """ - Test if request failed due to an expired session. - - Test if the response to a websocket-sent request failed specifically because the utilized session is consider to - be expired, either because the session is explicitly expired or there is no longer a record of the session with - the session secret in the init request (i.e., it is implicitly expired). - Parameters - ---------- - response_obj - - Returns - ------- - bool - whether a failure occur and it specifically was due to a lack of authorization over the used session - """ - is_expired = response_obj.reason_enum == InitRequestResponseReason.UNRECOGNIZED_SESSION_SECRET - is_expired = is_expired or response_obj.reason_enum == InitRequestResponseReason.EXPIRED_SESSION - return response_obj is not None and not response_obj.success and is_expired +class ExternalRequestClient(RequestClient): - @classmethod - def _run_validation(cls, message: Union[EXTERN_REQ_M, EXTERN_REQ_R]): - """ - Run validation for the given message object using the appropriate validator subtype. - - Parameters - ---------- - message - The message to validate, which will be either a ``ExternalRequest`` or a ``ExternalRequestResponse`` subtype. - - Returns - ------- - tuple - A tuple with the first item being whether or not the message was valid, and the second being either None or - the particular error that caused the message to be identified as invalid - - Raises - ------- - RuntimeError - Raised if the message is of a particular type for which there is not a supported validator type configured. - """ - if message is None: - return False, None - elif isinstance(message, NWMRequest): - is_valid, error = NWMRequestJsonValidator().validate(message.to_dict()) - return is_valid, error - elif isinstance(message, NGENRequest): - is_valid, error = NWMRequestJsonValidator().validate(message.to_dict()) - return is_valid, error - elif isinstance(message, Serializable): - return message.__class__.factory_init_from_deserialized_json(message.to_dict()) == message, None - else: - raise RuntimeError('Unsupported ExternalRequest subtype: ' + str(message.__class__)) - - def __init__(self, *args, **kwargs): + def __init__(self, auth_client: AuthClient, *args, **kwargs): """ Initialize instance. @@ -768,42 +946,47 @@ def __init__(self, *args, **kwargs): Other Parameters ---------- - endpoint_uri : str - The client connection endpoint for opening new websocket connections, required for superclass init. - ssl_directory : Path - The directory of the SSL certificate files for the client SSL context. - session_file : Optional[Path] - Optional path to file for a serialized session, both for loading from and saving to. + transport_client: TransportLayerClient """ super().__init__(*args, **kwargs) + self._auth_client: AuthClient = auth_client + self._errors = None self._warnings = None self._info = None - @abstractmethod - def _update_after_valid_response(self, response: EXTERN_REQ_R): + async def async_make_request(self, message: ExternalRequest, + response_type: Optional[Type[ExternalRequestResponse]] = None) -> ExternalRequestResponse: """ - Perform any required internal updates immediately after a request gets back a successful, valid response. + Async send a request message object and return the received response. - This provides a way of extending the behavior of this type specifically regarding the ::method:make_maas_request - function. Any updates specific to the type, which should be performed after a request receives back a valid, - successful response object, can be implemented here. + Send (within Python's async functionality) the appropriate type of request :class:`Message` for this client + implementation type and return the response as a corresponding, appropriate :class:`Response` instance. - In the base implementation, no further action is taken. + Parameters + ---------- + message : ExternalRequest + The request message object. + response_type: Optional[Type[ExternalRequestResponse]] + An optional class type for the response that, if ``None`` (the default) is replaced with the default + provided at initialization. - See Also + Returns ------- - ::method:make_maas_request + EXTERN_REQ_R + the request response object """ - pass + if response_type is None: + response_type = self._default_response_type - # TODO: this can probably be taken out, as the superclass implementation should suffice - async def async_make_request(self, request: EXTERN_REQ_M) -> EXTERN_REQ_R: - async with websockets.connect(self.endpoint_uri, ssl=self.client_ssl_context) as websocket: - await websocket.send(request.to_json()) - response = await websocket.recv() - return request.__class__.factory_init_correct_response_subtype(json_obj=json.loads(response)) + if await self._auth_client.apply_auth(message): + return await super().async_make_request(message, response_type=response_type) + else: + reason = f'{self.__class__.__name__} Request Auth Failure' + msg = f'{self.__class__.__name__} async_make_request could not apply auth to {message.__class__.__name__}' + logger.error(msg) + return response_type(success=False, reason=reason, message=msg) @property def errors(self): @@ -813,104 +996,23 @@ def errors(self): def info(self): return self._info - @property - def is_new_session(self): - return self._is_new_session - - def make_maas_request(self, maas_request: EXTERN_REQ_M, force_new_session: bool = False): - request_type_str = maas_request.__class__.__name__ - logger.debug("client Making {} type request".format(request_type_str)) - self._acquire_session_info(force_new=force_new_session) - # Make sure to set if empty or reset if a new session was forced and just acquired - if force_new_session or maas_request.session_secret is None: - maas_request.session_secret = self._session_secret - # If able to get session details, proceed with making a job request - if self._session_secret is not None: - print("******************* Request: " + maas_request.to_json()) - try: - is_request_valid, request_validation_error = self._run_validation(message=maas_request) - if is_request_valid: - try: - response_obj: EXTERN_REQ_R = get_or_create_eventloop().run_until_complete( - self.async_make_request(maas_request)) - print('***************** Response: ' + str(response_obj)) - # Try to get a new session if session is expired (and we hadn't already gotten a new session) - if self._request_failed_due_to_expired_session(response_obj) and not force_new_session: - return self.make_maas_request(maas_request=maas_request, force_new_session=True) - elif not self.validate_maas_request_response(response_obj): - raise RuntimeError('Invalid response received for requested job: ' + str(response_obj)) - elif not response_obj.success: - template = 'Request failed (reason: {}): {}' - raise RuntimeError(template.format(response_obj.reason, response_obj.message)) - else: - self._update_after_valid_response(response_obj) - return response_obj - except Exception as e: - # TODO: log error instead of print - msg_template = 'Encountered error submitting {} over session {} : \n{}: {}' - msg = msg_template.format(request_type_str, str(self._session_id), str(type(e)), str(e)) - print(msg) - traceback.print_exc() - self.errors.append(msg) - else: - msg_template = 'Could not submit invalid MaaS request over session {} ({})' - msg = msg_template.format(str(self._session_id), str(request_validation_error)) - print(msg) - self.errors.append(msg) - except RuntimeError as e: - print(str(e)) - self.errors.append(str(e)) - else: - logger.info("client Unable to aquire session details") - self.errors.append("Unable to acquire session details or authenticate new session for request") - return None - - def validate_maas_request_response(self, maas_request_response: EXTERN_REQ_R): - return self._run_validation(message=maas_request_response)[0] - @property def warnings(self): return self._warnings -class DataServiceClient(InternalServiceClient[DatasetManagementMessage, DatasetManagementResponse]): +@deprecated("Use RequestClient or ExternalRequestClient instead") +class DataServiceClient(RequestClient): """ Client for data service communication between internal DMOD services. """ - @classmethod - def get_response_subtype(cls) -> Type[DatasetManagementResponse]: - return DatasetManagementResponse - - -class ModelExecRequestClient(ExternalRequestClient[MOD_EX_M, MOD_EX_R], ABC): - - def __init__(self, endpoint_uri: str, ssl_directory: Path): - super().__init__(endpoint_uri=endpoint_uri, ssl_directory=ssl_directory) - - def _update_after_valid_response(self, response: MOD_EX_R): - """ - Perform any required internal updates immediately after a request gets back a successful, valid response. - - This provides a way of extending the behavior of this type specifically regarding the ::method:make_maas_request - function. Any updates specific to the type, which should be performed after a request receives back a valid, - successful response object, can be implemented here. - In this implementation, the ::attribute:`info` property is appended to, noting that the job of the given id has - just been started by the scheduler. - - See Also - ------- - ::method:make_maas_request - """ - #self.job_id = self.resp_as_json['data']['job_id'] - #results = self.resp_as_json['data']['results'] - #jobs = self.resp_as_json['data']['all_jobs'] - #self.info.append("Scheduler started job, id {}, results: {}".format(self.job_id, results)) - #self.info.append("All user jobs: {}".format(jobs)) - self.info.append("Scheduler started job, id {}".format(response.data['job_id'])) + def __init__(self, *args, **kwargs): + super().__init__(default_response_type=DatasetManagementResponse, *args, **kwargs) -class PartitionerServiceClient(InternalServiceClient[PartitionRequest, PartitionResponse]): +@deprecated("Use RequestClient or ExternalRequestClient instead") +class PartitionerServiceClient(RequestClient): """ A client for interacting with the partitioner service. @@ -918,31 +1020,14 @@ class PartitionerServiceClient(InternalServiceClient[PartitionRequest, Partition does not need to be a (public) ::class:`ExternalRequestClient` based type. """ - @classmethod - def get_response_subtype(cls) -> Type[PartitionResponse]: - """ - Return the response subtype class appropriate for this client implementation. - - Returns - ------- - Type[PartitionResponse] - The response subtype class appropriate for this client implementation. - """ - return PartitionResponse + def __init__(self, *args, **kwargs): + super().__init__(default_response_type=PartitionResponse, *args, **kwargs) -class EvaluationServiceClient(InternalServiceClient[EvaluationConnectionRequest, EvaluationConnectionRequestResponse]): +@deprecated("Use RequestClient or ExternalRequestClient instead") +class EvaluationServiceClient(RequestClient): """ A client for interacting with the evaluation service """ - - @classmethod - def get_response_subtype(cls) -> Type[EvaluationConnectionRequestResponse]: - """ - Return the response subtype class appropriate for this client implementation - - Returns: - Type[EvaluationConnectionRequestResponse] - The response subtype class appropriate for this client implementation - """ - return EvaluationConnectionRequestResponse + def __init__(self, *args, **kwargs): + super().__init__(default_response_type=EvaluationConnectionRequestResponse, *args, **kwargs) diff --git a/python/lib/communication/dmod/test/test_scheduler_client.py b/python/lib/communication/dmod/test/test_scheduler_client.py index f2142a1b1..2e834345c 100644 --- a/python/lib/communication/dmod/test/test_scheduler_client.py +++ b/python/lib/communication/dmod/test/test_scheduler_client.py @@ -1,20 +1,20 @@ import asyncio -import json import logging +import ssl import unittest -from pathlib import Path from typing import Optional, Union -from ..communication import NWMRequest, SchedulerClient, SchedulerRequestMessage, SchedulerRequestResponse +from ..communication import NWMRequest, SchedulerClient, SchedulerRequestMessage, SchedulerRequestResponse, \ + TransportLayerClient -class MockSendTestingSchedulerClient(SchedulerClient): - """ - Customized extension of ``SchedulerClient`` for testing purposes, where the :meth:`async_send` method has been - overridden with a mock implementation to allow for testing without actually needing a real websocket connection. - """ +class MockTransportLayerClient(TransportLayerClient): + + @classmethod + def get_endpoint_protocol_str(cls, use_secure_connection: bool = True) -> str: + return "mock" def __init__(self): - super().__init__(endpoint_uri='', ssl_directory=Path('.')) + super().__init__(endpoint_host='', endpoint_port=8888) self.test_responses = dict() @@ -52,20 +52,56 @@ async def async_send(self, data: Union[str, bytearray], await_response: bool = F else: return str(response) - def set_scheduler_response_none(self): + async def async_recv(self) -> str: + pass + + @property + def client_ssl_context(self) -> ssl.SSLContext: + pass + + @property + def _get_endpoint_uri(self) -> str: + return '' + + def set_client_response_none(self): self.test_response_selection = 0 - def set_scheduler_response_non_json_string(self): + def set_client_response_non_json_string(self): self.test_response_selection = 1 - def set_scheduler_response_unrecognized_json(self): + def set_client_response_unrecognized_json(self): self.test_response_selection = 2 - def set_scheduler_response_valid_obj_for_failure(self): + def set_client_response_valid_obj_for_failure(self): self.test_response_selection = 3 - def set_scheduler_response_valid_obj_for_success(self): + def set_client_response_valid_obj_for_success(self): self.test_response_selection = 4 + + +class MockSendTestingSchedulerClient(SchedulerClient): + """ + Customized extension of ``SchedulerClient`` for testing purposes, where the :meth:`async_send` method has been + overridden with a mock implementation to allow for testing without actually needing a real websocket connection. + """ + + def __init__(self): + super().__init__(transport_client=MockTransportLayerClient()) + + def set_scheduler_response_none(self): + self._transport_client.test_response_selection = 0 + + def set_scheduler_response_non_json_string(self): + self._transport_client.test_response_selection = 1 + + def set_scheduler_response_unrecognized_json(self): + self._transport_client.test_response_selection = 2 + + def set_scheduler_response_valid_obj_for_failure(self): + self._transport_client.test_response_selection = 3 + + def set_scheduler_response_valid_obj_for_success(self): + self._transport_client.test_response_selection = 4 class TestSchedulerClient(unittest.TestCase): @@ -103,26 +139,6 @@ def tearDown(self) -> None: self.loop.stop() self.loop.close() - def test_get_response_subtype_1_a(self): - """ - Test that ``get_response_subtype`` returns the right type. - """ - self.assertEqual(SchedulerRequestResponse, self.client.get_response_subtype()) - - def test_build_response_1_a(self): - """ - Basic test to ensure this function operates correctly. - """ - response = self.client.build_response(success=True, reason='Test Good', message='Test worked correctly') - self.assertTrue(isinstance(response, SchedulerRequestResponse)) - - def test_build_response_1_b(self): - """ - Basic test to ensure this response has the expected ``success`` value. - """ - response = self.client.build_response(success=True, reason='Test Good', message='Test worked correctly') - self.assertTrue(response.success) - def test_async_make_request_1_a(self): """ Test when function gets ``None`` returned over websocket that response object ``success`` is ``False``. diff --git a/python/lib/externalrequests/dmod/externalrequests/_version.py b/python/lib/externalrequests/dmod/externalrequests/_version.py index b703f5c96..9bdd4d277 100644 --- a/python/lib/externalrequests/dmod/externalrequests/_version.py +++ b/python/lib/externalrequests/dmod/externalrequests/_version.py @@ -1 +1 @@ -__version__ = '0.4.1' \ No newline at end of file +__version__ = '0.5.0' \ No newline at end of file diff --git a/python/lib/externalrequests/dmod/externalrequests/maas_request_handlers.py b/python/lib/externalrequests/dmod/externalrequests/maas_request_handlers.py index 23283d9bb..913d9a45b 100644 --- a/python/lib/externalrequests/dmod/externalrequests/maas_request_handlers.py +++ b/python/lib/externalrequests/dmod/externalrequests/maas_request_handlers.py @@ -6,8 +6,8 @@ from dmod.access import Authorizer from dmod.communication import AbstractRequestHandler, DataServiceClient, FullAuthSession, ExternalRequest, \ - InitRequestResponseReason, InternalServiceClient, PartitionRequest, PartitionResponse, PartitionerServiceClient, \ - Session, SessionManager + InitRequestResponseReason, RequestClient, PartitionRequest, PartitionResponse, PartitionerServiceClient, \ + TransportLayerClient, Session, SessionManager, WebSocketClient from dmod.communication.dataset_management_message import MaaSDatasetManagementMessage, MaaSDatasetManagementResponse, \ ManagementAction from dmod.communication.data_transmit_message import DataTransmitMessage, DataTransmitResponse @@ -38,6 +38,7 @@ def __init__(self, session_manager: SessionManager, authorizer: Authorizer, serv self._service_port = service_port self._service_ssl_dir = service_ssl_dir self._service_url = None + self._transport_client = None async def _is_authorized(self, request: ExternalRequest, session: FullAuthSession) -> bool: """ @@ -129,15 +130,25 @@ async def get_authorized_session(self, request: ExternalRequest) -> Tuple[ msg = None return session, is_authorized, reason, msg + @property + def transport_client(self) -> TransportLayerClient: + if self._transport_client is None: + # TODO: parameterize whether to, e.g., use websocket uri/protocol, as opposed to something else + # TODO: subsequent PR that removes this from these types (receive a service client on init) or at least has + # it supplied on init. + self._transport_client = WebSocketClient(endpoint_host=self._service_host, endpoint_port=self._service_port, + cafile=self.service_ssl_dir.joinpath("certificate.pem")) + return self._transport_client + @property @abstractmethod - def service_client(self) -> InternalServiceClient: + def service_client(self) -> RequestClient: """ - Get the client for interacting with the service, which also is a context manager for connections. + Get the client for interacting with the service. Returns ------- - InternalServiceClient + RequestClient The client for interacting with the service. """ pass @@ -146,12 +157,6 @@ def service_client(self) -> InternalServiceClient: def service_ssl_dir(self) -> Path: return self._service_ssl_dir - @property - def service_url(self) -> str: - if self._service_url is None: - self._service_url = 'wss://{}:{}'.format(str(self._service_host), str(self._service_port)) - return self._service_url - class PartitionRequestHandler(MaaSRequestHandler): @@ -202,7 +207,7 @@ async def determine_required_access_types(self, request: PartitionRequest, user) @property def service_client(self) -> PartitionerServiceClient: if self._service_client is None: - self._service_client = PartitionerServiceClient(self.service_url, self.service_ssl_dir) + self._service_client = PartitionerServiceClient(transport_client=self.transport_client) return self._service_client async def handle_request(self, request: PartitionRequest, **kwargs) -> PartitionResponse: @@ -332,5 +337,5 @@ async def handle_request(self, request: MaaSDatasetManagementMessage, **kwargs) @property def service_client(self) -> DataServiceClient: if self._service_client is None: - self._service_client = DataServiceClient(endpoint_uri=self.service_url, ssl_directory=self.service_ssl_dir) + self._service_client = DataServiceClient(transport_client=self.transport_client) return self._service_client diff --git a/python/lib/externalrequests/dmod/externalrequests/model_exec_request_handler.py b/python/lib/externalrequests/dmod/externalrequests/model_exec_request_handler.py index bc0a641b3..d7effd285 100644 --- a/python/lib/externalrequests/dmod/externalrequests/model_exec_request_handler.py +++ b/python/lib/externalrequests/dmod/externalrequests/model_exec_request_handler.py @@ -183,7 +183,7 @@ async def handle_request(self, request: ModelExecRequest, **kwargs) -> ModelExec @property def service_client(self) -> SchedulerClient: if self._scheduler_client is None: - self._scheduler_client = SchedulerClient(ssl_directory=self.service_ssl_dir, endpoint_uri=self.service_url) + self._scheduler_client = SchedulerClient(transport_client=self.transport_client) return self._scheduler_client diff --git a/python/lib/externalrequests/setup.py b/python/lib/externalrequests/setup.py index 8a304ea9f..f5ac1b936 100644 --- a/python/lib/externalrequests/setup.py +++ b/python/lib/externalrequests/setup.py @@ -20,6 +20,6 @@ author_email='', url='', license='', - install_requires=['websockets', 'dmod-core>=0.1.0', 'dmod-communication>=0.4.2', 'dmod-access>=0.1.1'], + install_requires=['websockets', 'dmod-core>=0.1.0', 'dmod-communication>=0.15.0', 'dmod-access>=0.1.1'], packages=find_namespace_packages(exclude=['dmod.test', 'schemas', 'ssl', 'src']) )