diff --git a/pyproject.toml b/pyproject.toml index 1171f8de..5f8cab9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,11 +168,21 @@ suppress-none-returning = true # https://docs.astral.sh/ruff/settings/#suppress- [tool.ruff.lint.flake8-comprehensions] allow-dict-calls-with-keyword-arguments = true # https://docs.astral.sh/ruff/settings/#allow-dict-calls-with-keyword-arguments + [tool.ruff.lint.pycodestyle] max-doc-length = 120 # https://docs.astral.sh/ruff/settings/#max-doc-length [tool.ruff.lint.per-file-ignores] "__init__.py" = ["D104"] +"tests/*" = [ + "SLF001", # Private member accessed + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "PLR0915", # Too many statements + "PLR2004", # Magic value used in comparison + "E722", # Do not use bare `except` +] [tool.mypy] python_version = "3.9" # https://mypy.readthedocs.io/en/stable/config_file.html#confval-python_version diff --git a/src/sdc11073/consumer/serviceclients/contextservice.py b/src/sdc11073/consumer/serviceclients/contextservice.py index c5b95215..e1375ce1 100644 --- a/src/sdc11073/consumer/serviceclients/contextservice.py +++ b/src/sdc11073/consumer/serviceclients/contextservice.py @@ -1,3 +1,4 @@ +"""The module contains the implementation of the BICEPS context service.""" from __future__ import annotations from typing import TYPE_CHECKING @@ -41,13 +42,13 @@ def mk_proposed_context_object(self, descriptor_handle: str, mdib = self._mdib_wref() if mdib is None: raise ApiUsageError('no mdib information') - context_descriptor_container = mdib.descriptions.handle.get_one(descriptor_handle) + context_entity = mdib.entities.by_handle(descriptor_handle) if handle is None: - cls = data_model.get_state_container_class(context_descriptor_container.STATE_QNAME) - obj = cls(descriptor_container=context_descriptor_container) + cls = data_model.get_state_container_class(context_entity.descriptor.STATE_QNAME) + obj = cls(descriptor_container=context_entity.descriptor) obj.Handle = descriptor_handle # this indicates that this is a new context state else: - _obj = mdib.context_states.handle.get_one(handle) + _obj = context_entity.states[handle] obj = _obj.mk_copy() return obj diff --git a/src/sdc11073/entity_mdib/__init__.py b/src/sdc11073/entity_mdib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/sdc11073/entity_mdib/entities.py b/src/sdc11073/entity_mdib/entities.py new file mode 100644 index 00000000..2bec8d89 --- /dev/null +++ b/src/sdc11073/entity_mdib/entities.py @@ -0,0 +1,448 @@ +"""Implementation of entities for EntityProviderMdib.""" +from __future__ import annotations + +import copy +import uuid +from typing import TYPE_CHECKING, Union + +from lxml.etree import QName + +from sdc11073.namespaces import QN_TYPE, text_to_qname +from sdc11073.xml_types import pm_qnames +from sdc11073.xml_types.pm_types import CodedValue + +if TYPE_CHECKING: + from sdc11073.mdib.containerbase import ContainerBase + from sdc11073.mdib.descriptorcontainers import AbstractDescriptorContainer + from sdc11073.mdib.statecontainers import AbstractMultiStateContainer, AbstractStateContainer + from sdc11073.xml_utils import LxmlElement + + from .entity_consumermdib import EntityConsumerMdib + from .entity_providermdib import EntityProviderMdib + +# Many types are fixed in schema. This table maps from tag in Element to its type +_static_type_lookup = { + pm_qnames.Mds: pm_qnames.MdsDescriptor, + pm_qnames.Vmd: pm_qnames.VmdDescriptor, + pm_qnames.Channel: pm_qnames.ChannelDescriptor, + pm_qnames.AlertSystem: pm_qnames.AlertSystemDescriptor, + pm_qnames.AlertCondition: pm_qnames.AlertConditionDescriptor, + pm_qnames.AlertSignal: pm_qnames.AlertSignalDescriptor, + pm_qnames.Sco: pm_qnames.ScoDescriptor, + pm_qnames.SystemContext: pm_qnames.SystemContextDescriptor, + pm_qnames.PatientContext: pm_qnames.PatientContextDescriptor, + pm_qnames.LocationContext: pm_qnames.LocationContextDescriptor, + pm_qnames.Clock: pm_qnames.ClockDescriptor, + pm_qnames.Battery: pm_qnames.BatteryDescriptor, +} + + +def get_xsi_type(element: LxmlElement) -> QName: + """Return the BICEPS type of an element. + + If there is a xsi:type entry, this specifies the type. + If not, the tag is used to determine the type. + """ + xsi_type_str = element.attrib.get(QN_TYPE) + + if xsi_type_str: + return text_to_qname(xsi_type_str, element.nsmap) + _xsi_type = QName(element.tag) + try: + return _static_type_lookup[_xsi_type] + except KeyError as err: # pragma: no cover + raise KeyError(str(_xsi_type)) from err + + +class _XmlEntityBase: + """A descriptor element and some info about it for easier access.""" + + def __init__(self, + parent_handle: str | None, + source_mds: str | None, + node_type: QName, + descriptor: LxmlElement): + self.parent_handle = parent_handle + self.source_mds = source_mds + self.node_type = node_type # name of descriptor type + self._descriptor = None + self.coded_value: CodedValue | None = None + self.descriptor = descriptor # setter updates self._descriptor and self.coded_value + + @property + def descriptor(self) -> LxmlElement: + return self._descriptor + + @descriptor.setter + def descriptor(self, new_descriptor: LxmlElement): + self._descriptor = new_descriptor + type_node = self.descriptor.find(pm_qnames.Type) + if type_node is not None: + self.coded_value = CodedValue.from_node(type_node) + else: + self.coded_value = None + + def __str__(self): + return f'{self.__class__.__name__} {self.node_type.localname} handle={self._descriptor.get("Handle")}' + + +class XmlEntity(_XmlEntityBase): + """Groups descriptor and state.""" + + def __init__(self, + parent_handle: str | None, + source_mds: str | None, + node_type: QName, + descriptor: LxmlElement, + state: LxmlElement | None): + super().__init__(parent_handle, source_mds, node_type, descriptor) + self.state = state + + @property + def is_multi_state(self) -> bool: + """Return False because this is not a multi state entity.""" + return False + + def mk_entity(self, mdib: EntityConsumerMdib) -> ConsumerEntity: + """Return a corresponding entity with containers.""" + return ConsumerEntity(self, mdib) + + +class XmlMultiStateEntity(_XmlEntityBase): + """Groups descriptor and list of multi-states.""" + + def __init__(self, + parent_handle: str | None, + source_mds: str | None, + node_type: QName, + descriptor: LxmlElement, + states: list[LxmlElement]): + super().__init__(parent_handle, source_mds, node_type, descriptor) + self.states: dict[str, LxmlElement] = {node.get('Handle'): node for node in states} + + @property + def is_multi_state(self) -> bool: + """Return True because this is a multi state entity.""" + return True + + def mk_entity(self, mdib: EntityConsumerMdib) -> ConsumerMultiStateEntity: + """Return a corresponding entity with containers.""" + return ConsumerMultiStateEntity(self, mdib) + + +class ConsumerEntityBase: + """A descriptor container and a weak reference to the corresponding xml entity.""" + + def __init__(self, + source: XmlEntity | XmlMultiStateEntity, + mdib: EntityConsumerMdib, # needed if a new state needs to be added + ): + self._mdib: EntityConsumerMdib = mdib + + cls = mdib.sdc_definitions.data_model.get_descriptor_container_class(source.node_type) + if cls is None: # pragma: no cover + raise ValueError(f'do not know how to make container from {source.node_type!s}') + handle = source.descriptor.get('Handle') + self.descriptor: AbstractDescriptorContainer = cls(handle, parent_handle=source.parent_handle) + self.descriptor.update_from_node(source.descriptor) + self.descriptor.set_source_mds(source.source_mds) + self.source_mds = source.source_mds + + @property + def handle(self) -> str: + """Return the handle of the descriptor.""" + return self.descriptor.Handle + + @property + def parent_handle(self) -> str | None: + """Return the parent handle of the descriptor.""" + return self.descriptor.parent_handle + + @property + def node_type(self) -> QName: + """Return the node type of the descriptor.""" + return self.descriptor.NODETYPE + + def __str__(self): + return f'{self.__class__.__name__} {self.node_type} handle={self.handle}' + + +class ConsumerEntity(ConsumerEntityBase): + """Groups descriptor container and state container.""" + + def __init__(self, + source: XmlEntity, + mdib: EntityConsumerMdib): + super().__init__(source, mdib) + self.state: AbstractStateContainer | None = None + if source.state is not None: + cls = mdib.sdc_definitions.data_model.get_state_container_class(self.descriptor.STATE_QNAME) + self.state = cls(self.descriptor) + self.state.update_from_node(source.state) + + def update(self): + """Update the entity from current data in mdib.""" + xml_entity = self._mdib.internal_entities.get(self.handle) + if xml_entity is None: + raise ValueError('entity no longer exists in mdib') + if int(xml_entity.descriptor.get('DescriptorVersion', '0')) != self.descriptor.DescriptorVersion: + self.descriptor.update_from_node(xml_entity.descriptor) + if int(xml_entity.state.get('StateVersion', '0')) != self.state.StateVersion: + self.state.update_from_node(xml_entity.state) + + +class ConsumerMultiStateEntity(ConsumerEntityBase): + """Groups descriptor container and list of multi-state containers.""" + + def __init__(self, + source: XmlMultiStateEntity, + mdib: EntityConsumerMdib): + super().__init__(source, mdib) + self.states: dict[str, AbstractMultiStateContainer] = {} + for handle, state in source.states.items(): + state_type = get_xsi_type(state) + cls = mdib.sdc_definitions.data_model.get_state_container_class(state_type) + state_container = cls(self.descriptor) + state_container.update_from_node(state) + self.states[handle] = state_container + + def update(self): + """Update the entity from current data in mdib.""" + xml_entity = self._mdib.internal_entities.get(self.handle) + if xml_entity is None: + raise ValueError('entity no longer exists in mdib') + if int(xml_entity.descriptor.get('DescriptorVersion', '0')) != self.descriptor.DescriptorVersion: + self.descriptor.update_from_node(xml_entity.descriptor) + + for handle, xml_state in xml_entity.states.items(): + create_new_state = False + try: + existing_state = self.states[handle] + except KeyError: + create_new_state = True + else: + if existing_state.StateVersion != int(xml_state.get('StateVersion', '0')): + existing_state.update_from_node(xml_state) + + if create_new_state: + xsi_type = get_xsi_type(xml_state) + cls = self._mdib.sdc_definitions.data_model.get_state_container_class(xsi_type) + state_container = cls(self.descriptor) + state_container.update_from_node(xml_state) + self.states[handle] = state_container + + # delete states that are no longer in xml_entity + for handle in list(self.states.keys()): + if handle not in xml_entity.states: + self.states.pop(handle) + + def new_state(self, state_handle: str | None = None) -> AbstractMultiStateContainer: + """Create a new state. + + The new state has handle of descriptor container as handle. + If this new state is used as a proposed context state in SetContextState operation, this means a new + state shall be created on providers side. + """ + if state_handle in self.states: # pragma: no cover + raise ValueError( + f'State handle {state_handle} already exists in {self.__class__.__name__}, handle = {self.handle}') + cls = self._mdib.data_model.get_state_container_class(self.descriptor.STATE_QNAME) + state = cls(descriptor_container=self.descriptor) + state.Handle = state_handle or self.handle + self.states[state.Handle] = state + return state + + +ConsumerEntityType = Union[ConsumerEntity, ConsumerMultiStateEntity] +ConsumerInternalEntityType = Union[XmlEntity, XmlMultiStateEntity] + + +############## provider ########################## + +class ProviderInternalEntityBase: + """A descriptor element and some info about it for easier access.""" + + def __init__(self, descriptor: AbstractDescriptorContainer): + self.descriptor = descriptor + + @property + def handle(self) -> str: + """Return the handle of the descriptor.""" + return self.descriptor.Handle + + @property + def parent_handle(self) -> str | None: + """Return the parent handle of the descriptor.""" + return self.descriptor.parent_handle + + @property + def source_mds(self) -> str: + """Return the source mds of the descriptor.""" + return self.descriptor.source_mds + + @property + def node_type(self) -> QName: + """Return the node type of the descriptor.""" + return self.descriptor.NODETYPE + + def __str__(self): + return f'{self.__class__.__name__} {self.node_type.localname} handle={self.handle}' + + +class ProviderInternalEntity(ProviderInternalEntityBase): + """Groups descriptor and state.""" + + def __init__(self, + descriptor: AbstractDescriptorContainer, + state: AbstractStateContainer | None): + super().__init__(descriptor) + self._state = state + if state is not None: + self._state.descriptor_container = self.descriptor + + @property + def state(self) -> AbstractStateContainer | None: + """Return the state member of the entity.""" + return self._state + + @state.setter + def state(self, new_state: AbstractStateContainer): + self._state = new_state + self._state.descriptor_container = self.descriptor + + @property + def is_multi_state(self) -> bool: + """Return False because this is not a multi state entity.""" + return False + + def mk_entity(self, mdib: EntityProviderMdib) -> ProviderEntity: + """Return a corresponding entity with containers.""" + return ProviderEntity(self, mdib) + + +class ProviderInternalMultiStateEntity(ProviderInternalEntityBase): + """Groups descriptor and list of multi-states.""" + + def __init__(self, + descriptor: AbstractDescriptorContainer, + states: list[AbstractMultiStateContainer]): + super().__init__(descriptor) + self.states = {state.Handle: state for state in states} + + @property + def is_multi_state(self) -> bool: + """Return True because this is a multi state entity.""" + return True + + def mk_entity(self, mdib: EntityProviderMdib) -> ProviderMultiStateEntity: + """Return a corresponding entity with containers.""" + return ProviderMultiStateEntity(self, mdib) + + +def _mk_copy(original: ContainerBase) -> ContainerBase: + """Return a deep copy of original without node member.""" + node, original.node = original.node, None + copied = copy.deepcopy(original) + original.node = node + return copied + + +class ProviderEntityBase: + """A descriptor container and a weak reference to the corresponding xml entity.""" + + def __init__(self, + source: ProviderInternalEntity | ProviderInternalMultiStateEntity, + mdib: EntityProviderMdib): + self._mdib = mdib + self.descriptor = _mk_copy(source.descriptor) + self.source_mds = source.source_mds + + @property + def handle(self) -> str: + """Return the handle of the descriptor.""" + return self.descriptor.Handle + + @property + def parent_handle(self) -> str | None: + """Return the parent handle of the descriptor.""" + return self.descriptor.parent_handle + + @property + def node_type(self) -> QName: + """Return the node type of the descriptor.""" + return self.descriptor.NODETYPE + + def __str__(self): + return f'{self.__class__.__name__} {self.node_type.localname} handle={self.handle}' + + +class ProviderEntity(ProviderEntityBase): + """Groups descriptor container and state container.""" + + def __init__(self, + source: ProviderInternalEntity, + mdib: EntityProviderMdib): + super().__init__(source, mdib) + self.state: AbstractStateContainer | None = None + if source.state is not None: + self.state = _mk_copy(source.state) + + @property + def is_multi_state(self) -> bool: + """Return False because this is not a multi state entity.""" + return False + + def update(self): + """Update from internal entity.""" + source_entity = self._mdib.internal_entities.get(self.handle) + if source_entity is None: + raise ValueError(f'entity {self.handle} no longer exists in mdib') + self.descriptor.update_from_other_container(source_entity.descriptor) + self.state = _mk_copy(source_entity.state) + + +class ProviderMultiStateEntity(ProviderEntityBase): + """Groups descriptor container and list of multi-state containers.""" + + def __init__(self, + source: ProviderInternalMultiStateEntity, + mdib: EntityProviderMdib): + super().__init__(source, mdib) + self.states: dict[str, AbstractMultiStateContainer] = {st.Handle: _mk_copy(st) for st in source.states.values()} + + @property + def is_multi_state(self) -> bool: + """Return True because this is a multi state entity.""" + return True + + def update(self): + """Update from internal entity.""" + source_entity = self._mdib.internal_entities.get(self.handle) + if source_entity is None: # pragma: no cover + raise ValueError(f'entity {self.handle} no longer exists in mdib') + self.descriptor.update_from_other_container(source_entity.descriptor) + for handle, src_state in source_entity.states.items(): + dest_state = self.states.get(handle) + if dest_state is None: + self.states[handle] = _mk_copy(src_state) + else: + dest_state.update_from_other_container(src_state) + # remove states that are no longer present is source_entity + for handle in list(self.states.keys()): + if handle not in source_entity.states: + self.states.pop(handle) + + def new_state(self, state_handle: str | None = None) -> AbstractMultiStateContainer: + """Create a new state.""" + if state_handle in self.states: + raise ValueError( + f'State handle {state_handle} already exists in {self.__class__.__name__}, handle = {self.handle}') + cls = self._mdib.data_model.get_state_container_class(self.descriptor.STATE_QNAME) + state = cls(descriptor_container=self.descriptor) + state.Handle = state_handle or uuid.uuid4().hex + self.states[state.Handle] = state + return state + + +ProviderInternalEntityType = Union[ProviderInternalEntity, ProviderInternalMultiStateEntity] +ProviderEntityType = Union[ProviderEntity, ProviderMultiStateEntity] diff --git a/src/sdc11073/entity_mdib/entity_consumermdib.py b/src/sdc11073/entity_mdib/entity_consumermdib.py new file mode 100644 index 00000000..c37cd909 --- /dev/null +++ b/src/sdc11073/entity_mdib/entity_consumermdib.py @@ -0,0 +1,732 @@ +"""The module contains the implementation of the EntityConsumerMdib.""" +from __future__ import annotations + +import copy +import time +from dataclasses import dataclass +from threading import Lock, Thread +from typing import TYPE_CHECKING, Any, Callable + +from sdc11073 import loghelper +from sdc11073 import observableproperties as properties +from sdc11073.exceptions import ApiUsageError +from sdc11073.mdib.consumermdib import ConsumerMdibState +from sdc11073.namespaces import QN_TYPE, default_ns_helper +from sdc11073.xml_types import msg_qnames, pm_qnames + +from .entities import XmlEntity, XmlMultiStateEntity, get_xsi_type +from .entity_consumermdibxtra import EntityConsumerMdibMethods +from .entity_mdibbase import EntityMdibBase + +if TYPE_CHECKING: + from collections.abc import Iterable + + from lxml.etree import QName + + from sdc11073.consumer.consumerimpl import SdcConsumer + from sdc11073.mdib.entityprotocol import EntityGetterProtocol, EntityTypeProtocol + from sdc11073.mdib.mdibbase import MdibVersionGroup + from sdc11073.pysoap.msgreader import ReceivedMessage + from sdc11073.xml_types.pm_types import CodedValue, Coding + from sdc11073.xml_utils import LxmlElement + + from .entities import ConsumerEntityType, ConsumerInternalEntityType + + XmlEntityFactory = Callable[[LxmlElement, str, str], ConsumerInternalEntityType] + + +@dataclass +class _BufferedData: + received_message_data: ReceivedMessage + handler: Callable + + +multi_state_q_names = (pm_qnames.PatientContextDescriptor, + pm_qnames.LocationContextDescriptor, + pm_qnames.WorkflowContextDescriptor, + pm_qnames.OperatorContextDescriptor, + pm_qnames.MeansContextDescriptor, + pm_qnames.EnsembleContextDescriptor) + + +def _mk_xml_entity(node: LxmlElement, parent_handle: str, source_mds: str) -> ConsumerInternalEntityType: + """Return a new XmlEntity or XmlMultiStateEntity. + + This is the default consumer entity factory. + It creates one of ConsumerInternalEntityType, which are factories for ConsumerEntityType. + By using a different factory, the user can change this to use other classes. + """ + xsi_type = get_xsi_type(node) + if xsi_type in multi_state_q_names: + return XmlMultiStateEntity(parent_handle, source_mds, xsi_type, node, []) + return XmlEntity(parent_handle, source_mds, xsi_type, node, None) + + +class EntityGetter: + """Implementation of EntityGetterProtocol.""" + + def __init__(self, entities: dict[str, XmlEntity | XmlMultiStateEntity], mdib: EntityConsumerMdib): + self._entities = entities + self._mdib = mdib + + def by_handle(self, handle: str) -> ConsumerEntityType | None: + """Return entity with given handle.""" + try: + return self._mk_entity(handle) + except KeyError: + return None + + def by_node_type(self, node_type: QName) -> list[EntityTypeProtocol]: + """Return all entities with given node type.""" + ret = [] + for handle, entity in self._entities.items(): + if entity.node_type == node_type: + ret.append(self._mk_entity(handle)) + return ret + + def by_parent_handle(self, parent_handle: str | None) -> list[EntityTypeProtocol]: + """Return all entities with given parent handle.""" + ret = [] + for handle, entity in self._entities.items(): + if entity.parent_handle == parent_handle: + ret.append(self._mk_entity(handle)) + return ret + + def by_coding(self, coding: Coding) -> list[EntityTypeProtocol]: + """Return all entities with given Coding.""" + ret = [] + for handle, xml_entity in self._entities.items(): + if xml_entity.coded_value is not None and xml_entity.coded_value.is_equivalent(coding): + ret.append(self._mk_entity(handle)) + return ret + + def by_coded_value(self, coded_value: CodedValue) -> list[EntityTypeProtocol]: + """Return all entities with given Coding.""" + ret = [] + for handle, xml_entity in self._entities.items(): + if xml_entity.coded_value.is_equivalent(coded_value): + ret.append(self._mk_entity(handle)) + return ret + + def items(self) -> Iterable[tuple[str, [EntityTypeProtocol]]]: + """Return items of a dictionary.""" + for handle in self._entities: + yield handle, self._mk_entity(handle) + + def _mk_entity(self, handle: str) -> ConsumerEntityType: + xml_entity = self._mdib.internal_entities[handle] + return xml_entity.mk_entity(self._mdib) + + def __len__(self) -> int: + """Return number of entities.""" + return len(self._entities) + + +class EntityConsumerMdib(EntityMdibBase): + """Implementation of the consumer side mdib with EntityGetter Interface. + + The internal entities store descriptors and states as XML nodes. This needs only very little CPU time for + handling of notifications. + The instantiation of descriptor and state container instances is only done on demand when the user calls the + EntityGetter interface. + """ + + sequence_or_instance_id_changed_event: bool = properties.ObservableProperty( + default_value=False, fire_only_on_changed_value=False) + # sequence_or_instance_id_changed_event is set to True every time the sequence id changes. + # It is not reset to False any time later. + # It is in the responsibility of the application to react on a changed sequence id. + # Observe this property and call "reload_all" in the observer code. + + MDIB_VERSION_CHECK_DISABLED = False + + # for testing purpose you can disable checking of mdib version, so that every notification is accepted. + + def __init__(self, + sdc_client: SdcConsumer, + extras_cls: type | None = None, + max_realtime_samples: int = 100): + """Construct a ConsumerMdib instance. + + :param sdc_client: a SdcConsumer instance + :param extras_cls: extended functionality + :param max_realtime_samples: determines how many real time samples are stored per RealtimeSampleArray + """ + super().__init__(sdc_client.sdc_definitions, + loghelper.get_logger_adapter('sdc.client.mdib', sdc_client.log_prefix)) + self._entity_factory: XmlEntityFactory = _mk_xml_entity + self._entities: dict[str, XmlEntity | XmlMultiStateEntity] = {} # key is the handle + + self._sdc_client = sdc_client + if extras_cls is None: + extras_cls = EntityConsumerMdibMethods + self._xtra = extras_cls(self, self._logger) + self._state = ConsumerMdibState.invalid + self.rt_buffers = {} # key is a handle, value is a ConsumerRtBuffer + self._max_realtime_samples = max_realtime_samples + self._last_wf_age_log = time.time() + # a buffer for notifications that are received before initial get_mdib is done + self._buffered_notifications = [] + self._buffered_notifications_lock = Lock() + self.entities: EntityGetterProtocol = EntityGetter(self._entities, self) + + @property + def xtra(self) -> Any: + """Give access to extended functionality.""" + return self._xtra + + @property + def internal_entities(self) -> dict[str, ConsumerInternalEntityType]: + """The property is needed by transactions. Do not use it otherwise.""" + return self._entities + + @property + def sdc_client(self) -> SdcConsumer: + """Give access to sdc client.""" + return self._sdc_client + + @property + def is_initialized(self) -> bool: + """Returns True if everything has been set up completely.""" + return self._state == ConsumerMdibState.initialized + + def init_mdib(self): + """Binds own notification handlers to observables of sdc client and calls GetMdib. + + Client mdib is initialized from GetMdibResponse, and from then on updated from incoming notifications. + :return: None + """ + if self.is_initialized: + raise ApiUsageError('ConsumerMdib is already initialized') + # first start receiving notifications, then call get_mdib. + # Otherwise, we might miss notifications. + self._xtra.bind_to_client_observables() + self.reload_all() + self._sdc_client.set_mdib(self) # pylint: disable=protected-access + self._logger.info('initializing mdib done') + + def reload_all(self): + """Delete all data and reloads everything.""" + self._logger.info('reload_all called') + with self.mdib_lock: + self._state = ConsumerMdibState.initializing # notifications are now buffered + self._get_mdib_response_node = None + self._md_state_node = None + self.sequence_id = None + self.instance_id = None + self.mdib_version = None + + get_service = self._sdc_client.client('Get') + self._logger.info('initializing mdib...') + response = get_service.get_mdib() + self._set_root_node(response.p_msg.msg_node) + self._update_mdib_version_group(response.mdib_version_group) + + # process buffered notifications + with self._buffered_notifications_lock: + self._logger.debug('got _buffered_notifications_lock') + for buffered_report in self._buffered_notifications: + # buffered data might contain notifications that do not fit. + mvg = buffered_report.received_message_data.mdib_version_group + if mvg.sequence_id != self.sequence_id: + self.logger.debug('wrong sequence id "%s"; ignore buffered report', + mvg.sequence_id) + continue + if mvg.mdib_version <= self.mdib_version: + self.logger.debug('older mdib version "%d"; ignore buffered report', + mvg.mdib_version) + continue + buffered_report.handler(buffered_report.received_message_data) + del self._buffered_notifications[:] + self._state = ConsumerMdibState.initialized + self._logger.info('reload_all done') + + def process_incoming_metric_states_report(self, received_message_data: ReceivedMessage): + """Check mdib_version_group and process report it if okay.""" + if not self._pre_check_report_ok(received_message_data, + self._process_incoming_metric_states_report): + return + with self.mdib_lock: + self._process_incoming_metric_states_report(received_message_data) + + def _process_incoming_metric_states_report(self, received_message: ReceivedMessage): + """Check mdib version. + + If okay: + - update mdib. + - update observable. + Call this method only if mdib_lock is already acquired. + """ + self.metric_handles = self._process_incoming_state_report(received_message, + msg_qnames.MetricState) + + def process_incoming_alert_states_report(self, received_message_data: ReceivedMessage): + """Check mdib_version_group and process report it if okay.""" + if not self._pre_check_report_ok(received_message_data, + self._process_incoming_metric_states_report): + return + with self.mdib_lock: + self._process_incoming_alert_states_report(received_message_data) + + def _process_incoming_alert_states_report(self, received_message: ReceivedMessage): + """Check mdib version. + + If okay: + - update mdib. + - update observable. + Call this method only if mdib_lock is already acquired. + """ + self.alert_handles = self._process_incoming_state_report(received_message, + msg_qnames.AlertState) + + def process_incoming_component_report(self, received_message_data: ReceivedMessage): + """Check mdib_version_group and process report it if okay.""" + if not self._pre_check_report_ok(received_message_data, + self._process_incoming_metric_states_report): + return + with self.mdib_lock: + self._process_incoming_component_report(received_message_data) + + def _process_incoming_component_report(self, received_message: ReceivedMessage): + """Check mdib version. + + If okay: + - update mdib. + - update observable. + Call this method only if mdib_lock is already acquired. + """ + self.component_handles = self._process_incoming_state_report(received_message, + msg_qnames.ComponentState) + + def process_incoming_operational_state_report(self, received_message_data: ReceivedMessage): + """Check mdib_version_group and process report it if okay.""" + if not self._pre_check_report_ok(received_message_data, + self._process_incoming_metric_states_report): + return + with self.mdib_lock: + self._process_incoming_operational_state_report(received_message_data) + + def _process_incoming_operational_state_report(self, received_message: ReceivedMessage): + """Check mdib version. + + If okay: + - update mdib. + - update observable. + Call this method only if mdib_lock is already acquired. + """ + self.operation_handles = self._process_incoming_state_report(received_message, + msg_qnames.OperationState) + + def process_incoming_context_report(self, received_message_data: ReceivedMessage): + """Check mdib_version_group and process report it if okay.""" + if not self._pre_check_report_ok(received_message_data, + self._process_incoming_metric_states_report): + return + with self.mdib_lock: + self._process_incoming_context_report(received_message_data) + + def _process_incoming_context_report(self, received_message: ReceivedMessage): + """Check mdib version. + + If okay: + - update mdib. + - update observable. + Call this method only if mdib_lock is already acquired. + """ + if self._can_accept_mdib_version(received_message.mdib_version_group.mdib_version, 'component states'): + self._update_mdib_version_group(received_message.mdib_version_group) + + handles = [] + for report_part in received_message.p_msg.msg_node: + for state_node in report_part: + if state_node.tag == msg_qnames.ContextState: + handle = state_node.attrib['Handle'] + descriptor_handle = state_node.attrib['DescriptorHandle'] + xml_entity = self._entities[descriptor_handle] + # modify state_node, but only in a deep copy + state_node = copy.deepcopy(state_node) # noqa: PLW2901 + state_node.tag = pm_qnames.State # xml_entity.state.tag # keep old tag + + # replace state in parent + found = False + parent = self._md_state_node + for st in parent: + if st.get('Handle') == handle: + parent.replace(st, state_node) + found = True + break + if not found: + parent.append(state_node) + + # replace or add in xml entity + xml_entity.states[handle] = state_node + + handles.append(handle) + self.context_handles = handles # update observable + + def process_incoming_waveform_states(self, received_message_data: ReceivedMessage): + """Check mdib_version_group and process report it if okay.""" + if not self._pre_check_report_ok(received_message_data, + self._process_incoming_waveform_states): + return + with self.mdib_lock: + self._process_incoming_waveform_states(received_message_data) + + def _process_incoming_waveform_states(self, received_message: ReceivedMessage): + if self._can_accept_mdib_version(received_message.mdib_version_group.mdib_version, 'waveform states'): + self._update_mdib_version_group(received_message.mdib_version_group) + + handles = [] + for state in received_message.p_msg.msg_node: + handles.append(self._update_state(state, + pm_qnames.RealTimeSampleArrayMetricState)) # replaces states in self._get_mdib_response_node + self.waveform_handles = handles # update observable + + def process_incoming_description_modification_report(self, received_message: ReceivedMessage): + """Check mdib_version_group and process report it if okay.""" + if not self._pre_check_report_ok(received_message, + self._process_incoming_metric_states_report): + return + with self.mdib_lock: + self._process_incoming_description_modification_report(received_message) + + def _process_incoming_description_modification_report(self, received_message: ReceivedMessage): + new_descriptors_handles = [] + updated_descriptors_handles = [] + deleted_descriptors_handles = [] + for report_part in received_message.p_msg.msg_node: + parent_handle = report_part.attrib.get('ParentDescriptor') # can be None in case of MDS, but that is ok + source_mds_handle = report_part[0].text + descriptors = [copy.deepcopy(e) for e in report_part if e.tag == msg_qnames.Descriptor] + states = [copy.deepcopy(e) for e in report_part if e.tag == msg_qnames.State] + modification_type = report_part.attrib.get('ModificationType', 'Upt') # implied value is 'Upt' + if modification_type == 'Upt': + updated_descriptors_handles.extend(self._update_descriptors(parent_handle, + source_mds_handle, + descriptors, + states)) + elif modification_type == 'Crt': + new_descriptors_handles.extend(self._create_descriptors(parent_handle, + source_mds_handle, + descriptors, + states)) + elif modification_type == 'Del': + deleted_descriptors_handles.extend(self._delete_descriptors(descriptors)) + else: + self.logger.error('Unknown modification type %r', modification_type) + if updated_descriptors_handles: + self.updated_descriptors_handles = updated_descriptors_handles + if new_descriptors_handles: + self.new_descriptors_handles = new_descriptors_handles + if deleted_descriptors_handles: + self.deleted_descriptors_handles = deleted_descriptors_handles + + def _update_descriptors(self, + parent_handle: str, + source_mds_handle: str, + descriptors: list[LxmlElement], + states: list[LxmlElement]) -> list[str]: + handles = [] + for descriptor in descriptors: + handle = descriptor.attrib['Handle'] + entity = self._entities.get(handle) + if entity.parent_handle != parent_handle: + self.logger.error('inconsistent parent handle "%s" for "%s"', handle, entity.parent_handle) + if entity.source_mds != source_mds_handle: + self.logger.error('inconsistent source mds handle "%s" for "%s"', + source_mds_handle, entity.source_mds) + if entity is None: + self.logger.error('got descriptor update for not existing handle "%s"', handle) + continue + + current_states = [s for s in states if s.attrib['DescriptorHandle'] == handle] + self._update_descriptor_states(descriptor, current_states) + handles.append(handle) + return handles + + def _create_descriptors(self, + parent_handle: str, + source_mds_handle: str, + descriptors: list[LxmlElement], + states: list[LxmlElement]) -> list[str]: + handles = [] + for descriptor in descriptors: + xsi_type = get_xsi_type(descriptor) + descriptor_handle = descriptor.attrib['Handle'] + current_states = [s for s in states if s.attrib['DescriptorHandle'] == descriptor_handle] + + # add states to parent (MdState node) + for st in current_states: + st.tag = pm_qnames.State + self._md_state_node.append(st) + + xml_entity = self._entity_factory(descriptor, parent_handle, source_mds_handle) + if xml_entity.is_multi_state: + for st in current_states: + xml_entity.states[st.attrib['Handle']] = st + elif len(current_states) != 1: + self.logger.error('create descriptor: Expect one state, got %d', len(current_states)) + # Todo: what to do in this case? add entity without state? + else: + xml_entity.state = current_states[0] + self._entities[descriptor_handle] = xml_entity + handles.append(descriptor_handle) + + # add descriptor to parent + parent_xml_entity = self._entities[parent_handle] + if parent_xml_entity.node_type == pm_qnames.ChannelDescriptor: + # channel children have same tag + descriptor.tag = pm_qnames.Metric + parent_xml_entity.descriptor.append(descriptor) + elif parent_xml_entity.node_type == pm_qnames.VmdDescriptor: + # vmd children have same tag + descriptor.tag = pm_qnames.Channel + parent_xml_entity.descriptor.append(descriptor) + elif parent_xml_entity.node_type == pm_qnames.MdsDescriptor: + # Mds children have different names. + # child_order determines the tag of the element (first tuple position), and the corresponding type + # (2nd position) + child_order: Iterable[tuple[QName, QName]] = ( + (pm_qnames.MetaData, pm_qnames.MetaData), # optional member, no handle + (pm_qnames.SystemContext, pm_qnames.SystemContextDescriptor), + (pm_qnames.Clock, pm_qnames.ClockDescriptor), + (pm_qnames.Battery, pm_qnames.BatteryDescriptor), + (pm_qnames.ApprovedJurisdictions, pm_qnames.ApprovedJurisdictions), + # optional list, no handle + (pm_qnames.Vmd, pm_qnames.VmdDescriptor)) + # Insert at correct position with correct name! + self._insert_child(descriptor, xsi_type, + parent_xml_entity.descriptor, child_order) + return handles + + @staticmethod + def _insert_child(child_node: LxmlElement, + child_xsi_type: QName, + parent_node: + LxmlElement, + child_order: Iterable[tuple[QName, QName]]): + """Rename child_node to correct name acc. to BICEPS schema and insert at correct position.""" + # rename child_node to correct name required by BICEPS schema + add_before_q_names = [] + + for i, entry in enumerate(child_order): + schema_name, xsi_type = entry + if xsi_type == child_xsi_type: + child_node.tag = schema_name + add_before_q_names.extend([x[0] for x in child_order[i + 1:]]) + break + + # find position + existing_children = parent_node[:] + if not existing_children or not add_before_q_names: + parent_node.append(child_node) + return + for tmp_child_node in existing_children: + if tmp_child_node.tag in add_before_q_names: + tmp_child_node.addprevious(child_node) + return + raise RuntimeError('this should not happen') + + def _delete_descriptors(self, descriptors: list[LxmlElement]) -> list[str]: + handles = [] + for descriptor in descriptors: + handle = descriptor.attrib['Handle'] + entity = self._entities.get(handle) + if entity is None: + self.logger.error('shall delete descriptor "%s", but it is unknown', handle) + else: + self._delete_entity(entity, handles) + return handles + # Todo: update self._get_mdib_response_node + + def _delete_entity(self, entity: XmlEntity | XmlMultiStateEntity, deleted_handles: list[str]): + """Recursive method to delete an entity and subtree.""" + parent = entity.descriptor.getparent() + if parent is not None: + parent.remove(entity.descriptor) + states = entity.states.values() if entity.is_multi_state else [entity.state] + for state in states: + parent = state.getparent() + if parent is not None: + parent.remove(state) + handle = entity.descriptor.get('Handle') + del self._entities[handle] + deleted_handles.append(handle) + child_entities = [e for e in self._entities.values() if e.parent_handle == handle] + for e in child_entities: + self._delete_entity(e, deleted_handles) + + def _process_incoming_state_report(self, received_message: ReceivedMessage, expected_q_name: QName) -> list[str]: + """Check mdib version. + + If okay: + - update mdib. + - update observable. + Call this method only if mdib_lock is already acquired. + """ + if self._can_accept_mdib_version(received_message.mdib_version_group.mdib_version, 'state'): + self._update_mdib_version_group(received_message.mdib_version_group) + + handles = [] + for report_part in received_message.p_msg.msg_node: + for state in report_part: + if state.tag == expected_q_name: + handles.append(self._update_state(state)) # replace states in self._get_mdib_response_node + return handles # update observable + + def _update_state(self, state_node: LxmlElement, xsi_type: QName | None = None) -> str: + """Replace state in DOM tree and entity.""" + descriptor_handle = state_node.attrib['DescriptorHandle'] + xml_entity = self._entities[descriptor_handle] + state_node = copy.deepcopy(state_node) # we modify state_node, but only in a deep copy + state_node.tag = pm_qnames.State # xml_entity.state.tag # keep old tag + if xsi_type: + state_node.set(QN_TYPE, default_ns_helper.doc_name_from_qname(xsi_type)) + + # replace state in parent + parent = xml_entity.state.getparent() + parent.replace(xml_entity.state, state_node) + + # replace in xml entity + xml_entity.state = state_node + return descriptor_handle + + def _update_descriptor_states(self, descriptor_node: LxmlElement, state_nodes: list[LxmlElement]) -> str: + """Replace state in DOM tree and entity.""" + for state_node in state_nodes: + state_node.tag = pm_qnames.State # rename in order to have a valid tag acc. to participant model + + descriptor_handle = descriptor_node.attrib['Handle'] + xml_entity = self._entities[descriptor_handle] + descriptor_node.tag = xml_entity.descriptor.tag # keep old tag + + # move all children with a Handle from entity.descriptor to descriptor_node (at identical position ) + children = xml_entity.descriptor[:] + for idx, child in enumerate(children): + if 'Handle' in child.attrib: + descriptor_node.insert(idx, child) + + # replace descriptor in parent + descriptor_parent = xml_entity.descriptor.getparent() + descriptor_parent.replace(xml_entity.descriptor, descriptor_node) + + # replace descriptor in xml_entity + xml_entity.descriptor = descriptor_node + + if xml_entity.is_multi_state: + # replace state_nodes in parent + for state_node in state_nodes: + state_parent = xml_entity.state.getparent() + state_parent.replace(xml_entity.state, state_node) + + # replace state_nodes in xml_entity + xml_entity.states = state_nodes + elif len(state_nodes) != 1: + self.logger.error('update descriptor: Expect one state, got %d', len(state_nodes)) + # Todo: what to do in this case? add entity without state? + else: + state_parent = xml_entity.state.getparent() + state_parent.replace(xml_entity.state, state_nodes[0]) + xml_entity.state = state_nodes[0] + return descriptor_handle + + def _pre_check_report_ok(self, received_message_data: ReceivedMessage, + handler: Callable) -> bool: + """Check if the report can be added to mdib. + + The pre-check runs before the mdib lock is acquired. + The report is buffered if state is 'initializing' and 'is_buffered_report' is False. + :return: True if report can be added to mdib. + """ + self._check_sequence_or_instance_id_changed( + received_message_data.mdib_version_group) # this might change self._state + if self._state == ConsumerMdibState.invalid: + # ignore report in these states + return False + if self._state == ConsumerMdibState.initializing: + with self._buffered_notifications_lock: + # check state again, it might have changed before lock was acquired + if self._state == ConsumerMdibState.initializing: + self._buffered_notifications.append(_BufferedData(received_message_data, handler)) + return False + return True + + def _can_accept_mdib_version(self, new_mdib_version: int, log_prefix: str) -> bool: + if self.MDIB_VERSION_CHECK_DISABLED: + return True + # log deviations from expected mdib version + if new_mdib_version < self.mdib_version: + self._logger.warning('{}: ignoring too old Mdib version, have {}, got {}', # noqa: PLE1205 + log_prefix, self.mdib_version, new_mdib_version) + elif (new_mdib_version - self.mdib_version) > 1: + # This can happen if consumer did not subscribe to all notifications. + # Still log a warning, because mdib is no longer a correct mirror of provider mdib. + self._logger.warning('{}: expect mdib_version {}, got {}', # noqa: PLE1205 + log_prefix, self.mdib_version + 1, new_mdib_version) + # it is possible to receive multiple notifications with the same mdib version => compare ">=" + return new_mdib_version >= self.mdib_version + + def _check_sequence_or_instance_id_changed(self, mdib_version_group: MdibVersionGroup): + """Check if sequence id and instance id are still the same. + + If not, + - set state member to invalid + - set the observable "sequence_or_instance_id_changed_event" in a thread. + This allows to implement an observer that can directly call reload_all without blocking the consumer. + """ + if mdib_version_group.sequence_id == self.sequence_id and mdib_version_group.instance_id == self.instance_id: + return + if self._state == ConsumerMdibState.initialized: + if mdib_version_group.sequence_id != self.sequence_id: + self.logger.warning('sequence id changed from "%s" to "%s"', + self.sequence_id, mdib_version_group.sequence_id) + if mdib_version_group.instance_id != self.instance_id: + self.logger.warning('instance id changed from "%r" to "%r"', + self.instance_id, mdib_version_group.instance_id) + self.logger.warning('mdib is no longer valid!') + + self._state = ConsumerMdibState.invalid + + def _set_observable(): + self.sequence_or_instance_id_changed_event = True + + thr = Thread(target=_set_observable) + thr.start() + + def _set_root_node(self, root_node: LxmlElement): # noqa: C901 + """Set member and create xml entities.""" + if root_node.tag != msg_qnames.GetMdibResponse: + raise ValueError(f'root node must be {msg_qnames.GetMdibResponse!s}, got {root_node.tag!s}') + self._get_mdib_response_node = root_node + self._mdib_node = root_node[0] + self._entities.clear() + for child_element in self._mdib_node: # MdDescription, MdState; both are optional + if child_element.tag == pm_qnames.MdState: + self._md_state_node = child_element + if child_element.tag == pm_qnames.MdDescription: + self._md_description_node = child_element + + def register_children_with_handle(parent_node: LxmlElement, source_mds: str | None = None): + parent_handle = parent_node.attrib.get('Handle') + for child_node in parent_node[:]: + child_handle = child_node.attrib.get('Handle') + if child_node.tag == pm_qnames.Mds: + source_mds = child_handle + if child_handle: + self._entities[child_handle] = self._entity_factory(child_node, + parent_handle, + source_mds) + register_children_with_handle(child_node, source_mds) + + if self._md_description_node is not None: + register_children_with_handle(self._md_description_node) + + if self._md_state_node is not None: + for state_node in self._md_state_node: + descriptor_handle = state_node.attrib['DescriptorHandle'] + entity = self._entities[descriptor_handle] + if entity.is_multi_state: + handle = state_node.attrib['Handle'] + entity.states[handle] = state_node + else: + entity.state = state_node diff --git a/src/sdc11073/entity_mdib/entity_consumermdibxtra.py b/src/sdc11073/entity_mdib/entity_consumermdibxtra.py new file mode 100644 index 00000000..9435f122 --- /dev/null +++ b/src/sdc11073/entity_mdib/entity_consumermdibxtra.py @@ -0,0 +1,61 @@ +"""The module contains extensions to the functionality of the EntityConsumerMdib.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from sdc11073 import observableproperties as properties + +if TYPE_CHECKING: + from sdc11073.loghelper import LoggerAdapter + from sdc11073.pysoap.msgreader import ReceivedMessage + + from .entity_consumermdib import EntityConsumerMdib + +class EntityConsumerMdibMethods: + """Extra methods for consumer mdib that are not part of core functionality.""" + + def __init__(self, consumer_mdib: EntityConsumerMdib, logger: LoggerAdapter): + self._mdib = consumer_mdib + self._sdc_client = consumer_mdib.sdc_client + self._msg_reader = self._sdc_client.msg_reader + self._logger = logger + + def bind_to_client_observables(self): + """Connect the mdib with the notifications from consumer.""" + properties.bind(self._sdc_client, waveform_report=self._on_waveform_report) + properties.bind(self._sdc_client, episodic_metric_report=self._on_episodic_metric_report) + properties.bind(self._sdc_client, episodic_alert_report=self._on_episodic_alert_report) + properties.bind(self._sdc_client, episodic_context_report=self._on_episodic_context_report) + properties.bind(self._sdc_client, episodic_component_report=self._on_episodic_component_report) + properties.bind(self._sdc_client, description_modification_report=self._on_description_modification_report) + properties.bind(self._sdc_client, episodic_operational_state_report=self._on_operational_state_report) + + def _on_episodic_metric_report(self, received_message_data: ReceivedMessage): + self._logger.debug('_on_episodic_metric_report') + self._mdib.process_incoming_metric_states_report(received_message_data) + + def _on_episodic_alert_report(self, received_message_data: ReceivedMessage): + self._logger.debug('_on_episodic_alert_report') + self._mdib.process_incoming_alert_states_report(received_message_data) + + def _on_operational_state_report(self, received_message_data: ReceivedMessage): + self._logger.debug('_on_operational_state_report') + self._mdib.process_incoming_operational_state_report(received_message_data) + + def _on_waveform_report(self, received_message_data: ReceivedMessage): + self._logger.debug('_on_waveform_report') + self._mdib.process_incoming_waveform_states(received_message_data) + + def _on_episodic_context_report(self, received_message_data: ReceivedMessage): + self._logger.debug('_on_episodic_context_report') + self._mdib.process_incoming_context_report(received_message_data) + + def _on_episodic_component_report(self, received_message_data: ReceivedMessage): + self._logger.debug('_on_episodic_component_report') + self._mdib.process_incoming_component_report(received_message_data) + + def _on_description_modification_report(self, received_message_data: ReceivedMessage): + self._logger.debug('_on_description_modification_report') + self._mdib.process_incoming_description_modification_report(received_message_data) + + diff --git a/src/sdc11073/entity_mdib/entity_mdibbase.py b/src/sdc11073/entity_mdib/entity_mdibbase.py new file mode 100644 index 00000000..76fe0c45 --- /dev/null +++ b/src/sdc11073/entity_mdib/entity_mdibbase.py @@ -0,0 +1,75 @@ +"""The module implements the base class for consumer and provider specific mdib implementations.""" +from __future__ import annotations + +from threading import RLock +from typing import TYPE_CHECKING + +from sdc11073 import observableproperties as properties +from sdc11073.mdib.mdibbase import MdibVersionGroup + +if TYPE_CHECKING: + from sdc11073.definitions_base import BaseDefinitions + from sdc11073.loghelper import LoggerAdapter + from sdc11073.xml_utils import LxmlElement + + +class EntityMdibBase: + """Base class for consumer and provider specific mdib implementations.""" + + metric_handles = properties.ObservableProperty(fire_only_on_changed_value=False) + waveform_handles = properties.ObservableProperty(fire_only_on_changed_value=False) + alert_handles = properties.ObservableProperty(fire_only_on_changed_value=False) + context_handles = properties.ObservableProperty(fire_only_on_changed_value=False) + component_handles = properties.ObservableProperty(fire_only_on_changed_value=False) + new_descriptors_handles = properties.ObservableProperty(fire_only_on_changed_value=False) + updated_descriptors_handles = properties.ObservableProperty(fire_only_on_changed_value=False) + deleted_descriptors_handles = properties.ObservableProperty(fire_only_on_changed_value=False) + operation_handles = properties.ObservableProperty(fire_only_on_changed_value=False) + sequence_id = properties.ObservableProperty() + instance_id = properties.ObservableProperty() + + def __init__(self, sdc_definitions: type[BaseDefinitions], + logger: LoggerAdapter): + """Construct MdibBase. + + :param sdc_definitions: a class derived from BaseDefinitions + """ + self.sdc_definitions = sdc_definitions + self.data_model = sdc_definitions.data_model + self._logger = logger + self.mdib_version = 0 + self.sequence_id = '' # needs to be set to a reasonable value by derived class + self.instance_id = None # None or an unsigned int + self.log_prefix = '' + self.mdib_lock = RLock() + + self._get_mdib_response_node: LxmlElement | None = None + self._mdib_node: LxmlElement | None = None + self._md_description_node: LxmlElement | None = None + self._md_state_node :LxmlElement | None = None + + @property + def mdib_version_group(self) -> MdibVersionGroup: + """"Get current version data.""" + return MdibVersionGroup(self.mdib_version, self.sequence_id, self.instance_id) + + def _update_mdib_version_group(self, mdib_version_group: MdibVersionGroup): + """Set members and update entries in DOM tree.""" + mdib_node = self._get_mdib_response_node[0] + if mdib_version_group.mdib_version != self.mdib_version: + self.mdib_version = mdib_version_group.mdib_version + self._get_mdib_response_node.set('MdibVersion', str(mdib_version_group.mdib_version)) + mdib_node.set('MdibVersion', str(mdib_version_group.mdib_version)) + if mdib_version_group.sequence_id != self.sequence_id: + self.sequence_id = mdib_version_group.sequence_id + self._get_mdib_response_node.set('SequenceId', str(mdib_version_group.sequence_id)) + mdib_node.set('SequenceId', str(mdib_version_group.sequence_id)) + if mdib_version_group.instance_id != self.instance_id: + self.instance_id = mdib_version_group.instance_id + self._get_mdib_response_node.set('InstanceId', str(mdib_version_group.instance_id)) + mdib_node.set('InstanceId', str(mdib_version_group.instance_id)) + + @property + def logger(self) -> LoggerAdapter: + """Return the logger.""" + return self._logger diff --git a/src/sdc11073/entity_mdib/entity_providermdib.py b/src/sdc11073/entity_mdib/entity_providermdib.py new file mode 100644 index 00000000..ab317bf9 --- /dev/null +++ b/src/sdc11073/entity_mdib/entity_providermdib.py @@ -0,0 +1,511 @@ +"""The module contains a provider mdib implementation that uses entities in internal representation.""" +from __future__ import annotations + +import uuid +from collections import defaultdict +from contextlib import AbstractContextManager, contextmanager +from pathlib import Path +from threading import Lock +from typing import TYPE_CHECKING, Any, Callable + +from lxml.etree import Element, SubElement + +from sdc11073 import loghelper +from sdc11073.definitions_base import ProtocolsRegistry +from sdc11073.mdib.transactionsprotocol import TransactionType +from sdc11073.observableproperties import ObservableProperty +from sdc11073.pysoap.msgreader import MessageReader + +from .entities import ProviderInternalEntity, ProviderInternalEntityType, ProviderInternalMultiStateEntity +from .entity_mdibbase import EntityMdibBase +from .entity_providermdibxtra import EntityProviderMdibMethods +from .entity_transactions import mk_transaction + +if TYPE_CHECKING: + from collections.abc import Iterable + + from lxml.etree import QName + + from sdc11073 import xml_utils + from sdc11073.definitions_base import BaseDefinitions + from sdc11073.loghelper import LoggerAdapter + from sdc11073.mdib.descriptorcontainers import AbstractDescriptorContainer + from sdc11073.mdib.entityprotocol import ProviderEntityGetterProtocol + from sdc11073.mdib.mdibbase import MdibVersionGroup + from sdc11073.mdib.statecontainers import AbstractStateContainer + from sdc11073.mdib.transactionsprotocol import ( + AnyEntityTransactionManagerProtocol, + EntityContextStateTransactionManagerProtocol, + EntityDescriptorTransactionManagerProtocol, + EntityStateTransactionManagerProtocol, + TransactionResultProtocol, + ) + from sdc11073.xml_types.pm_types import CodedValue, Coding + + from .entities import ProviderEntityType, ProviderMultiStateEntity + + ProviderEntityFactory = Callable[[AbstractDescriptorContainer, list[AbstractStateContainer]], + ProviderInternalEntityType] + + +def _mk_internal_entity(descriptor_container: AbstractDescriptorContainer, + states: list[AbstractStateContainer]) -> ProviderInternalEntityType: + """Create an entity. + + This is the default Implementation of ProviderEntityFactory. + """ + for s in states: + s.descriptor_container = descriptor_container + if descriptor_container.is_context_descriptor: + return ProviderInternalMultiStateEntity(descriptor_container, states) + if len(states) == 1: + return ProviderInternalEntity(descriptor_container, states[0]) + if len(states) == 0: + return ProviderInternalEntity(descriptor_container, None) + raise ValueError( + f'found {len(states)} states for {descriptor_container.NODETYPE} handle = {descriptor_container.Handle}') + + +class ProviderEntityGetter: + """Implements entityprotocol.ProviderEntityGetterProtocol.""" + + def __init__(self, + mdib: EntityProviderMdib): + self._mdib = mdib + + def by_handle(self, handle: str) -> ProviderEntityType | None: + """Return entity with given handle.""" + with self._mdib.mdib_lock: + try: + internal_entity = self._mdib.internal_entities[handle] + return internal_entity.mk_entity(self._mdib) + except KeyError: + return None + + def by_context_handle(self, handle: str) -> ProviderMultiStateEntity | None: + """Return multi state entity that contains a state with given handle.""" + with self._mdib.mdib_lock: + for internal_entity in self._mdib.internal_entities.values(): + if internal_entity.is_multi_state and handle in internal_entity.states: + return internal_entity.mk_entity(self._mdib) + return None + + def by_node_type(self, node_type: QName) -> list[ProviderEntityType]: + """Return all entities with given node type.""" + ret = [] + with self._mdib.mdib_lock: + for internal_entity in self._mdib.internal_entities.values(): + if node_type == internal_entity.descriptor.NODETYPE: + ret.append(internal_entity.mk_entity(self._mdib)) + return ret + + def by_parent_handle(self, parent_handle: str | None) -> list[ProviderEntityType]: + """Return all entities with given parent handle.""" + ret = [] + with self._mdib.mdib_lock: + for internal_entity in self._mdib.internal_entities.values(): + if internal_entity.descriptor.parent_handle == parent_handle: + ret.append(internal_entity.mk_entity(self._mdib)) + return ret + + def by_coding(self, coding: Coding) -> list[ProviderEntityType]: + """Return all entities with given Coding.""" + ret = [] + with self._mdib.mdib_lock: + for internal_entity in self._mdib.internal_entities.values(): + if internal_entity.descriptor.Type is not None and internal_entity.descriptor.Type.is_equivalent(coding): + ret.append(internal_entity.mk_entity(self._mdib)) + return ret + + def by_coded_value(self, coded_value: CodedValue) -> list[ProviderEntityType]: + """Return all entities with given Coding.""" + ret = [] + with self._mdib.mdib_lock: + for internal_entity in self._mdib.internal_entities.values(): + if internal_entity.descriptor.Type is not None and internal_entity.descriptor.Type.is_equivalent( + coded_value): + ret.append(internal_entity.mk_entity(self._mdib)) + return ret + + def items(self) -> Iterable[tuple[str, [ProviderEntityType]]]: + """Return the items.""" + with self._mdib.mdib_lock: + for handle, internal_entity in self._mdib.internal_entities.items(): + yield handle, internal_entity.mk_entity(self._mdib) + + def new_entity(self, + node_type: QName, + handle: str, + parent_handle: str | None) -> ProviderEntityType: + """Create an entity. + + User can modify the entity and then add it to transaction via write_entity! + It will not become part of mdib without write_entity call! + """ + if handle in self._mdib.internal_entities or handle in self._mdib.new_entities: + raise ValueError('Handle already exists') + + descr_cls = self._mdib.data_model.get_descriptor_container_class(node_type) + descriptor_container = descr_cls(handle=handle, parent_handle=parent_handle) + if parent_handle is not None: + parent_entity = (self._mdib.new_entities.get(parent_handle) + or self._mdib.internal_entities.get(parent_handle)) + if parent_entity is None: + raise ValueError(f'Entity {handle} has no parent (parent_handle = {parent_handle})!') + descriptor_container.set_source_mds(parent_entity.descriptor.source_mds) + else: + descriptor_container.set_source_mds(descriptor_container.Handle) # this is a mds, source_mds is itself + + new_internal_entity = self._mdib.entity_factory(descriptor_container, []) + if handle in self._mdib.descr_handle_version_lookup: + # This handle existed before. Use last descriptor version + 1 + new_internal_entity.descriptor.DescriptorVersion = self._mdib.descr_handle_version_lookup[handle] + 1 + if not new_internal_entity.is_multi_state: + # create a state + state_cls = self._mdib.data_model.get_state_container_class(descriptor_container.STATE_QNAME) + new_internal_entity.state = state_cls(descriptor_container) + if handle in self._mdib.state_handle_version_lookup: + new_internal_entity.state.StateVersion = self._mdib.state_handle_version_lookup[handle] + 1 + self._mdib.new_entities[descriptor_container.Handle] = new_internal_entity # write to mdib in process_transaction + return new_internal_entity.mk_entity(self._mdib) + + def __len__(self) -> int: + """Return number of entities.""" + return len(self._mdib.internal_entities) + + +class EntityProviderMdib(EntityMdibBase): + """Device side implementation of a mdib. + + Do not modify containers directly, use transactions for that purpose. + Transactions keep track of changes and initiate sending of update notifications to clients. + """ + + transaction: TransactionResultProtocol | None = ObservableProperty(fire_only_on_changed_value=False) + rt_updates = ObservableProperty(fire_only_on_changed_value=False) # different observable for performance + + # ToDo: keep track of DescriptorVersions and StateVersion in order to allow correct StateVersion after delete/create + # new version must be bigger then old version + def __init__(self, + sdc_definitions: type[BaseDefinitions] | None = None, + log_prefix: str | None = None, + extra_functionality: type | None = None, + transaction_factory: Callable[[EntityProviderMdib, TransactionType, LoggerAdapter], + AnyEntityTransactionManagerProtocol] | None = None, + ): + """Construct a ProviderMdib. + + :param sdc_definitions: defaults to sdc11073.definitions_sdc.SdcV1Definitions + :param log_prefix: a string + :param extra_functionality: class for extra functionality, default is ProviderMdibMethods + :param transaction_factory: optional alternative transactions factory. + """ + if sdc_definitions is None: + from sdc11073.definitions_sdc import SdcV1Definitions # lazy import, needed to brake cyclic imports + sdc_definitions = SdcV1Definitions + super().__init__(sdc_definitions, + loghelper.get_logger_adapter('sdc.device.mdib', log_prefix), + ) + + self.nsmapper = sdc_definitions.data_model.ns_helper + + if extra_functionality is None: + extra_functionality = EntityProviderMdibMethods + + self._entities: dict[str, ProviderInternalEntityType] = {} # key is the handle + + # Keep track of entities that were created but are not yet part of mdib. + # They become part of mdib when they are added via transaction. + self._new_entities: dict[str, ProviderInternalEntityType] = {} + + # The official API + self.entities: ProviderEntityGetterProtocol = ProviderEntityGetter(self) + + self._xtra = extra_functionality(self) + self._tr_lock = Lock() # transaction lock + + self.sequence_id = uuid.uuid4().urn # this uuid identifies this mdib instance + + self._annotators = {} + self.current_transaction = None + + self.pre_commit_handler = None # pre_commit_handler can modify transaction if needed before it is committed + self.post_commit_handler = None # post_commit_handler can modify mdib if needed after it is committed + self._transaction_factory = transaction_factory or mk_transaction + self._retrievability_episodic = [] # a list of handles + self.retrievability_periodic = defaultdict(list) + self.mddescription_version = 0 + self.mdstate_version = 0 + self._is_initialized = False + # In order to be able to re-create a descriptor or state with a bigger version than before, + # these lookups keep track of version counters for deleted descriptors and states. + self.descr_handle_version_lookup: dict[str, int] = {} + self.state_handle_version_lookup: dict[str, int] = {} + self.entity_factory: ProviderEntityFactory = _mk_internal_entity + + @property + def xtra(self) -> Any: + """Give access to extended functionality.""" + return self._xtra + + @property + def internal_entities(self) -> dict[str, ProviderInternalEntityType]: + """The property is needed by transactions. Do not use it otherwise.""" + return self._entities + + @property + def new_entities(self) -> dict[str, ProviderInternalEntityType]: + """The property is needed by transactions. Do not use it otherwise.""" + return self._new_entities + + def set_initialized(self): + """Set initialized state = True.""" + self._is_initialized = True + + @property + def is_initialized(self) -> bool: + """Return True if mdib is already initialized.""" + return self._is_initialized + + @contextmanager + def _transaction_manager(self, # noqa: PLR0912, C901 + transaction_type: TransactionType, + set_determination_time: bool = True) -> AbstractContextManager[ + AnyEntityTransactionManagerProtocol]: + """Start a transaction, return a new transaction manager.""" + with self._tr_lock, self.mdib_lock: + try: + self.current_transaction = self._transaction_factory(self, transaction_type, self.logger) + yield self.current_transaction + + if callable(self.pre_commit_handler): + self.pre_commit_handler(self, self.current_transaction) + if self.current_transaction.error: + self._logger.info('transaction_manager: transaction without updates!') + else: + # update observables + transaction_result = self.current_transaction.process_transaction(set_determination_time) + if transaction_result.new_mdib_version is not None: + self.mdib_version = transaction_result.new_mdib_version + self.transaction = transaction_result + + if transaction_result.alert_updates: + self.alert_by_handle = {st.DescriptorHandle: st for st in transaction_result.alert_updates} + if transaction_result.comp_updates: + self.component_by_handle = {st.DescriptorHandle: st for st in transaction_result.comp_updates} + if transaction_result.ctxt_updates: + self.context_by_handle = {st.Handle: st for st in transaction_result.ctxt_updates} + if transaction_result.descr_created: + self.new_descriptors_by_handle = {descr.Handle: descr for descr + in transaction_result.descr_created} + if transaction_result.descr_deleted: + self.deleted_descriptors_by_handle = {descr.Handle: descr for descr + in transaction_result.descr_deleted} + if transaction_result.descr_updated: + self.updated_descriptors_by_handle = {descr.Handle: descr for descr + in transaction_result.descr_updated} + if transaction_result.metric_updates: + self.metrics_by_handle = {st.DescriptorHandle: st for st in transaction_result.metric_updates} + if transaction_result.op_updates: + self.operation_by_handle = {st.DescriptorHandle: st for st in transaction_result.op_updates} + if transaction_result.rt_updates: + self.waveform_by_handle = {st.DescriptorHandle: st for st in transaction_result.rt_updates} + + if callable(self.post_commit_handler): + self.post_commit_handler(self, self.current_transaction) + finally: + self.current_transaction = None + + @contextmanager + def context_state_transaction(self) -> AbstractContextManager[EntityContextStateTransactionManagerProtocol]: + """Return a transaction for context state updates.""" + with self._transaction_manager(TransactionType.context, False) as mgr: + yield mgr + + @contextmanager + def alert_state_transaction(self, set_determination_time: bool = True) \ + -> AbstractContextManager[EntityStateTransactionManagerProtocol]: + """Return a transaction for alert state updates.""" + with self._transaction_manager(TransactionType.alert, set_determination_time) as mgr: + yield mgr + + @contextmanager + def metric_state_transaction(self, set_determination_time: bool = True) \ + -> AbstractContextManager[EntityStateTransactionManagerProtocol]: + """Return a transaction for metric state updates (not real time samples!).""" + with self._transaction_manager(TransactionType.metric, set_determination_time) as mgr: + yield mgr + + @contextmanager + def rt_sample_state_transaction(self, set_determination_time: bool = False) \ + -> AbstractContextManager[EntityStateTransactionManagerProtocol]: + """Return a transaction for real time sample state updates.""" + with self._transaction_manager(TransactionType.rt_sample, set_determination_time) as mgr: + yield mgr + + @contextmanager + def component_state_transaction(self) -> AbstractContextManager[EntityStateTransactionManagerProtocol]: + """Return a transaction for component state updates.""" + with self._transaction_manager(TransactionType.component) as mgr: + yield mgr + + @contextmanager + def operational_state_transaction(self) -> AbstractContextManager[EntityStateTransactionManagerProtocol]: + """Return a transaction for operational state updates.""" + with self._transaction_manager(TransactionType.operational) as mgr: + yield mgr + + @contextmanager + def descriptor_transaction(self) -> AbstractContextManager[EntityDescriptorTransactionManagerProtocol]: + """Return a transaction for descriptor updates. + + This transaction also allows to handle the states that relate to the modified descriptors. + """ + with self._transaction_manager(TransactionType.descriptor) as mgr: + yield mgr + + def make_descriptor_node(self, + descriptor_container: AbstractDescriptorContainer, + parent_node: xml_utils.LxmlElement, + tag: QName, + set_xsi_type: bool = True) -> xml_utils.LxmlElement: + """Create a lxml etree node with subtree from instance data. + + :param descriptor_container: a descriptor container instance + :param parent_node: parent node + :param tag: tag of node + :param set_xsi_type: if true, the NODETYPE will be used to set the xsi:type attribute of the node + :return: an etree node. + """ + ns_map = self.nsmapper.partial_map(self.nsmapper.PM, self.nsmapper.XSI) \ + if set_xsi_type else self.nsmapper.partial_map(self.nsmapper.PM) + node = SubElement(parent_node, + tag, + attrib={'Handle': descriptor_container.Handle}, + nsmap=ns_map) + descriptor_container.update_node(node, self.nsmapper, set_xsi_type) # create all + child_entities = self.entities.by_parent_handle(descriptor_container.Handle) + # append all child containers, then bring all child elements in correct order + for child_entity in child_entities: + child_tag, set_xsi = descriptor_container.tag_name_for_child_descriptor(child_entity.descriptor.NODETYPE) + self.make_descriptor_node(child_entity.descriptor, node, child_tag, set_xsi) + descriptor_container.sort_child_nodes(node) + return node + + def reconstruct_mdib(self) -> (xml_utils.LxmlElement, MdibVersionGroup): + """Build dom tree from current data. + + This method does not include context states! + """ + with self.mdib_lock: + return self._reconstruct_mdib(add_context_states=False), self.mdib_version_group + + def reconstruct_mdib_with_context_states(self) -> (xml_utils.LxmlElement, MdibVersionGroup): + """Build dom tree from current data. + + This method includes the context states. + """ + with self.mdib_lock: + return self._reconstruct_mdib(add_context_states=True), self.mdib_version_group + + def reconstruct_md_description(self) -> (xml_utils.LxmlElement, MdibVersionGroup): + """Build dom tree of descriptors from current data.""" + with self.mdib_lock: + node = self._reconstruct_md_description() + return node, self.mdib_version_group + + def _reconstruct_mdib(self, add_context_states: bool) -> xml_utils.LxmlElement: + """Build dom tree of mdib from current data. + + If add_context_states is False, context states are not included. + """ + pm = self.data_model.pm_names + msg = self.data_model.msg_names + doc_nsmap = self.nsmapper.ns_map + mdib_node = Element(msg.Mdib, nsmap=doc_nsmap) + mdib_node.set('MdibVersion', str(self.mdib_version)) + mdib_node.set('SequenceId', self.sequence_id) + if self.instance_id is not None: + mdib_node.set('InstanceId', str(self.instance_id)) + md_description_node = self._reconstruct_md_description() + mdib_node.append(md_description_node) + + # add a list of states + md_state_node = SubElement(mdib_node, pm.MdState, + attrib={'StateVersion': str(self.mdstate_version)}, + nsmap=doc_nsmap) + tag = pm.State + for entity in self._entities.values(): + if entity.descriptor.is_context_descriptor: + if add_context_states: + for state_container in entity.states.values(): + md_state_node.append(state_container.mk_state_node(tag, self.nsmapper)) + elif entity.state is not None: + md_state_node.append(entity.state.mk_state_node(tag, self.nsmapper)) + return mdib_node + + def _reconstruct_md_description(self) -> xml_utils.LxmlElement: + """Build dom tree of descriptors from current data.""" + pm = self.data_model.pm_names + doc_nsmap = self.nsmapper.ns_map + root_entities = self.entities.by_parent_handle(None) + if root_entities: + md_description_node = Element(pm.MdDescription, + attrib={'DescriptionVersion': str(self.mddescription_version)}, + nsmap=doc_nsmap) + for root_entity in root_entities: + self.make_descriptor_node(root_entity.descriptor, md_description_node, tag=pm.Mds, set_xsi_type=False) + return md_description_node + + @classmethod + def from_mdib_file(cls, + path: str, + protocol_definition: type[BaseDefinitions] | None = None, + xml_reader_class: type[MessageReader] | None = MessageReader, + log_prefix: str | None = None) -> EntityProviderMdib: + """Construct mdib from a file. + + :param path: the input file path for creating the mdib + :param protocol_definition: an optional object derived from BaseDefinitions, forces usage of this definition + :param xml_reader_class: class that is used to read mdib xml file + :param log_prefix: a string or None + :return: instance. + """ + with Path(path).open('rb') as the_file: + xml_text = the_file.read() + return cls.from_string(xml_text, + protocol_definition, + xml_reader_class, + log_prefix) + + @classmethod + def from_string(cls, + xml_text: bytes, + protocol_definition: type[BaseDefinitions] | None = None, + xml_reader_class: type[MessageReader] | None = MessageReader, + log_prefix: str | None = None) -> EntityProviderMdib: + """Construct mdib from a string. + + :param xml_text: the input string for creating the mdib + :param protocol_definition: an optional object derived from BaseDefinitions, forces usage of this definition + :param xml_reader_class: class that is used to read mdib xml file + :param log_prefix: a string or None + :return: instance. + """ + # get protocol definition that matches xml_text + if protocol_definition is None: + for definition_cls in ProtocolsRegistry.protocols: + pm_namespace = definition_cls.data_model.ns_helper.PM.namespace.encode('utf-8') + if pm_namespace in xml_text: + protocol_definition = definition_cls + break + if protocol_definition is None: + raise ValueError('cannot create instance, no known BICEPS schema version identified') + mdib = cls(protocol_definition, log_prefix=log_prefix) + + xml_msg_reader = xml_reader_class(protocol_definition, None, mdib.logger) + descriptor_containers, state_containers = xml_msg_reader.read_mdib_xml(xml_text) + # Todo: msg_reader sets source_mds while reading xml mdib + mdib.xtra.set_initial_content(descriptor_containers, state_containers) + mdib.set_initialized() + return mdib diff --git a/src/sdc11073/entity_mdib/entity_providermdibxtra.py b/src/sdc11073/entity_mdib/entity_providermdibxtra.py new file mode 100644 index 00000000..79118e3e --- /dev/null +++ b/src/sdc11073/entity_mdib/entity_providermdibxtra.py @@ -0,0 +1,227 @@ +"""The module contains extensions to the functionality of the EntityProviderMdib.""" +from __future__ import annotations + +import time +from collections import defaultdict +from typing import TYPE_CHECKING + +from sdc11073.etc import apply_map +from sdc11073.exceptions import ApiUsageError +from sdc11073.xml_types.pm_types import RetrievabilityMethod + +if TYPE_CHECKING: + from sdc11073.location import SdcLocation + from sdc11073.mdib.descriptorcontainers import AbstractDescriptorContainer + from sdc11073.mdib.statecontainers import AbstractStateContainer + from sdc11073.xml_types.pm_types import InstanceIdentifier + + from .entities import ProviderEntity, ProviderMultiStateEntity + from .entity_providermdib import EntityProviderMdib + + +class EntityProviderMdibMethods: + """Extra methods for provider mdib that are not part of core functionality.""" + + def __init__(self, provider_mdib: EntityProviderMdib): + self._mdib = provider_mdib + self.default_validators = (provider_mdib.data_model.pm_types.InstanceIdentifier( + root='rootWithNoMeaning', extension_string='System'),) + + def set_all_source_mds(self): + """Set source mds in all entities.""" + dict_by_parent_handle = defaultdict(list) + descriptor_containers = [entity.descriptor for entity in self._mdib.internal_entities.values()] + for d in descriptor_containers: + dict_by_parent_handle[d.parent_handle].append(d) + + def tag_tree(source_mds_handle: str, descriptor_container: AbstractDescriptorContainer): + descriptor_container.set_source_mds(source_mds_handle) + children = dict_by_parent_handle[descriptor_container.Handle] + for ch in children: + tag_tree(source_mds_handle, ch) + + for mds in dict_by_parent_handle[None]: # only mds has no parent + tag_tree(mds.Handle, mds) + + def set_location(self, sdc_location: SdcLocation, + validators: list[InstanceIdentifier] | None = None, + location_context_descriptor_handle: str | None = None): + """Create a location context state. The new state will be the associated state. + + This method updates only the mdib data! + Use the SdcProvider.set_location method if you want to publish the address on the network. + :param sdc_location: a sdc11073.location.SdcLocation instance + :param validators: a list of InstanceIdentifier objects or None + If None, self.default_validators is used. + :param location_context_descriptor_handle: Only needed if the mdib contains more than one + LocationContextDescriptor. Then this defines the descriptor for which a new LocationContextState + shall be created. + """ + mdib = self._mdib + pm = mdib.data_model.pm_names + + if location_context_descriptor_handle is None: + # assume there is only one descriptor in mdib, user has not provided a handle. + location_entity = mdib.entities.by_node_type(pm.LocationContextDescriptor)[0] + else: + location_entity = mdib.entities.by_handle(location_context_descriptor_handle) + + new_location = location_entity.new_state() + new_location.update_from_sdc_location(sdc_location) + if validators is None: + new_location.Validator = self.default_validators + else: + new_location.Validator = validators + + with mdib.context_state_transaction() as mgr: + # disassociate before creating a new state + handles = self.disassociate_all(location_entity, + mgr.new_mdib_version, + ignored_handle=new_location.Handle) + new_location.BindingMdibVersion = mgr.new_mdib_version + new_location.BindingStartTime = time.time() + new_location.ContextAssociation = mdib.data_model.pm_types.ContextAssociation.ASSOCIATED + handles.append(new_location.Handle) + mgr.write_entity(location_entity, handles) + + def set_initial_content(self, + descriptor_containers: list[AbstractDescriptorContainer], + state_containers: list[AbstractStateContainer]): + """Add states.""" + if self._mdib.is_initialized: # pragma: no cover + raise ApiUsageError('method "set_initial_content" can not be called when mdib is already initialized') + for d in descriptor_containers: + states = [s for s in state_containers if s.DescriptorHandle == d.Handle] + entity = self._mdib.entity_factory(d, states) + self._mdib.internal_entities[d.Handle] = entity + + self.set_all_source_mds() + self.mk_state_containers_for_all_descriptors() + self.set_states_initial_values() + self.update_retrievability_lists() + + def mk_state_containers_for_all_descriptors(self): + """Create a state container for every descriptor that is missing a state in mdib. + + The model requires that there is a state for every descriptor (exception: multi-states) + """ + mdib = self._mdib + pm = mdib.data_model.pm_names + for entity in mdib.internal_entities.values(): + if entity.descriptor.is_context_descriptor: + continue + if entity.state is None: + state_cls = mdib.data_model.get_state_class_for_descriptor(entity.descriptor) + state = state_cls(entity.descriptor) + entity.state = state + # add some initial values where needed + if state.is_alert_condition: + state.DeterminationTime = time.time() + elif state.NODETYPE == pm.AlertSystemState: # noqa: SIM300 + state.LastSelfCheck = time.time() + state.SelfCheckCount = 1 + elif state.NODETYPE == pm.ClockState: # noqa: SIM300 + state.LastSet = time.time() + if mdib.current_transaction is not None: + mdib.current_transaction.add_state(state) + + def set_states_initial_values(self): + """Set all states to defined starting conditions. + + This method is ment to be called directly after the mdib was loaded and before the provider is published + on the network. + It changes values only internally in the mdib, no notifications are sent! + + """ + pm_names = self._mdib.data_model.pm_names + pm_types = self._mdib.data_model.pm_types + + for entity in self._mdib.internal_entities.values(): + if entity.node_type == pm_names.AlertSystemDescriptor: + # alert systems are active + entity.state.ActivationState = pm_types.AlertActivation.ON + entity.state.SystemSignalActivation.append( + pm_types.SystemSignalActivation(manifestation=pm_types.AlertSignalManifestation.AUD, + state=pm_types.AlertActivation.ON)) + elif entity.descriptor.is_alert_condition_descriptor: + # alert conditions are active, but not present + entity.state.ActivationState = pm_types.AlertActivation.ON + entity.state.Presence = False + elif entity.descriptor.is_alert_signal_descriptor: + # alert signals are not present, and delegable signals are also not active + if entity.descriptor.SignalDelegationSupported: + entity.state.Location = pm_types.AlertSignalPrimaryLocation.REMOTE + entity.state.ActivationState = pm_types.AlertActivation.OFF + entity.state.Presence = pm_types.AlertSignalPresence.OFF + else: + entity.state.ActivationState = pm_types.AlertActivation.ON + entity.state.Presence = pm_types.AlertSignalPresence.OFF + elif entity.descriptor.is_component_descriptor: + # all components are active + entity.state.ActivationState = pm_types.ComponentActivation.ON + elif entity.descriptor.is_operational_descriptor: + # all operations are enabled + entity.state.OperatingMode = pm_types.OperatingMode.ENABLED + + def update_retrievability_lists(self): + """Update internal lists, based on current mdib descriptors.""" + mdib = self._mdib + with mdib.mdib_lock: + del mdib._retrievability_episodic[:] # noqa: SLF001 + mdib.retrievability_periodic.clear() + for entity in mdib.internal_entities.values(): + for r in entity.descriptor.get_retrievability(): + for r_by in r.By: + if r_by.Method == RetrievabilityMethod.EPISODIC: + mdib._retrievability_episodic.append(entity.descriptor.Handle) # noqa: SLF001 + elif r_by.Method == RetrievabilityMethod.PERIODIC: + period_float = r_by.UpdatePeriod + period_ms = int(period_float * 1000.0) + mdib.retrievability_periodic[period_ms].append(entity.descriptor.Handle) + + def get_all_entities_in_subtree(self, root_entity: ProviderEntity | ProviderMultiStateEntity, + depth_first: bool = True, + include_root: bool = True, + ) -> list[ProviderEntity | ProviderMultiStateEntity]: + """Return the tree below descriptor_container as a flat list.""" + result = [] + + def _getchildren(parent: ProviderEntity | ProviderMultiStateEntity): + child_containers = [e for e in self._mdib.internal_entities.values() if e.parent_handle == parent.handle] + if not depth_first: + result.extend(child_containers) + apply_map(_getchildren, child_containers) + if depth_first: + result.extend(child_containers) + + if include_root and not depth_first: + result.append(root_entity) + _getchildren(root_entity) + if include_root and depth_first: + result.append(root_entity) + return result + + def disassociate_all(self, + entity: ProviderMultiStateEntity, + unbinding_mdib_version: int, + ignored_handle: str | None = None) -> list[str]: + """Disassociate all associated states in entity. + + The method returns a list of states that were disassociated. + :param entity: ProviderMultiStateEntity + :param ignored_handle: the context state with this Handle shall not be touched. + """ + pm_types = self._mdib.data_model.pm_types + disassociated_state_handles = [] + for state in entity.states.values(): + if state.Handle == ignored_handle: + # If state is already part of this transaction leave it also untouched, accept what the user wanted. + continue + if state.ContextAssociation != pm_types.ContextAssociation.DISASSOCIATED \ + or state.UnbindingMdibVersion is None: + state.ContextAssociation = pm_types.ContextAssociation.DISASSOCIATED + if state.UnbindingMdibVersion is None: + state.UnbindingMdibVersion = unbinding_mdib_version + state.BindingEndTime = time.time() + disassociated_state_handles.append(state.Handle) + return disassociated_state_handles diff --git a/src/sdc11073/entity_mdib/entity_transactions.py b/src/sdc11073/entity_mdib/entity_transactions.py new file mode 100644 index 00000000..bc930ce5 --- /dev/null +++ b/src/sdc11073/entity_mdib/entity_transactions.py @@ -0,0 +1,727 @@ +"""The module implements transactions for EntityProviderMdib.""" +from __future__ import annotations + +import copy +import time +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING, Union + +from sdc11073.exceptions import ApiUsageError +from sdc11073.mdib.transactionsprotocol import ( + AnyTransactionManagerProtocol, + TransactionItem, + TransactionResultProtocol, + TransactionType, +) + +if TYPE_CHECKING: + from sdc11073.loghelper import LoggerAdapter + from sdc11073.mdib.descriptorcontainers import AbstractDescriptorProtocol + from sdc11073.mdib.statecontainers import AbstractMultiStateProtocol, AbstractStateProtocol + + from .entities import ( + ProviderEntity, + ProviderInternalEntity, + ProviderInternalMultiStateEntity, + ProviderMultiStateEntity, + ) + from .entity_providermdib import EntityProviderMdib, ProviderInternalEntityType + + AnyProviderEntity = Union[ + ProviderEntity, ProviderMultiStateEntity, ProviderInternalEntity, ProviderInternalMultiStateEntity] + + +class _Modification(Enum): + insert = auto() + update = auto() + delete = auto() + + +@dataclass(frozen=True) +class DescriptorTransactionItem: + """Transaction Item with old and new container.""" + + entity: ProviderEntity | ProviderMultiStateEntity | ProviderInternalEntity | ProviderInternalMultiStateEntity + modification: _Modification + + +def _update_multi_states(mdib: EntityProviderMdib, # noqa: C901 + new: ProviderMultiStateEntity, + old: ProviderMultiStateEntity, + modified_handles: list[str] | None = None, + adjust_state_version: bool = True): + if not (new.is_multi_state and old.is_multi_state): # pragma: no cover + raise ApiUsageError('_update_multi_states only handles context states!') + if new.handle != old.handle: # pragma: no cover + raise ApiUsageError(f'_update_multi_states found different handles! new={new.handle}, old = {old.handle}') + if not modified_handles: + modified_handles = new.states.keys() + for handle in modified_handles: + state_container = new.states.get(handle) + if state_container is None: + # a deleted state : this cannot be communicated via notification. + # delete it internal_entity anf that is all + if handle in old.states: + old.states.pop(handle) + else: + raise KeyError(f'invalid handle {handle}!') + continue + + old_state = old.states.get(state_container.Handle) + tmp = copy.deepcopy(state_container) + + if old_state is None: + # this is a new state + tmp.descriptor_container = old.descriptor + tmp.DescriptorVersion = old.descriptor.DescriptorVersion + if adjust_state_version: + old_state_version = mdib.state_handle_version_lookup.get(tmp.Handle) + if old_state_version: + tmp.StateVersion = old_state_version + 1 + elif adjust_state_version: + tmp.StateVersion = old_state.StateVersion + 1 + + +def _adjust_version_counters(new_entity: ProviderInternalEntityType, + old_entity: ProviderInternalEntityType, + increment_descriptor_version: bool = False): + if increment_descriptor_version: + new_entity.descriptor.DescriptorVersion = old_entity.descriptor.DescriptorVersion + 1 + if new_entity.is_multi_state: + for new_state in new_entity.states.values(): + new_state.DescriptorVersion = new_entity.descriptor.DescriptorVersion + old_state = old_entity.states.get(new_state.Handle) + if old_state is not None: + new_state.StateVersion = old_state.StateVersion + 1 + else: + new_entity.state.DescriptorVersion = new_entity.descriptor.DescriptorVersion + new_entity.state.StateVersion = old_entity.state.StateVersion + 1 + + +class _TransactionBase: + def __init__(self, + provider_mdib: EntityProviderMdib, + logger: LoggerAdapter): + self._mdib = provider_mdib + # provide the new mdib version that the commit of this transaction will create + self.new_mdib_version = provider_mdib.mdib_version + 1 + self._logger = logger + self.metric_state_updates: dict[str, TransactionItem] = {} + self.alert_state_updates: dict[str, TransactionItem] = {} + self.component_state_updates: dict[str, TransactionItem] = {} + self.context_state_updates: dict[str, TransactionItem] = {} + self.operational_state_updates: dict[str, TransactionItem] = {} + self.rt_sample_state_updates: dict[str, TransactionItem] = {} + self._error = False + + def _handle_state_updates(self, state_updates_dict: dict) -> list[TransactionItem]: + """Update mdib table and return a list of states to be sent in notifications.""" + updates_list = [] + for transaction_item in state_updates_dict.values(): + if transaction_item.old is not None and transaction_item.new is not None: + # update + entity = self._mdib.internal_entities[transaction_item.old.DescriptorHandle] + if entity.descriptor.is_context_descriptor: + entity.states[transaction_item.old.Handle] = transaction_item.new + else: + entity.state = transaction_item.new + elif transaction_item.new is not None: + # insert + entity = self._mdib.internal_entities[transaction_item.new.DescriptorHandle] + if entity.descriptor.is_context_descriptor: + entity.states[transaction_item.new.Handle] = transaction_item.new + else: + entity.state = transaction_item.new + else: + # delete + entity = self._mdib.internal_entities[transaction_item.old.DescriptorHandle] + if entity.descriptor.is_context_descriptor: + entity.states.pop(transaction_item.old.Handle) + + if transaction_item.new is not None: + updates_list.append(transaction_item.new.mk_copy(copy_node=False)) + return updates_list + + def get_state_transaction_item(self, handle: str) -> TransactionItem | None: + """If transaction has a state with given handle, return the transaction-item, otherwise None. + + :param handle: the Handle of a context state or the DescriptorHandle in all other cases + """ + if not handle: # pragma: no cover + raise ValueError('No handle for state specified') + for lookup in (self.metric_state_updates, + self.alert_state_updates, + self.component_state_updates, + self.context_state_updates, + self.operational_state_updates, + self.rt_sample_state_updates): + if handle in lookup: + return lookup[handle] + return None + + @property + def error(self) -> bool: + return self._error + + +class DescriptorTransaction(_TransactionBase): + """A Transaction that allows to insert / update / delete Descriptors and to modify states related to them.""" + + def __init__(self, + provider_mdib: EntityProviderMdib, + logger: LoggerAdapter): + super().__init__(provider_mdib, logger) + self.descriptor_updates: dict[str, DescriptorTransactionItem] = {} + self._new_entities: dict[str, ProviderInternalEntity | ProviderInternalMultiStateEntity] = {} + + def transaction__entity(self, descriptor_handle: str) -> ProviderEntity | ProviderMultiStateEntity | None: + """Return the entity in open transaction if it exists. + + The descriptor can already be part of the transaction, and e.g. in pre_commit handlers of role providers + it can be necessary to have access to it. + """ + if not descriptor_handle: # pragma: no cover + raise ValueError('No handle for descriptor specified') + tr_container = self.descriptor_updates.get(descriptor_handle) + if tr_container is not None: + if tr_container.modification == _Modification.delete: + raise ValueError(f'The descriptor {descriptor_handle} is going to be deleted') + return tr_container.entity + return None + + def write_entity(self, # noqa: PLR0912, C901 + entity: ProviderEntity | ProviderMultiStateEntity, + adjust_version_counter: bool = True): + """Insert or update an entity.""" + descriptor_handle = entity.descriptor.Handle + if descriptor_handle in self.descriptor_updates: # pragma: no cover + raise ValueError(f'Entity {descriptor_handle} already in updated set!') + tmp = copy.copy(entity) # cannot deepcopy entity, that would deepcopy also whole mdib + tmp.descriptor = copy.deepcopy(entity.descriptor) # do not touch original entity of user + if entity.is_multi_state: + tmp.states = copy.deepcopy(entity.states) # do not touch original entity of user + else: + tmp.state = copy.deepcopy(entity.state) # do not touch original entity of user + if descriptor_handle in self._mdib.internal_entities: + # update + if adjust_version_counter: + old_entity = self._mdib.internal_entities[descriptor_handle] + _adjust_version_counters(tmp, old_entity, increment_descriptor_version=True) + self.descriptor_updates[descriptor_handle] = DescriptorTransactionItem(tmp, + _Modification.update) + + elif descriptor_handle in self._mdib.new_entities: + # create + if adjust_version_counter: + version = self._mdib.descr_handle_version_lookup.get(descriptor_handle) + if version is not None: + tmp.descriptor.DescriptorVersion = version + if tmp.is_multi_state: + for state in tmp.states.values(): + version = self._mdib.state_handle_version_lookup.get(state.Handle) + if version is not None: + state.StateVersion = version + else: + version = self._mdib.state_handle_version_lookup.get(descriptor_handle) + if version is not None: + tmp.state.StateVersion = version + self.descriptor_updates[descriptor_handle] = DescriptorTransactionItem(tmp, + _Modification.insert) + else: + # create without having internal entity + tmp_entity = self._mdib.entities.new_entity(entity.node_type, entity.handle, entity.parent_handle) + # replace descriptor and state in tmp_entity with values from tmp, but keep existing version counters + descriptor_version = tmp_entity.descriptor.DescriptorVersion + tmp_entity.descriptor = tmp.descriptor + tmp_entity.descriptor.DescriptorVersion = descriptor_version + if entity.is_multi_state: + tmp_entity.states = tmp.states + # change state versions if they were deleted before + for handle, state in tmp_entity.states.items(): + if handle in self._mdib.state_handle_version_lookup: + state.StateVersion = self._mdib.state_handle_version_lookup[handle] + 1 + else: + state_version = tmp_entity.state.StateVersion + tmp_entity.state = tmp.state + tmp_entity.state.StateVersion = state_version + tmp_entity.state.DescriptorVersion = descriptor_version + self.descriptor_updates[descriptor_handle] = DescriptorTransactionItem(tmp_entity, + _Modification.insert) + + def write_entities(self, + entities: list[ProviderEntity | ProviderMultiStateEntity], + adjust_version_counter: bool = True): + """Write entities in order parents first.""" + written_handles = [] + ent_dict = {ent.handle: ent for ent in entities} + while len(written_handles) < len(ent_dict): + for handle, ent in ent_dict.items(): + write_now = True + if (ent.parent_handle is not None + and ent.parent_handle in ent_dict + and ent.parent_handle not in written_handles): + # it has a parent, and parent has not been written yet + write_now = False + if write_now and handle not in written_handles: + self.write_entity(ent, adjust_version_counter) + written_handles.append(handle) + + def remove_entity(self, entity: ProviderEntity | ProviderMultiStateEntity): + """Remove existing descriptor from mdib.""" + if entity.handle in self.descriptor_updates: # pragma: no cover + raise ValueError(f'Descriptor {entity.handle} already in updated set!') + + internal_entity = self._mdib.internal_entities.get(entity.handle) + if internal_entity: + self.descriptor_updates[entity.handle] = DescriptorTransactionItem(internal_entity, + _Modification.delete) + + def process_transaction(self, set_determination_time: bool) -> TransactionResultProtocol: # noqa: ARG002, PLR0915, PLR0912, C901 + """Process transaction and create a TransactionResult. + + The parameter set_determination_time is only present in order to implement the interface correctly. + Determination time is not set, because descriptors have no modification time. + """ + proc = TransactionResult() + if self.descriptor_updates: + proc.new_mdib_version = self.new_mdib_version + # need to know all to be deleted and to be created descriptors + to_be_deleted_handles = [tr_item.entity.descriptor.Handle for tr_item in self.descriptor_updates.values() + if tr_item.modification == _Modification.delete] + to_be_created_handles = [tr_item.entity.descriptor.Handle for tr_item in self.descriptor_updates.values() + if tr_item.modification == _Modification.insert] + to_be_updated_handles = [tr_item.entity.descriptor.Handle for tr_item in self.descriptor_updates.values() + if tr_item.modification == _Modification.update] + + # Remark 1: + # handling only updated states here: If a descriptor is created, it can be assumed that the + # application also creates the state in a transaction. + # The state will then be transported via that notification report. + # Maybe this needs to be reworked, but at the time of this writing it seems fine. + # + # Remark 2: + # DescriptionModificationReport also contains the states that are related to the descriptors. + # => if there is one, update its DescriptorVersion and add it to list of states that shall be sent + # (Assuming that context descriptors (patient, location) are never changed, + # additional check for states in self.context_states is not needed. + # If this assumption is wrong, that functionality must be added!) + + # Restrict transaction to only insert, update or delete stuff. No mixes! + # This simplifies handling a lot! + types = [handle for handle in (to_be_deleted_handles, to_be_created_handles, to_be_updated_handles) + if handle] + if not types: + return proc # nothing changed + if len(types) > 1: # pragma: no cover + raise ValueError('this transaction can only handle one of insert, update, delete!') + + for tr_item in self.descriptor_updates.values(): + if tr_item.modification == _Modification.insert: + # this is a create operation + new_entity = tr_item.entity + + self._logger.debug( # noqa: PLE1205 + 'transaction_manager: new entity Handle={}, node type={}', + new_entity.handle, new_entity.descriptor.NODETYPE) + + # move temporary new internal entity to mdib + internal_entity = self._mdib.new_entities[new_entity.handle] + self._mdib.internal_entities[new_entity.handle] = internal_entity + del self._mdib.new_entities[new_entity.handle] + + self._update_internal_entity(new_entity, internal_entity) + + proc.descr_created.append( + internal_entity.descriptor) # this will cause a Description Modification Report + state_update_list = proc.get_state_updates_list(new_entity.descriptor) + + if internal_entity.is_multi_state: + state_update_list.extend(internal_entity.states) + # Todo: update context state handles in mdib + + else: + state_update_list.append(internal_entity.state) + + if (internal_entity.parent_handle is not None + and internal_entity.parent_handle not in to_be_created_handles): + self._increment_parent_descriptor_version(proc, internal_entity) + + elif tr_item.modification == _Modification.delete: + # this is a delete operation + + # Todo: is tr_item.entity always an internal entity? + handle = tr_item.entity.descriptor.Handle + internal_entity = self._mdib.internal_entities.get(handle) + if internal_entity is None: + self._logger.debug( # noqa: PLE1205 + 'transaction_manager: cannot remove unknown descriptor Handle={}', handle) + return None + + self._logger.debug( # noqa: PLE1205 + 'transaction_manager: rm descriptor Handle={}', handle) + all_entities = self._mdib.xtra.get_all_entities_in_subtree(internal_entity) + for entity in all_entities: + + # save last versions + self._mdib.descr_handle_version_lookup[ + entity.descriptor.Handle] = entity.descriptor.DescriptorVersion + if entity.is_multi_state: + for state in entity.states.values(): + self._mdib.state_handle_version_lookup[state.Handle] = state.StateVersion + else: + self._mdib.state_handle_version_lookup[entity.descriptor.Handle] = entity.state.StateVersion + + self._mdib.internal_entities.pop(entity.handle) + proc.descr_deleted.extend([e.descriptor for e in all_entities]) + # increment DescriptorVersion if a child descriptor is added or deleted. + if internal_entity.parent_handle is not None \ + and internal_entity.parent_handle not in to_be_deleted_handles: + # Todo: whole parent chain should be checked + # only update parent if it is not also deleted in this transaction + self._increment_parent_descriptor_version(proc, internal_entity) + else: + # this is an update operation + # it does not change tr_item.entity! + # Todo: check if state changes exist and raise an error in that case. + # It simplifies code a lot if it is safe to assume that states + # have not changed in description transaction + updated_entity = tr_item.entity + internal_entity = self._mdib.internal_entities[tr_item.entity.handle] + self._logger.debug( # noqa: PLE1205 + 'transaction_manager: update descriptor Handle={}, DescriptorVersion={}', + internal_entity.handle, updated_entity.descriptor.DescriptorVersion) + self._update_internal_entity(updated_entity, internal_entity) + proc.descr_updated.append( + internal_entity.descriptor) # this will cause a Description Modification Report + state_update_list = proc.get_state_updates_list(internal_entity.descriptor) + if updated_entity.is_multi_state: + state_update_list.extend(internal_entity.states.values()) + else: + state_update_list.append(internal_entity.state) + return proc + + def _update_internal_entity(self, modified_entity: ProviderEntity | ProviderMultiStateEntity, + internal_entity: ProviderInternalEntity | ProviderInternalMultiStateEntity): + """Write back information into internal entity.""" + internal_entity.descriptor.update_from_other_container(modified_entity.descriptor) + if modified_entity.is_multi_state: + _update_multi_states(self._mdib, + modified_entity, + internal_entity, + None) + # Todo: update context state handles in mdib + + else: + internal_entity.state.update_from_other_container(modified_entity.state) + + def _increment_parent_descriptor_version(self, proc: TransactionResult, + entity: ProviderInternalEntityType): + """Increment version counter of descriptor and state. + + Add both to transaction result. + """ + parent_entity = self._mdib.internal_entities.get(entity.parent_handle) + updates_list = proc.get_state_updates_list(parent_entity.descriptor) + + if parent_entity is not None: + parent_entity.descriptor.increment_descriptor_version() + # parent entity can never be a multi state + parent_entity.state.increment_state_version() + + # Todo: why make a copy? + proc.descr_updated.append(parent_entity.descriptor.mk_copy()) + updates_list.append(parent_entity.state.mk_copy()) + + +class StateTransactionBase(_TransactionBase): + """Base Class for all transactions that modify states.""" + + def __init__(self, + provider_mdib: EntityProviderMdib, + logger: LoggerAdapter): + super().__init__(provider_mdib, logger) + self._state_updates = {} # will be set to proper value in derived classes + + def has_state(self, descriptor_handle: str) -> bool: + """Check if transaction has a state with given handle.""" + return descriptor_handle in self._state_updates + + def write_entity(self, entity: ProviderEntity, adjust_version_counter: bool = True): + """Update the state of the entity.""" + if not self._is_correct_state_type(entity.state): + raise ApiUsageError(f'Wrong data type in transaction! {self.__class__.__name__}, {entity.state}') + descriptor_handle = entity.state.DescriptorHandle + old_state = self._mdib.internal_entities[descriptor_handle].state + tmp = copy.deepcopy(entity.state) # do not touch original entity of user + if adjust_version_counter: + tmp.DescriptorVersion = old_state.DescriptorVersion + tmp.StateVersion = old_state.StateVersion + 1 + self._state_updates[descriptor_handle] = TransactionItem(old=old_state, + new=tmp) + + def write_entities(self, entities: list[ProviderEntity], adjust_version_counter: bool = True): + """Update the states of entities.""" + for entity in entities: + # check all states before writing any of them + if not self._is_correct_state_type(entity.state): + raise ApiUsageError(f'Wrong data type in transaction! {self.__class__.__name__}, {entity.state}') + for ent in entities: + self.write_entity(ent, adjust_version_counter) + + @staticmethod + def _is_correct_state_type(state: AbstractStateProtocol) -> bool: # noqa: ARG004 + return False + + +class AlertStateTransaction(StateTransactionBase): + """A Transaction for alert states.""" + + def __init__(self, + provider_mdib: EntityProviderMdib, + logger: LoggerAdapter): + super().__init__(provider_mdib, logger) + self._state_updates = self.alert_state_updates + + def process_transaction(self, set_determination_time: bool) -> TransactionResultProtocol: + """Process transaction and create a TransactionResult.""" + if set_determination_time: + for tr_item in self._state_updates.values(): + new_state = tr_item.new + old_state = tr_item.old + if new_state is None or not hasattr(new_state, 'Presence'): + continue + if old_state is None: + if new_state.Presence: + new_state.DeterminationTime = time.time() + elif new_state.is_alert_condition and new_state.Presence != old_state.Presence: + new_state.DeterminationTime = time.time() + proc = TransactionResult() + if self._state_updates: + proc.new_mdib_version = self.new_mdib_version + updates = self._handle_state_updates(self._state_updates) + proc.alert_updates.extend(updates) + return proc + + @staticmethod + def _is_correct_state_type(state: AbstractStateProtocol) -> bool: + return state.is_alert_state + + +class MetricStateTransaction(StateTransactionBase): + """A Transaction for metric states (except real time samples).""" + + def __init__(self, + provider_mdib: EntityProviderMdib, + logger: LoggerAdapter): + super().__init__(provider_mdib, logger) + self._state_updates = self.metric_state_updates + + def process_transaction(self, set_determination_time: bool) -> TransactionResultProtocol: + """Process transaction and create a TransactionResult.""" + if set_determination_time: + for tr_item in self._state_updates.values(): + if tr_item.new is not None and tr_item.new.MetricValue is not None: + tr_item.new.MetricValue.DeterminationTime = time.time() + proc = TransactionResult() + if self._state_updates: + proc.new_mdib_version = self.new_mdib_version + updates = self._handle_state_updates(self._state_updates) + proc.metric_updates.extend(updates) + return proc + + @staticmethod + def _is_correct_state_type(state: AbstractStateProtocol) -> bool: + return state.is_metric_state and not state.is_realtime_sample_array_metric_state + + +class ComponentStateTransaction(StateTransactionBase): + """A Transaction for component states.""" + + def __init__(self, + provider_mdib: EntityProviderMdib, + logger: LoggerAdapter): + super().__init__(provider_mdib, logger) + self._state_updates = self.component_state_updates + + def process_transaction(self, set_determination_time: bool) -> TransactionResultProtocol: # noqa: ARG002 + """Process transaction and create a TransactionResult.""" + proc = TransactionResult() + if self._state_updates: + proc.new_mdib_version = self.new_mdib_version + updates = self._handle_state_updates(self._state_updates) + proc.comp_updates.extend(updates) + return proc + + @staticmethod + def _is_correct_state_type(state: AbstractStateProtocol) -> bool: + return state.is_component_state + + +class RtStateTransaction(StateTransactionBase): + """A Transaction for real time sample states.""" + + def __init__(self, + provider_mdib: EntityProviderMdib, + logger: LoggerAdapter): + super().__init__(provider_mdib, logger) + self._state_updates = self.rt_sample_state_updates + + def process_transaction(self, set_determination_time: bool) -> TransactionResultProtocol: # noqa: ARG002 + """Process transaction and create a TransactionResult.""" + proc = TransactionResult() + if self._state_updates: + proc.new_mdib_version = self.new_mdib_version + updates = self._handle_state_updates(self._state_updates) + proc.rt_updates.extend(updates) + return proc + + @staticmethod + def _is_correct_state_type(state: AbstractStateProtocol) -> bool: + return state.is_realtime_sample_array_metric_state + + +class OperationalStateTransaction(StateTransactionBase): + """A Transaction for operational states.""" + + def __init__(self, + provider_mdib: EntityProviderMdib, + logger: LoggerAdapter): + super().__init__(provider_mdib, logger) + self._state_updates = self.operational_state_updates + + def process_transaction(self, set_determination_time: bool) -> TransactionResultProtocol: # noqa: ARG002 + """Process transaction and create a TransactionResult.""" + proc = TransactionResult() + if self._state_updates: + proc.new_mdib_version = self.new_mdib_version + updates = self._handle_state_updates(self._state_updates) + proc.op_updates.extend(updates) + return proc + + @staticmethod + def _is_correct_state_type(state: AbstractStateProtocol) -> bool: + return state.is_operational_state + + +class ContextStateTransaction(_TransactionBase): + """A Transaction for context states.""" + + def __init__(self, + provider_mdib: EntityProviderMdib, + logger: LoggerAdapter): + super().__init__(provider_mdib, logger) + self._state_updates = self.context_state_updates + + def write_entity(self, entity: ProviderMultiStateEntity, + modified_handles: list[str], + adjust_version_counter: bool = True): + """Insert or update a context state in mdib.""" + internal_entity = self._mdib.internal_entities[entity.descriptor.Handle] + + for handle in modified_handles: + state_container = entity.states.get(handle) + if state_container is None: + # a deleted state : this cannot be communicated via notification. + # delete it internal_entity anf that is all + if handle in internal_entity.states: + internal_entity.states.pop(handle) + else: + raise KeyError(f'invalid handle {handle}!') + continue + if not state_container.is_context_state: + raise ApiUsageError('Transaction only handles context states!') + + old_state = internal_entity.states.get(state_container.Handle) + tmp = copy.deepcopy(state_container) # do not touch original entity of user + + if old_state is None: + # this is a new state + tmp.descriptor_container = internal_entity.descriptor + if adjust_version_counter: + tmp.DescriptorVersion = internal_entity.descriptor.DescriptorVersion + # look for previously existing state with same handle + old_state_version = self._mdib.state_handle_version_lookup.get(tmp.Handle) + if old_state_version: + tmp.StateVersion = old_state_version + 1 + # update + elif adjust_version_counter: + tmp.DescriptorVersion = internal_entity.descriptor.DescriptorVersion + tmp.StateVersion = old_state.StateVersion + 1 + + self._state_updates[state_container.Handle] = TransactionItem(old=old_state, new=tmp) + + def process_transaction(self, set_determination_time: bool) -> TransactionResultProtocol: # noqa: ARG002 + """Process transaction and create a TransactionResult.""" + proc = TransactionResult() + if self._state_updates: + proc.new_mdib_version = self.new_mdib_version + updates = self._handle_state_updates(self._state_updates) + proc.ctxt_updates.extend(updates) + return proc + + @staticmethod + def _is_correct_state_type(state: AbstractStateProtocol) -> bool: + return state.is_context_state + + +class TransactionResult: + """The transaction result. + + Data is used to create notifications. + """ + + def __init__(self): + # states and descriptors that were modified are stored here: + self.new_mdib_version: int | None = None + self.descr_updated: list[AbstractDescriptorProtocol] = [] + self.descr_created: list[AbstractDescriptorProtocol] = [] + self.descr_deleted: list[AbstractDescriptorProtocol] = [] + self.metric_updates: list[AbstractStateProtocol] = [] + self.alert_updates: list[AbstractStateProtocol] = [] + self.comp_updates: list[AbstractStateProtocol] = [] + self.ctxt_updates: list[AbstractMultiStateProtocol] = [] + self.op_updates: list[AbstractStateProtocol] = [] + self.rt_updates: list[AbstractStateProtocol] = [] + + @property + def has_descriptor_updates(self) -> bool: + """Return True if at least one descriptor is in result.""" + return len(self.descr_updated) > 0 or len(self.descr_created) > 0 or len(self.descr_deleted) > 0 + + def all_states(self) -> list[AbstractStateProtocol]: + """Return all states in this transaction.""" + return self.metric_updates + self.alert_updates + self.comp_updates + self.ctxt_updates \ + + self.op_updates + self.rt_updates + + def get_state_updates_list(self, descriptor: AbstractDescriptorProtocol) -> list: + """Return the list that stores updated states of this descriptor.""" + if descriptor.is_context_descriptor: + return self.ctxt_updates + if descriptor.is_alert_descriptor: + return self.alert_updates + if descriptor.is_realtime_sample_array_metric_descriptor: + return self.rt_updates + if descriptor.is_metric_descriptor: + return self.metric_updates + if descriptor.is_operational_descriptor: + return self.op_updates + if descriptor.is_component_descriptor: + return self.comp_updates + raise ValueError(f'do not know how to handle {descriptor}') + + +_transaction_type_lookup = {TransactionType.descriptor: DescriptorTransaction, + TransactionType.alert: AlertStateTransaction, + TransactionType.metric: MetricStateTransaction, + TransactionType.operational: OperationalStateTransaction, + TransactionType.context: ContextStateTransaction, + TransactionType.component: ComponentStateTransaction, + TransactionType.rt_sample: RtStateTransaction} + + +def mk_transaction(provider_mdib: EntityProviderMdib, + transaction_type: TransactionType, + logger: LoggerAdapter) -> AnyTransactionManagerProtocol: + """Create a transaction according to transaction_type.""" + return _transaction_type_lookup[transaction_type](provider_mdib, logger) diff --git a/src/sdc11073/mdib/consumermdib.py b/src/sdc11073/mdib/consumermdib.py index 61d47f80..f42a21d6 100644 --- a/src/sdc11073/mdib/consumermdib.py +++ b/src/sdc11073/mdib/consumermdib.py @@ -1,3 +1,4 @@ +"""The module contains the implementation of Consumermdib.""" from __future__ import annotations import enum @@ -12,6 +13,7 @@ from sdc11073 import loghelper from sdc11073 import observableproperties as properties from sdc11073.exceptions import ApiUsageError + from . import mdibbase from .consumermdibxtra import ConsumerMdibMethods @@ -31,9 +33,13 @@ OperationInvokedReport, ) - from .statecontainers import (AbstractStateContainer, - AbstractContextStateContainer, - RealTimeSampleArrayMetricStateContainer) + from .entityprotocol import EntityGetterProtocol + from .mdibbase import MdibVersionGroup + from .statecontainers import ( + AbstractContextStateContainer, + AbstractStateContainer, + RealTimeSampleArrayMetricStateContainer, + ) @dataclass @@ -140,6 +146,7 @@ class _BufferedData: class ConsumerMdibState(enum.Enum): """ConsumerMdib can be in one of these states.""" + initializing = enum.auto() # the state during reload_all() initialized = enum.auto() # the state when mdib is in sync with provider invalid = enum.auto() # the state when mdib is not in sync with provider @@ -180,6 +187,7 @@ def __init__(self, # a buffer for notifications that are received before initial get_mdib is done self._buffered_notifications = [] self._buffered_notifications_lock = Lock() + self.entities: EntityGetterProtocol = mdibbase.EntityGetter(self) @property def xtra(self) -> Any: @@ -233,8 +241,8 @@ def reload_all(self): mdib_version_group = response.mdib_version_group self.mdib_version = mdib_version_group.mdib_version - self._logger.info('setting initial mdib version to {}', - mdib_version_group.mdib_version) # noqa: PLE1205 + self._logger.info('setting initial mdib version to %d', + mdib_version_group.mdib_version) self.sequence_id = mdib_version_group.sequence_id self._logger.info('setting initial sequence id to {}', mdib_version_group.sequence_id) # noqa: PLE1205 if mdib_version_group.instance_id != self.instance_id: @@ -318,13 +326,14 @@ def _can_accept_mdib_version(self, new_mdib_version: int, log_prefix: str) -> bo # it is possible to receive multiple notifications with the same mdib version => compare ">=" return new_mdib_version >= self.mdib_version - def _check_sequence_or_instance_id_changed(self, mdib_version_group): + def _check_sequence_or_instance_id_changed(self, mdib_version_group: MdibVersionGroup): """Check if sequence id and instance id are still the same. If not, - set state member to invalid - set the observable "sequence_or_instance_id_changed_event" in a thread. - This allows to implement an observer that can directly call reload_all without blocking the consumer.""" + This allows to implement an observer that can directly call reload_all without blocking the consumer. + """ if mdib_version_group.sequence_id == self.sequence_id and mdib_version_group.instance_id == self.instance_id: return if self._state == ConsumerMdibState.initialized: @@ -441,10 +450,13 @@ def process_incoming_metric_states_report(self, mdib_version_group: MdibVersionG def _process_incoming_metric_states_report(self, mdib_version_group: MdibVersionGroupReader, report: EpisodicMetricReport): - """Check mdib version, if okay: + """Check mdib version. + + If okay: - update mdib. - - update observable - Call this method only if mdib_lock is already acquired.""" + - update observable. + Call this method only if mdib_lock is already acquired. + """ states_by_handle = {} try: if self._can_accept_mdib_version(mdib_version_group.mdib_version, 'metric states'): @@ -464,10 +476,13 @@ def process_incoming_alert_states_report(self, mdib_version_group: MdibVersionGr def _process_incoming_alert_states_report(self, mdib_version_group: MdibVersionGroupReader, report: EpisodicAlertReport): - """Check mdib version, if okay: + """Check mdib version. + + If okay: - update mdib. - - update observable - Call this method only if mdib_lock is already acquired.""" + - update observable. + Call this method only if mdib_lock is already acquired. + """ states_by_handle = {} try: if self._can_accept_mdib_version(mdib_version_group.mdib_version, 'alert states'): @@ -487,10 +502,13 @@ def process_incoming_operational_states_report(self, mdib_version_group: MdibVer def _process_incoming_operational_states_report(self, mdib_version_group: MdibVersionGroupReader, report: OperationInvokedReport): - """Check mdib version, if okay: + """Check mdib version. + + If okay: - update mdib. - - update observable - Call this method only if mdib_lock is already acquired.""" + - update observable. + Call this method only if mdib_lock is already acquired. + """ states_by_handle = {} try: if self._can_accept_mdib_version(mdib_version_group.mdib_version, 'operational states'): @@ -510,10 +528,13 @@ def process_incoming_context_states_report(self, mdib_version_group: MdibVersion def _process_incoming_context_states_report(self, mdib_version_group: MdibVersionGroupReader, report: EpisodicContextReport): - """Check mdib version, if okay: + """Check mdib version. + + If okay: - update mdib. - - update observable - Call this method only if mdib_lock is already acquired.""" + - update observable. + Call this method only if mdib_lock is already acquired. + """ states_by_handle = {} try: if self._can_accept_mdib_version(mdib_version_group.mdib_version, 'context states'): @@ -533,10 +554,13 @@ def process_incoming_component_states_report(self, mdib_version_group: MdibVersi def _process_incoming_component_states_report(self, mdib_version_group: MdibVersionGroupReader, report: EpisodicComponentReport): - """Check mdib version, if okay: + """Check mdib version. + + If okay: - update mdib. - - update observable - Call this method only if mdib_lock is already acquired.""" + - update observable. + Call this method only if mdib_lock is already acquired. + """ states_by_handle = {} try: if self._can_accept_mdib_version(mdib_version_group.mdib_version, 'component states'): @@ -546,7 +570,7 @@ def _process_incoming_component_states_report(self, mdib_version_group: MdibVers self.component_by_handle = states_by_handle # update observable def process_incoming_waveform_states(self, mdib_version_group: MdibVersionGroupReader, - state_containers: list[RealTimeSampleArrayMetricStateContainer] + state_containers: list[RealTimeSampleArrayMetricStateContainer], ) -> dict[str, RealTimeSampleArrayMetricStateContainer] | None: """Check mdib_version_group and process state_containers it if okay.""" if not self._pre_check_report_ok(mdib_version_group, state_containers, @@ -557,10 +581,13 @@ def process_incoming_waveform_states(self, mdib_version_group: MdibVersionGroupR def _process_incoming_waveform_states(self, mdib_version_group: MdibVersionGroupReader, state_containers: list[RealTimeSampleArrayMetricStateContainer]): - """Check mdib version, if okay: + """Check mdib version. + + If okay: - update mdib. - - update observable - Call this method only if mdib_lock is already acquired.""" + - update observable. + Call this method only if mdib_lock is already acquired. + """ states_by_handle = {} try: if self._can_accept_mdib_version(mdib_version_group.mdib_version, 'waveform states'): @@ -609,12 +636,16 @@ def process_incoming_description_modifications(self, mdib_version_group: MdibVer with self.mdib_lock: self._process_incoming_description_modifications(mdib_version_group, report) - def _process_incoming_description_modifications(self, mdib_version_group: MdibVersionGroupReader, + def _process_incoming_description_modifications(self, # noqa: PLR0915, PLR0912, C901 + mdib_version_group: MdibVersionGroupReader, report: DescriptionModificationReport): - """Check mdib version, if okay: + """Check mdib version. + + If okay: - update mdib. - - update observables - Call this method only if mdib_lock is already acquired.""" + - update observables. + Call this method only if mdib_lock is already acquired. + """ def multi_key(st_container: AbstractStateContainer) -> mdibbase.StatesLookup | mdibbase.MultiStatesLookup: return self.context_states if st_container.is_context_state else self.states diff --git a/src/sdc11073/mdib/entityprotocol.py b/src/sdc11073/mdib/entityprotocol.py new file mode 100644 index 00000000..160b9ebf --- /dev/null +++ b/src/sdc11073/mdib/entityprotocol.py @@ -0,0 +1,109 @@ +"""The module contains protocol definitions for the entity interface of mdib.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar, Protocol, Union + +if TYPE_CHECKING: + from collections.abc import Iterable + + from lxml.etree import QName + + from sdc11073.mdib.descriptorcontainers import AbstractDescriptorProtocol + from sdc11073.mdib.statecontainers import AbstractMultiStateProtocol, AbstractStateProtocol + from sdc11073.xml_types.pm_types import CodedValue, Coding + + +class EntityProtocol(Protocol): # pragma: no cover + """The protocol defines the interface of single-state entities.""" + + descriptor: AbstractDescriptorProtocol + state: AbstractStateProtocol + is_multi_state: ClassVar[bool] + + @property + def handle(self) -> str: + """Return the parent handle of the descriptor.""" + + @property + def parent_handle(self) -> str: + """Return the parent handle of the descriptor.""" + + @property + def node_type(self) -> QName: + """Return the node type of the descriptor.""" + + def update(self): + """Update entity with current mdib data.""" + + +class MultiStateEntityProtocol(Protocol): # pragma: no cover + """The protocol defines the interface of multi-state entities.""" + + descriptor: AbstractDescriptorProtocol + states: dict[str, AbstractMultiStateProtocol] # key is the Handle member of state + is_multi_state: bool + node_type: QName + handle: str + parent_handle: str + + def update(self): + """Update entity with current data in mdib.""" + + def new_state(self, handle: str | None = None) -> AbstractMultiStateProtocol: + """Create a new state.""" + +EntityTypeProtocol = Union[EntityProtocol, MultiStateEntityProtocol] + + +# Todo: should node_type be QName (this assumes that we talk XML) or just Any to be generic? + +class EntityGetterProtocol(Protocol): # pragma: no cover + """The protocol defines a way to access mdib data as entities. + + This representation is independent of the internal mdib organization. + The entities returned by the provided getter methods contain copies of the internal mdib data. + Changing the data does not change data in the mdib. + Use the EntityTransactionProtocol to write data back to the mdib. + """ + + def by_handle(self, handle: str) -> EntityTypeProtocol | None: + """Return entity with given descriptor handle.""" + ... + + def by_context_handle(self, handle: str) -> MultiStateEntityProtocol | None: + """Return multi state entity that contains a state with given handle.""" + ... + + def by_node_type(self, node_type: QName) -> list[EntityTypeProtocol]: + """Return all entities with given node type.""" + ... + + def by_parent_handle(self, parent_handle: str | None) -> list[EntityTypeProtocol]: + """Return all entities with given parent handle.""" + ... + + def by_coding(self, coding: Coding) -> list[EntityTypeProtocol]: + """Return all entities with equivalent Coding.""" + ... + + def by_coded_value(self, coded_value: CodedValue) -> list[EntityTypeProtocol]: + """Return all entities with equivalent CodedValue.""" + ... + + def items(self) -> Iterable[tuple[str, EntityTypeProtocol]]: + """Return items of a dictionary.""" + ... + + def __len__(self) -> int: + """Return number of entities.""" + ... + + +class ProviderEntityGetterProtocol(EntityGetterProtocol): # pragma: no cover + """The protocol adds the new_entity method to EntityGetterProtocol.""" + + def new_entity(self, + node_type: QName, + handle: str, + parent_handle: str | None) -> EntityTypeProtocol: + """Create an entity.""" diff --git a/src/sdc11073/mdib/mdibbase.py b/src/sdc11073/mdib/mdibbase.py index 1ecd9c21..85bc49bb 100644 --- a/src/sdc11073/mdib/mdibbase.py +++ b/src/sdc11073/mdib/mdibbase.py @@ -1,6 +1,9 @@ +"""The module contains the implementation of MdibBase plus entity interface.""" from __future__ import annotations +import copy import traceback +import uuid from dataclasses import dataclass from threading import Lock from typing import TYPE_CHECKING, Any @@ -13,15 +16,18 @@ from sdc11073.xml_types.pm_types import Coding, have_matching_codes if TYPE_CHECKING: + from collections.abc import Iterable + + from lxml.etree import QName + + from sdc11073 import xml_utils from sdc11073.definitions_base import BaseDefinitions from sdc11073.loghelper import LoggerAdapter from sdc11073.xml_types.pm_types import CodedValue - from sdc11073 import xml_utils from .descriptorcontainers import AbstractDescriptorContainer, AbstractOperationDescriptorContainer from .statecontainers import AbstractMultiStateContainer, AbstractStateContainer - @dataclass class MdibVersionGroup: """These 3 values define an mdib version.""" @@ -208,20 +214,140 @@ def set_version(self, obj: AbstractMultiStateContainer): obj.StateVersion = version + 1 -@dataclass -class Entity: +class _EntityBase: + + def __init__(self, mdib: MdibBase, descriptor: AbstractDescriptorContainer): + self._mdib = mdib + self.descriptor = descriptor + + @property + def handle(self) -> str: + return self.descriptor.Handle + + @property + def parent_handle(self) -> str: + return self.descriptor.parent_handle + + @property + def node_type(self) -> QName: + return self.descriptor.NODETYPE + + def update(self): + """Update the entity from current data in mdib.""" + orig = self._mdib.descriptions.get_one(self.handle) + self.descriptor.update_from_other_container(orig) + +class Entity(_EntityBase): """Groups descriptor and state.""" - descriptor: AbstractDescriptorContainer - state: AbstractStateContainer + def __init__(self, mdib: MdibBase, descriptor: AbstractDescriptorContainer, state: AbstractStateContainer): + super().__init__(mdib, descriptor) + self.state = state -@dataclass -class MultiStateEntity: + @property + def is_multi_state(self) -> bool: + """Return False because this is not a multi state entity.""" + return False + + def update(self): + """Update the entity from current data in mdib.""" + super().update() + orig = self._mdib.states.get_one(self.handle) + self.state.update_from_other_container(orig) + + +class MultiStateEntity(_EntityBase): """Groups descriptor and list of multi-states.""" - descriptor: AbstractDescriptorContainer - states: list[AbstractMultiStateContainer] + def __init__(self, mdib: MdibBase, descriptor: AbstractDescriptorContainer, + states: list[AbstractMultiStateContainer]): + super().__init__(mdib, descriptor) + self.states: dict[str, AbstractMultiStateContainer] = {s.Handle: s for s in states} + + @property + def is_multi_state(self) -> bool: + """Return True because this is a multi state entity.""" + return True + + def update(self): + """Update the entity from current data in mdib.""" + super().update() + + all_orig_states = self._mdib.context_states.descriptor_handle.get(self.handle) + states_dict = { st.Handle: st for st in all_orig_states} + # update existing states, remove deleted ones + for state in list(self.states.values()): + try: + orig = states_dict[state.Handle] + state.update_from_other_container(orig) + except KeyError: + self.states.pop(state.Handle) + # add new states + for handle, _ in states_dict.items(): + if handle not in self.states: + self.states[handle] = states_dict[handle].mk_copy() + + def new_state(self, state_handle: str | None = None) -> AbstractMultiStateContainer: + """Create a new state.""" + if state_handle in self.states: + raise ValueError(f'State handle {state_handle} already exists in {self.__class__.__name__}, handle = {self.handle}') + cls = self._mdib.data_model.get_state_container_class(self.descriptor.STATE_QNAME) + state = cls(descriptor_container=self.descriptor) + state.Handle = state_handle or uuid.uuid4().hex + self.states[state.Handle] = state + return state + + +class EntityGetter: + """Implementation of EntityGetterProtocol for MdibBase.""" + + def __init__(self, mdib: MdibBase): + self._mdib = mdib + + def by_handle(self, handle: str) -> Entity | MultiStateEntity | None: + """Return entity with given handle.""" + descriptor = self._mdib.descriptions.handle.get_one(handle, allow_none=True) + if descriptor is None: + return None + return self._mk_entity(descriptor) + + def by_node_type(self, node_type: QName) -> list[Entity | MultiStateEntity]: + """Return all entities with given node type.""" + descriptors = self._mdib.descriptions.NODETYPE.get(node_type, []) + return [self._mk_entity(d) for d in descriptors] + + def by_parent_handle(self, parent_handle: str | None) -> list[Entity | MultiStateEntity]: + """Return all entities with descriptors parent_handle == provided parent_handle.""" + descriptors = self._mdib.descriptions.parent_handle.get(parent_handle, []) + return [self._mk_entity(d) for d in descriptors] + + def by_coding(self, coding: Coding) -> list[Entity | MultiStateEntity]: + """Return all entities with descriptors type are equivalent to codeding.""" + descriptors = [d for d in self._mdib.descriptions.objects if d.Type.is_equivalent(coding)] + return [self._mk_entity(d) for d in descriptors] + + def by_coded_value(self, coded_value: CodedValue) -> list[Entity | MultiStateEntity]: + """Return all entities with descriptors type are equivalent to coded_value.""" + descriptors = [d for d in self._mdib.descriptions.objects if d.Type.is_equivalent(coded_value)] + return [self._mk_entity(d) for d in descriptors] + + def _mk_entity(self, descriptor: AbstractDescriptorContainer) -> Entity | MultiStateEntity: + if descriptor.is_context_descriptor: + states = self._mdib.context_states.descriptor_handle.get(descriptor.Handle, []) + return MultiStateEntity(self._mdib,copy.deepcopy(descriptor), copy.deepcopy(states)) + state = self._mdib.states.descriptor_handle.get_one(descriptor.Handle) + return Entity(self._mdib, copy.deepcopy(descriptor), copy.deepcopy(state)) + + def items(self) -> Iterable[tuple[str, [Entity | MultiStateEntity]]]: + """Return the items of a dictionary.""" + for descriptor in self._mdib.descriptions.objects: + yield descriptor.Handle, self._mk_entity(descriptor) + + def __len__(self) -> int: + """Return number of entities.""" + return len(self._mdib.descriptions.objects) + class MdibBase: @@ -264,6 +390,7 @@ def __init__(self, sdc_definitions: type[BaseDefinitions], logger: LoggerAdapter self.mdstate_version = 0 self.mddescription_version = 0 + @property def logger(self) -> LoggerAdapter: """Return the logger.""" @@ -511,6 +638,7 @@ def select_descriptors(self, *codings: list[Coding | CodedValue | str]) -> list[ It is not necessary that path starts at the top of a mds, it can start anywhere. :param codings: each element can be a string (which is handled as a Coding with DEFAULT_CODING_SYSTEM), a Coding or a CodedValue. + """ selected_objects = self.descriptions.objects # start with all objects for counter, coding in enumerate(codings): @@ -591,13 +719,13 @@ def get_entity(self, handle: str) -> Entity: """Return descriptor and state as Entity.""" descr = self.descriptions.handle.get_one(handle) state = self.states.descriptor_handle.get_one(handle) - return Entity(descr, state) + return Entity(self, descr, state) def get_context_entity(self, handle: str) -> MultiStateEntity: """Return descriptor and states as MultiStateEntity.""" descr = self.descriptions.handle.get_one(handle) - states = self.context_states.descriptor_handle.get(handle) - return MultiStateEntity(descr, states) + states = self.context_states.descriptor_handle.get(handle, []) + return MultiStateEntity(self, descr, states) def has_multiple_mds(self) -> bool: """Check if there is more than one mds in mdib (convenience method).""" diff --git a/src/sdc11073/mdib/mdibprotocol.py b/src/sdc11073/mdib/mdibprotocol.py new file mode 100644 index 00000000..5bb25b85 --- /dev/null +++ b/src/sdc11073/mdib/mdibprotocol.py @@ -0,0 +1,54 @@ +"""The module defines the interface of a provider mdib.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + from sdc11073.definitions_base import AbstractDataModel, BaseDefinitions + + from .entityprotocol import ProviderEntityGetterProtocol + from .transactionsprotocol import ( + ContextStateTransactionManagerProtocol, + DescriptorTransactionManagerProtocol, + StateTransactionManagerProtocol, + ) + + +class ProviderMdibProtocol(Protocol): # pragma: no cover + """The interface of a provider mdib. + + This interface only expects the ProviderEntityGetterProtocol. + The old implementation with separate lookups for descriptors, states and context states + is not part of this protocol. + """ + + entities: ProviderEntityGetterProtocol + sdc_definitions: type[BaseDefinitions] + data_model: AbstractDataModel + + def descriptor_transaction(self) -> AbstractContextManager[DescriptorTransactionManagerProtocol]: + """Return a transaction.""" + + def context_state_transaction(self) -> AbstractContextManager[ContextStateTransactionManagerProtocol]: + """Return a transaction.""" + + def alert_state_transaction(self, set_determination_time: bool = True) \ + -> AbstractContextManager[StateTransactionManagerProtocol]: + """Return a transaction.""" + + def metric_state_transaction(self, set_determination_time: bool = True) \ + -> AbstractContextManager[StateTransactionManagerProtocol]: + """Return a transaction.""" + + def rt_sample_state_transaction(self, set_determination_time: bool = False) \ + -> AbstractContextManager[StateTransactionManagerProtocol]: + """Return a transaction.""" + + def component_state_transaction(self) -> AbstractContextManager[StateTransactionManagerProtocol]: + """Return a transaction.""" + + def operational_state_transaction(self) -> AbstractContextManager[StateTransactionManagerProtocol]: + """Return a transaction.""" diff --git a/src/sdc11073/mdib/providermdib.py b/src/sdc11073/mdib/providermdib.py index 2ebf30fa..477384d5 100644 --- a/src/sdc11073/mdib/providermdib.py +++ b/src/sdc11073/mdib/providermdib.py @@ -1,3 +1,4 @@ +"""The module contains the implementation of ProviderMdib with ProviderEntityGetter Protocol.""" from __future__ import annotations import uuid @@ -19,19 +20,55 @@ from .transactionsprotocol import AnyTransactionManagerProtocol, TransactionType if TYPE_CHECKING: + from lxml.etree import QName + from sdc11073.definitions_base import BaseDefinitions + from .entityprotocol import ProviderEntityGetterProtocol from .transactionsprotocol import ( ContextStateTransactionManagerProtocol, DescriptorTransactionManagerProtocol, StateTransactionManagerProtocol, - TransactionResultProtocol + TransactionResultProtocol, ) TransactionFactory = Callable[[mdibbase.MdibBase, TransactionType, LoggerAdapter], AnyTransactionManagerProtocol] +class ProviderEntityGetter(mdibbase.EntityGetter): + """Implementation of ProviderEntityGetterProtocol.""" + + def new_entity(self, + node_type: QName, + handle: str, + parent_handle: str | None) -> mdibbase.Entity | mdibbase.MultiStateEntity: + """Create an entity.""" + if (handle in self._mdib.descriptions.handle + or handle in self._mdib.context_states.handle + # or handle in self._new_entities + ): + raise ValueError('Handle already exists') + + # Todo: check if this node type is a valid child of parent + + descr_cls = self._mdib.data_model.get_descriptor_container_class(node_type) + descriptor_container = descr_cls(handle=handle, parent_handle=parent_handle) + parent_descriptor = self._mdib.descriptions.handle.get_one(parent_handle) + if self._mdib.data_model.pm_names.MdsDescriptor == parent_descriptor.NODETYPE: + descriptor_container.set_source_mds(parent_descriptor.Handle) + else: + descriptor_container.set_source_mds(parent_descriptor.source_mds) + + if descriptor_container.is_context_descriptor: + new_entity = mdibbase.MultiStateEntity(self._mdib, descriptor_container, []) + else: + state_cls = self._mdib.data_model.get_state_container_class(descriptor_container.STATE_QNAME) + state = state_cls(descriptor_container) + new_entity = mdibbase.Entity(self._mdib,descriptor_container, state) + return new_entity + + class ProviderMdib(mdibbase.MdibBase): """Device side implementation of a mdib. @@ -75,6 +112,7 @@ def __init__(self, self._transaction_factory = transaction_factory or mk_transaction self._retrievability_episodic = [] # a list of handles self.retrievability_periodic = defaultdict(list) + self.entities: ProviderEntityGetterProtocol = ProviderEntityGetter(self) @property def xtra(self) -> Any: @@ -82,7 +120,7 @@ def xtra(self) -> Any: return self._xtra @contextmanager - def _transaction_manager(self, + def _transaction_manager(self, # noqa: PLR0912, C901 transaction_type: TransactionType, set_determination_time: bool = True) -> AbstractContextManager[ AnyTransactionManagerProtocol]: @@ -123,7 +161,6 @@ def _transaction_manager(self, if transaction_result.rt_updates: self.waveform_by_handle = {st.DescriptorHandle: st for st in transaction_result.rt_updates} - if callable(self.post_commit_handler): self.post_commit_handler(self, self.current_transaction) finally: @@ -229,6 +266,7 @@ def from_string(cls, mdib.add_description_containers(descriptor_containers) mdib.add_state_containers(state_containers) mdib.xtra.mk_state_containers_for_all_descriptors() + mdib.xtra.set_states_initial_values() mdib.xtra.update_retrievability_lists() mdib.xtra.set_all_source_mds() return mdib diff --git a/src/sdc11073/mdib/providermdibxtra.py b/src/sdc11073/mdib/providermdibxtra.py index 298c8b23..43eaa714 100644 --- a/src/sdc11073/mdib/providermdibxtra.py +++ b/src/sdc11073/mdib/providermdibxtra.py @@ -1,3 +1,4 @@ +"""The module contains extensions to the functionality of the ProviderMdib.""" from __future__ import annotations import time @@ -17,10 +18,10 @@ ) from .descriptorcontainers import AbstractDescriptorProtocol + from .entityprotocol import MultiStateEntityProtocol from .providermdib import ProviderMdib from .statecontainers import AbstractStateProtocol - class ProviderMdibMethods: """Extra methods for provider mdib tht are not core functionality.""" @@ -40,7 +41,7 @@ def ensure_location_context_descriptor(self): for system_context_descriptor in system_context_descriptors: child_location_descriptors = [d for d in location_context_descriptors if d.parent_handle == system_context_descriptor.Handle - and d.NODETYPE == pm.LocationContextDescriptor] + and pm.LocationContextDescriptor == d.NODETYPE] if not child_location_descriptors: descr_cls = mdib.data_model.get_descriptor_container_class(pm.LocationContextDescriptor) descr_container = descr_cls(handle=uuid.uuid4().hex, parent_handle=system_context_descriptor.Handle) @@ -57,7 +58,7 @@ def ensure_patient_context_descriptor(self): for system_context_descriptor in system_context_descriptors: child_location_descriptors = [d for d in patient_context_descriptors if d.parent_handle == system_context_descriptor.Handle - and d.NODETYPE == pm.PatientContextDescriptor] + and pm.PatientContextDescriptor == d.NODETYPE] if not child_location_descriptors: descr_cls = mdib.data_model.get_descriptor_container_class(pm.PatientContextDescriptor) descr_container = descr_cls(handle=uuid.uuid4().hex, parent_handle=system_context_descriptor.Handle) @@ -125,6 +126,45 @@ def mk_state_containers_for_all_descriptors(self): else: mdib.states.add_object(state) + def set_states_initial_values(self): + """Set all states to defined starting conditions. + + This method is ment to be called directly after the mdib was loaded and before the provider is published + on the network. + It changes values only internally in the mdib, no notifications are sent! + + """ + pm_names = self._mdib.data_model.pm_names + pm_types = self._mdib.data_model.pm_types + + for state in self._mdib.states.objects: + descriptor = self._mdib.descriptions.handle.get_one(state.DescriptorHandle) + if pm_names.AlertSystemDescriptor == descriptor.NODETYPE: + # alert systems are active + state.ActivationState = pm_types.AlertActivation.ON + state.SystemSignalActivation.append( + pm_types.SystemSignalActivation(manifestation=pm_types.AlertSignalManifestation.AUD, + state=pm_types.AlertActivation.ON)) + elif descriptor.is_alert_condition_descriptor: + # alert conditions are active, but not present + state.ActivationState = pm_types.AlertActivation.ON + state.Presence = False + elif descriptor.is_alert_signal_descriptor: + # alert signals are not present, and delegable signals are also not active + if descriptor.SignalDelegationSupported: + state.Location = pm_types.AlertSignalPrimaryLocation.REMOTE + state.ActivationState = pm_types.AlertActivation.OFF + state.Presence = pm_types.AlertSignalPresence.OFF + else: + state.ActivationState = pm_types.AlertActivation.ON + state.Presence = pm_types.AlertSignalPresence.OFF + elif descriptor.is_component_descriptor: + # all components are active + state.ActivationState = pm_types.ComponentActivation.ON + elif descriptor.is_operational_descriptor: + # all operations are enabled + state.OperatingMode = pm_types.OperatingMode.ENABLED + def update_retrievability_lists(self): """Update internal lists, based on current mdib descriptors.""" mdib = self._mdib @@ -178,6 +218,33 @@ def set_source_mds(self, descriptor_container: AbstractDescriptorProtocol): descriptor_container.set_source_mds(mds.Handle) + def disassociate_all(self, + entity: MultiStateEntityProtocol, + unbinding_mdib_version: int, + ignored_handle: str | None = None) -> list[str]: + """Disassociate all associated states in entity. + + The method returns a list of states that were disassociated. + :param entity: ProviderMultiStateEntity + :param ignored_handle: the context state with this Handle shall not be touched. + """ + pm_types = self._mdib.data_model.pm_types + disassociated_state_handles = [] + for state in entity.states.values(): + if state.Handle == ignored_handle: + # If state is already part of this transaction leave it also untouched, accept what the user wanted. + continue + if state.ContextAssociation != pm_types.ContextAssociation.DISASSOCIATED \ + or state.UnbindingMdibVersion is None: + # self._logger.info('disassociate %s, handle=%s', state.NODETYPE.localname, + # state.Handle) + state.ContextAssociation = pm_types.ContextAssociation.DISASSOCIATED + if state.UnbindingMdibVersion is None: + state.UnbindingMdibVersion = unbinding_mdib_version + state.BindingEndTime = time.time() + disassociated_state_handles.append(state.Handle) + return disassociated_state_handles + class DescriptorFactory: """DescriptorFactory provides some methods to make creation of descriptors easier.""" diff --git a/src/sdc11073/mdib/transactions.py b/src/sdc11073/mdib/transactions.py index b14ecdc5..0f10b2a9 100644 --- a/src/sdc11073/mdib/transactions.py +++ b/src/sdc11073/mdib/transactions.py @@ -1,5 +1,7 @@ +"""The module contains the implementations of transactions for ProviderMdib.""" from __future__ import annotations +import copy import time import uuid from typing import TYPE_CHECKING, cast @@ -18,10 +20,10 @@ from sdc11073.loghelper import LoggerAdapter from .descriptorcontainers import AbstractDescriptorProtocol + from .mdibbase import Entity, MultiStateEntity from .providermdib import ProviderMdib from .statecontainers import AbstractStateProtocol - class _TransactionBase: def __init__(self, device_mdib_container: ProviderMdib, @@ -195,6 +197,79 @@ def add_state(self, state_container: AbstractStateProtocol, adjust_state_version self._mdib.states.set_version(state_container) updates_dict[key] = TransactionItem(None, state_container) + def write_entity(self, # noqa: PLR0912, C901 + entity: Entity | MultiStateEntity, + adjust_version_counter: bool = True): + """Insert or update an entity.""" + descriptor_handle = entity.descriptor.Handle + if descriptor_handle in self.descriptor_updates: + raise ValueError(f'Entity {descriptor_handle} already in updated set!') + + tmp_descriptor = copy.deepcopy(entity.descriptor) + orig_descriptor_container = self._mdib.descriptions.handle.get_one(descriptor_handle, allow_none=True) + + if adjust_version_counter: + if orig_descriptor_container is None: + # new descriptor, update version from saved versions in mdib if exists + self._mdib.descriptions.set_version(tmp_descriptor) + else: + # update from old + tmp_descriptor.DescriptorVersion = orig_descriptor_container.DescriptorVersion + 1 + + self.descriptor_updates[descriptor_handle] = TransactionItem(orig_descriptor_container, + tmp_descriptor) + + if entity.is_multi_state: + old_states = self._mdib.context_states.descriptor_handle.get(descriptor_handle, []) + old_states_dict = {s.Handle: s for s in old_states} + for state_container in entity.states.values(): + tmp_state = copy.deepcopy(state_container) + old_state = old_states_dict.get(tmp_state.Handle) # can be None => new state + if adjust_version_counter: + tmp_state.DescriptorVersion = tmp_descriptor.DescriptorVersion + if old_state is not None: + tmp_state.StateVersion = old_state.StateVersion + 1 + else: + self._mdib.context_states.set_version(tmp_state) + + self.context_state_updates[state_container.Handle] = TransactionItem(old_state, tmp_state) + deleted_states_handles = set(old_states_dict.keys()).difference(set(entity.states.keys())) + for handle in deleted_states_handles: + del_state = old_states_dict[handle] + self.context_state_updates[handle] = TransactionItem(del_state, None) + else: + tmp_state = copy.deepcopy(entity.state) + tmp_state.descriptor_container = tmp_descriptor + old_state = self._mdib.states.descriptor_handle.get_one(descriptor_handle, allow_none=True) + if adjust_version_counter: + if old_state is not None: + tmp_state.StateVersion = old_state.StateVersion + 1 + else: + self._mdib.states.set_version(tmp_state) + + state_updates_dict = self._get_states_update(tmp_state) + state_updates_dict[entity.state.DescriptorHandle] = TransactionItem(old_state, tmp_state) + + def write_entities(self, entities: list[Entity | MultiStateEntity], adjust_version_counter: bool = True): + """Write entities in order parents first.""" + written_handles = [] + ent_dict = {ent.handle: ent for ent in entities} + while len(written_handles) < len(ent_dict): + for handle, ent in ent_dict.items(): + write_now = True + if (ent.parent_handle is not None + and ent.parent_handle in ent_dict + and ent.parent_handle not in written_handles): + # it has a parent, and parent has not been written yet + write_now = False + if write_now and handle not in written_handles: + self.write_entity(ent, adjust_version_counter) + written_handles.append(handle) + + def remove_entity(self, entity: Entity | MultiStateEntity): + """Remove existing entity from mdib.""" + self.remove_descriptor(entity.handle) + def process_transaction(self, set_determination_time: bool) -> TransactionResultProtocol: # noqa: ARG002 """Process transaction and create a TransactionResult. @@ -280,7 +355,7 @@ def _update_corresponding_state(self, descriptor_container: AbstractDescriptorPr if state_update is not None: # the state has also been updated directly in transaction. # update descriptor version - old_state, new_state = state_update + old_state, new_state = state_update.old, state_update.new else: old_state = context_state new_state = old_state.mk_copy() @@ -303,7 +378,7 @@ def _update_corresponding_state(self, descriptor_container: AbstractDescriptorPr descriptor_container.Handle, allow_none=True) if old_state is not None: new_state = old_state.mk_copy() - new_state.descriptor_container = descriptor_container # + new_state.descriptor_container = descriptor_container new_state.DescriptorVersion = descriptor_container.DescriptorVersion new_state.increment_state_version() updates_dict[descriptor_container.Handle] = TransactionItem(old_state, new_state) @@ -375,6 +450,41 @@ def actual_descriptor(self, descriptor_handle: str) -> AbstractDescriptorProtoco """Look descriptor in mdib, state transaction cannot have descriptor changes.""" return self._mdib.descriptions.handle.get_one(descriptor_handle) + def write_entity(self, entity: Entity, adjust_version_counter: bool = True): + """Insert or update an entity.""" + if entity.is_multi_state: + raise ApiUsageError(f'Transaction {self.__class__.__name__} does not handle multi state entities!') + + if not self._is_correct_state_type(entity.state): + raise ApiUsageError(f'Wrong data type in transaction! {self.__class__.__name__}, {entity.state}') + + descriptor_handle = entity.state.DescriptorHandle + old_state = self._mdib.states.descriptor_handle.get_one(entity.handle, allow_none=True) + tmp_state = copy.deepcopy(entity.state) + if adjust_version_counter: + descriptor_container = self._mdib.descriptions.handle.get_one(descriptor_handle) + tmp_state.DescriptorVersion = descriptor_container.DescriptorVersion + if old_state is not None: + # update from old state + tmp_state.StateVersion = old_state.StateVersion + 1 + else: + # new state, update version from saved versions in mdib if exists + self._mdib.states.set_version(tmp_state) + + self._state_updates[descriptor_handle] = TransactionItem(old=old_state, new=tmp_state) + + def write_entities(self, entities: list[Entity | MultiStateEntity], adjust_version_counter: bool = True): + """Write entities in order parents first.""" + for entity in entities: + # check all states before writing any of them + if entity.is_multi_state: + raise ApiUsageError(f'Transaction {self.__class__.__name__} does not handle multi state entities!') + + if not self._is_correct_state_type(entity.state): + raise ApiUsageError(f'Wrong data type in transaction! {self.__class__.__name__}, {entity.state}') + for ent in entities: + self.write_entity(ent, adjust_version_counter) + @staticmethod def _is_correct_state_type(state: AbstractStateProtocol) -> bool: # noqa: ARG004 return False @@ -568,7 +678,7 @@ def mk_context_state(self, descriptor_handle: str, return cast(AbstractMultiStateProtocol, new_state_container) def add_state(self, state_container: AbstractMultiStateProtocol, adjust_state_version: bool = True): - """Add a new state to mdib.""" + """Add a new context state to mdib.""" if not state_container.is_context_state: # prevent this for simplicity reasons raise ApiUsageError('Transaction only handles context states!') @@ -582,6 +692,7 @@ def add_state(self, state_container: AbstractMultiStateProtocol, adjust_state_ve self._mdib.context_states.set_version(state_container) self._state_updates[state_container.Handle] = TransactionItem(None, state_container) + def disassociate_all(self, context_descriptor_handle: str, ignored_handle: str | None = None) -> list[str]: @@ -611,6 +722,37 @@ def disassociate_all(self, disassociated_state_handles.append(transaction_state.Handle) return disassociated_state_handles + def write_entity(self, entity: MultiStateEntity, + modified_handles: list[str], + adjust_version_counter: bool = True): + """Insert or update a context state in mdib.""" + for handle in modified_handles: + state_container = entity.states.get(handle) + old_state = self._mdib.context_states.handle.get_one(handle, allow_none=True) + if state_container is None: + # a deleted state : this cannot be communicated via notification. + # delete it internal_entity anf that is all + if old_state is not None: + self._state_updates[handle] = TransactionItem(old=old_state, new=None) + else: + raise KeyError(f'invalid handle {handle}!') + continue + if not state_container.is_context_state: + raise ApiUsageError('Transaction only handles context states!') + + tmp = copy.deepcopy(state_container) + + if old_state is None: + # this is a new state + tmp.descriptor_container = entity.descriptor + tmp.DescriptorVersion = entity.descriptor.DescriptorVersion + if adjust_version_counter: + self._mdib.context_states.set_version(tmp) + elif adjust_version_counter: + tmp.StateVersion = old_state.StateVersion + 1 + + self._state_updates[state_container.Handle] = TransactionItem(old=old_state, new=tmp) + def process_transaction(self, set_determination_time: bool) -> TransactionResultProtocol: # noqa: ARG002 """Process transaction and create a TransactionResult.""" proc = TransactionResult() diff --git a/src/sdc11073/mdib/transactionsprotocol.py b/src/sdc11073/mdib/transactionsprotocol.py index 55a6e558..63c6d26c 100644 --- a/src/sdc11073/mdib/transactionsprotocol.py +++ b/src/sdc11073/mdib/transactionsprotocol.py @@ -1,3 +1,7 @@ +"""The module declares several protocols that are implemented by transactions. + +Only these protocols shall be used, the old way of transactions in mdib.transactions should no longer be used. +""" from __future__ import annotations from dataclasses import dataclass @@ -9,7 +13,7 @@ if TYPE_CHECKING: from .descriptorcontainers import AbstractDescriptorProtocol - + from .entityprotocol import EntityProtocol, EntityTypeProtocol, MultiStateEntityProtocol class TransactionType(Enum): """The different kinds of transactions. @@ -26,7 +30,7 @@ class TransactionType(Enum): rt_sample = 7 -class TransactionResultProtocol(Protocol): +class TransactionResultProtocol(Protocol): # pragma: no cover """TransactionResult contains all state and descriptors that were modified in the transaction. The states and descriptors are used to create the notification(s) that keep the consumers up to date. @@ -43,11 +47,12 @@ class TransactionResultProtocol(Protocol): rt_updates = list[AbstractStateProtocol] has_descriptor_updates: bool + new_mdib_version: int def all_states(self) -> list[AbstractStateProtocol]: """Return all states in this transaction.""" -class TransactionItemProtocol(Protocol): +class TransactionItemProtocol(Protocol): # pragma: no cover """A container for the old and the new version of a state or descriptor. If old is None, this is an object that is added to mdib. @@ -67,7 +72,9 @@ class TransactionItem: new: AbstractStateProtocol | AbstractDescriptorProtocol | None -class AbstractTransactionManagerProtocol(Protocol): + + +class AbstractTransactionManagerProtocol(Protocol): # pragma: no cover """Interface of a TransactionManager.""" new_mdib_version: int @@ -86,8 +93,107 @@ def process_transaction(self, set_determination_time: bool) -> TransactionResult error: bool -class DescriptorTransactionManagerProtocol(AbstractTransactionManagerProtocol): - """Interface of a TransactionManager that modifies descriptors.""" +class EntityDescriptorTransactionManagerProtocol(AbstractTransactionManagerProtocol): # pragma: no cover + """Entity based transaction manager for modification of descriptors (and associated states). + + The entity based transaction manager protocol can only be used with EntityGetter methods! + The only working approach is: + 1. Read an entity from mdib with one of the EntityGetter Methods. These methods return a + copy of the mdib data. + 2. Manipulate the copied data as required + 3. Create a transaction and write entity data back to mdib with write_entity method + """ + + def get_state_transaction_item(self, handle: str) -> TransactionItem | None: + """If transaction has a state with given handle, return the transaction-item, otherwise None.""" + + def transaction__entity(self, descriptor_handle: str) -> EntityTypeProtocol | None: + """Return the entity in open transaction if it exists. + + The descriptor can already be part of the transaction, and e.g. in pre_commit handlers of role providers + it can be necessary to have access to it. + """ + + def write_entity(self, + entity: EntityTypeProtocol, + adjust_version_counter: bool = True): + """Insert or update an entity (state and descriptor).""" + + def write_entities(self, + entities: list[EntityTypeProtocol], + adjust_version_counter: bool = True): + """Insert or update list of entities.""" + + def remove_entity(self, entity: EntityTypeProtocol): + """Remove existing descriptor from mdib.""" + + +class EntityStateTransactionManagerProtocol(AbstractTransactionManagerProtocol): # pragma: no cover + """Entity based transaction manager for modification of states. + + The entity based transaction manager protocol can only be used with EntityGetter methods! + The only working approach is: + 1. Read an entity from mdib with one of the EntityGetter Methods. These methods return a + copy of the mdib data. + 2. Manipulate the copied state as required + 3. Create a transaction and write entity data back to mdib with write_entity method + """ + + def has_state(self, descriptor_handle: str) -> bool: + """Check if transaction has a state with given handle.""" + + def write_entity(self, + entity: EntityProtocol, + adjust_version_counter: bool = True): + """Update the state of the entity.""" + + def write_entities(self, + entities: list[EntityProtocol], + adjust_version_counter: bool = True): + """Update the states of entities.""" + + +class EntityContextStateTransactionManagerProtocol(AbstractTransactionManagerProtocol): # pragma: no cover + """Entity based transaction manager for modification of context states. + + The entity based transaction manager protocol can only be used with EntityGetter methods! + The only working approach is: + 1. Read an entity from mdib with one of the EntityGetter Methods. These methods return a + copy of the mdib data. + 2. Manipulate the copied states as required + 3. Create a descriptor transaction context and write entity data back to mdib with write_entity method + """ + + def write_entity(self, entity: MultiStateEntityProtocol, + modified_handles: list[str], + adjust_version_counter: bool = True): + """Insert or update a context state in mdib.""" + + +class DescriptorTransactionManagerProtocol(EntityDescriptorTransactionManagerProtocol): # pragma: no cover + """The classic Interface of a TransactionManager that modifies descriptors. + + The classic transaction manager protocol can not be used with EntityGetter methods! + The only working approach is: + case A: update an existing descriptor: + 1. Start a descriptor transaction context + 2. call get_descriptor. This returns a copy of the descriptor in mdib + Manipulate the copied descriptor as required + 3. optional: call get_state / get_context_state. This returns a copy of the state in mdib + Manipulate the copied state as required + + case B: create a descriptor ( not context descriptor): + 1. Start a descriptor transaction context + 2. Create a new descriptor (and state instance if this is not a context state) + 3. Call add_descriptor and add_state + + case C: create a context descriptor: + 1. Start a descriptor transaction context + 2. Create a new descriptor + 3. Call mk_context_state 0... n times to add context states + + In all cases: when the transaction context is left, all before retrieved data is written back to mdib. + """ def actual_descriptor(self, descriptor_handle: str) -> AbstractDescriptorProtocol: """Look for new or updated descriptor in current transaction and in mdib.""" @@ -107,9 +213,6 @@ def get_descriptor(self, descriptor_handle: str) -> AbstractDescriptorProtocol: def has_state(self, descriptor_handle: str) -> bool: """Check if transaction has a state with given handle.""" - def get_state_transaction_item(self, handle: str) -> TransactionItemProtocol | None: - """If transaction has a state with given handle, return the transaction-item, otherwise None.""" - def add_state(self, state_container: AbstractStateProtocol, adjust_state_version: bool = True): """Add a new state to mdib.""" @@ -129,11 +232,17 @@ def mk_context_state(self, descriptor_handle: str, """Create a new ContextStateContainer.""" -class StateTransactionManagerProtocol(AbstractTransactionManagerProtocol): - """Interface of a TransactionManager that modifies states (except context states).""" +class StateTransactionManagerProtocol(EntityStateTransactionManagerProtocol): # pragma: no cover + """The classic Interface of a TransactionManager that modifies states (except context states). - def actual_descriptor(self, descriptor_handle: str) -> AbstractDescriptorProtocol: - """Look for new or updated descriptor in current transaction and in mdib.""" + The classic transaction manager protocol can not be used with EntityGetter methods! + The only working approach is: + 1. Start a descriptor transaction context + 2. call get_state. This returns a copy of the state in mdib + Manipulate the copied state as required + + When the transaction context is left, all before retrieved data is written back to mdib. + """ def has_state(self, descriptor_handle: str) -> bool: """Check if transaction has a state with given handle.""" @@ -141,9 +250,6 @@ def has_state(self, descriptor_handle: str) -> bool: def get_state_transaction_item(self, handle: str) -> TransactionItemProtocol | None: """If transaction has a state with given handle, return the transaction-item, otherwise None.""" - def add_state(self, state_container: AbstractStateProtocol, adjust_state_version: bool = True): - """Add a new state to mdib.""" - def unget_state(self, state_container: AbstractStateProtocol): """Forget a state that was provided before by a get_state or add_state call.""" @@ -151,8 +257,19 @@ def get_state(self, descriptor_handle: str) -> AbstractStateProtocol: """Read a state from mdib and add it to the transaction.""" -class ContextStateTransactionManagerProtocol(StateTransactionManagerProtocol): - """Interface of a TransactionManager that modifies context states.""" +class ContextStateTransactionManagerProtocol(EntityContextStateTransactionManagerProtocol): # pragma: no cover + """The classic Interface of a TransactionManager that modifies context states. + + The classic transaction manager protocol can not be used with EntityGetter methods! + The only working approach is: + 1. Start a descriptor transaction context + 2a.Call get_context_state if you want to manipulate an existing context state. + This returns a copy of the state in mdib. Manipulate the copied state as required. + 2b.Call mk_context_state if you want to create a new context state. + Manipulate the state as required. + + When the transaction context is left, all before retrieved data is written back to mdib. + """ def get_context_state(self, context_state_handle: str) -> AbstractMultiStateProtocol: """Read a ContextState from mdib with given state handle.""" @@ -168,6 +285,10 @@ def disassociate_all(self, ignored_handle: str | None = None) -> list[str]: """Disassociate all associated states in mdib for context_descriptor_handle.""" +AnyEntityTransactionManagerProtocol = Union[EntityContextStateTransactionManagerProtocol, + EntityStateTransactionManagerProtocol, + EntityDescriptorTransactionManagerProtocol] + AnyTransactionManagerProtocol = Union[ContextStateTransactionManagerProtocol, StateTransactionManagerProtocol, diff --git a/src/sdc11073/provider/components.py b/src/sdc11073/provider/components.py index 4e92df6f..1bf72988 100644 --- a/src/sdc11073/provider/components.py +++ b/src/sdc11073/provider/components.py @@ -1,3 +1,7 @@ +"""The module declares the components of a provider. + +This serves as dependency injection. +""" from __future__ import annotations import copy @@ -31,10 +35,10 @@ from lxml import etree from sdc11073 import provider - from sdc11073.mdib.providermdib import ProviderMdib + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol + from sdc11073.namespaces import PrefixNamespace from sdc11073.provider.servicesfactory import HostedServices from sdc11073.xml_types.wsd_types import ScopesType - from sdc11073.namespaces import PrefixNamespace from .sco import AbstractScoOperationsRegistry from .subscriptionmgr_base import SubscriptionManagerProtocol @@ -59,7 +63,7 @@ class SdcProviderComponents: subscriptions_manager_class: dict[str, type[SubscriptionManagerProtocol]] = None role_provider_class: type = None waveform_provider_class: type | None = None - scopes_factory: Callable[[ProviderMdib], ScopesType] = None + scopes_factory: Callable[[ProviderMdibProtocol], ScopesType] = None hosted_services: dict = None additional_schema_specs: list[PrefixNamespace] = field(default_factory=list) diff --git a/src/sdc11073/provider/operations.py b/src/sdc11073/provider/operations.py index 0b900a25..2b33bf14 100644 --- a/src/sdc11073/provider/operations.py +++ b/src/sdc11073/provider/operations.py @@ -1,3 +1,7 @@ +"""The module implements the different kinds of BICEPS operations. + +These operations are the instances inside a provider that perform an operation by changing the mdib content. +""" from __future__ import annotations import inspect @@ -13,6 +17,7 @@ if TYPE_CHECKING: from lxml import etree + from sdc11073.mdib.descriptorcontainers import AbstractDescriptorProtocol from sdc11073.mdib.providermdib import ProviderMdib from sdc11073.pysoap.soapenvelope import ReceivedSoapMessage @@ -84,8 +89,7 @@ def __init__(self, # noqa: PLR0913 """ self._logger = loghelper.get_logger_adapter(f'sdc.device.op.{self.__class__.__name__}', log_prefix) self._mdib: ProviderMdib | None = None - self._descriptor_container = None - self._operation_state_container = None + self._operation_entity = None self.handle: str = handle self.operation_target_handle: str = operation_target_handle # documentation of operation_target_handle: @@ -102,7 +106,7 @@ def __init__(self, # noqa: PLR0913 @property def descriptor_container(self) -> AbstractDescriptorProtocol: # noqa: D102 - return self._descriptor_container + return self._operation_entity.descriptor def execute_operation(self, soap_request: ReceivedSoapMessage, @@ -122,10 +126,10 @@ def check_timeout(self): """Set on_timeout observable if timeout is detected.""" if self.last_called_time is None: return - if self._descriptor_container.InvocationEffectiveTimeout is None: + if self._operation_entity.descriptor.InvocationEffectiveTimeout is None: return age = time.time() - self.last_called_time - if age < self._descriptor_container.InvocationEffectiveTimeout: + if age < self._operation_entity.descriptor.InvocationEffectiveTimeout: return if self._timeout_handler is not None: self._timeout_handler(self) @@ -135,45 +139,40 @@ def set_mdib(self, mdib: ProviderMdib, parent_descriptor_handle: str): """Set mdib reference. The operation needs to know the mdib that it operates on. - This is called by SubscriptionManager on registration. + This method is called by SubscriptionManager on registration. Needs to be implemented by derived classes if specific things have to be initialized. """ if self._mdib is not None: raise ApiUsageError('Mdib is already set') self._mdib = mdib self._logger.log_prefix = mdib.log_prefix # use same prefix as mdib for logging - self._descriptor_container = self._mdib.descriptions.handle.get_one(self.handle, allow_none=True) - if self._descriptor_container is not None: + self._operation_entity = self._mdib.entities.by_handle(self.handle) + if self._operation_entity is not None: # there is already a descriptor self._logger.debug('descriptor for operation "%s" is already present, re-using it', self.handle) else: - cls = mdib.data_model.get_descriptor_container_class(self.OP_DESCR_QNAME) - self._descriptor_container = cls(self.handle, parent_descriptor_handle) + self._operation_entity = self._mdib.entities.new_entity(self.OP_DESCR_QNAME, + self.handle, + parent_descriptor_handle ) self._init_operation_descriptor_container() - # ToDo: transaction context for flexibility to add operations at runtime - mdib.descriptions.add_object(self._descriptor_container) - - self._operation_state_container = self._mdib.states.descriptor_handle.get_one(self.handle, allow_none=True) - if self._operation_state_container is not None: - self._logger.debug('operation state for operation "%s" is already present, re-using it', self.handle) - else: - cls = mdib.data_model.get_state_container_class(self.OP_STATE_QNAME) - self._operation_state_container = cls(self._descriptor_container) - mdib.states.add_object(self._operation_state_container) + with self._mdib.descriptor_transaction() as mgr: + mgr.write_entity(self._operation_entity) def _init_operation_descriptor_container(self): - self._descriptor_container.OperationTarget = self.operation_target_handle + self._operation_entity.descriptor.OperationTarget = self.operation_target_handle if self._coded_value is not None: - self._descriptor_container.Type = self._coded_value + self._operation_entity.descriptor.Type = self._coded_value def set_operating_mode(self, mode: OperatingMode): """Set OperatingMode member in state in transaction context.""" + entity = self._mdib.entities.by_handle(self.handle) + entity.state.OperatingMode = mode + with self._mdib.operational_state_transaction() as mgr: - state = mgr.get_state(self.handle) - state.OperatingMode = mode + mgr.write_entity(entity) def __str__(self): - code = None if self._descriptor_container is None else self._descriptor_container.Type + code = None if self._operation_entity is None else self._operation_entity.descriptor.Type return (f'{self.__class__.__name__} handle={self.handle} code={code} ' f'operation-target={self.operation_target_handle}') diff --git a/src/sdc11073/provider/providerimpl.py b/src/sdc11073/provider/providerimpl.py index e4146d99..68cd86ff 100644 --- a/src/sdc11073/provider/providerimpl.py +++ b/src/sdc11073/provider/providerimpl.py @@ -1,3 +1,4 @@ +"""The module implements the SdcProvider.""" from __future__ import annotations import copy @@ -25,19 +26,20 @@ from sdc11073.xml_types.addressing_types import EndpointReferenceType from sdc11073.xml_types.dpws_types import HostServiceType, ThisDeviceType, ThisModelType from sdc11073.xml_types.wsd_types import ProbeMatchesType, ProbeMatchType -from sdc11073.roles.protocols import ProductProtocol, WaveformProviderProtocol # import here for code cov. :( from .periodicreports import PeriodicReportsHandler, PeriodicReportsNullHandler if TYPE_CHECKING: from enum import Enum + from sdc11073.location import SdcLocation - from sdc11073.mdib.providermdib import ProviderMdib - from sdc11073.mdib.transactionsprotocol import TransactionResultProtocol + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol from sdc11073.mdib.statecontainers import AbstractStateProtocol + from sdc11073.mdib.transactionsprotocol import TransactionResultProtocol from sdc11073.provider.porttypes.localizationservice import LocalizationStorage from sdc11073.pysoap.msgfactory import CreatedMessage from sdc11073.pysoap.soapenvelope import ReceivedSoapMessage + from sdc11073.roles.protocols import ProductProtocol, WaveformProviderProtocol from sdc11073.xml_types.msg_types import AbstractSet from sdc11073.xml_types.pm_types import InstanceIdentifier from sdc11073.xml_types.wsd_types import ScopesType @@ -82,15 +84,16 @@ class SdcProvider: DEFAULT_CONTEXTSTATES_IN_GETMDIB = True # defines weather get_mdib and getMdStates contain context states or not. - def __init__(self, ws_discovery: WsDiscoveryProtocol, + def __init__(self, # noqa: PLR0915, PLR0913 + ws_discovery: WsDiscoveryProtocol, this_model: ThisModelType, this_device: ThisDeviceType, - device_mdib_container: ProviderMdib, + device_mdib_container: ProviderMdibProtocol, epr: str | uuid.UUID | None = None, validate: bool = True, ssl_context_container: sdc11073.certloader.SSLContextContainer | None = None, max_subscription_duration: int = 15, - socket_timeout: int | float | None = None, + socket_timeout: int | float | None = None, # noqa: PYI041 log_prefix: str = '', default_components: SdcProviderComponents | None = None, specific_components: SdcProviderComponents | None = None, @@ -101,7 +104,7 @@ def __init__(self, ws_discovery: WsDiscoveryProtocol, :param ws_discovery: a WsDiscovers instance :param this_model: a ThisModelType instance :param this_device: a ThisDeviceType instance - :param device_mdib_container: a ProviderMdib instance + :param device_mdib_container: a ProviderMdibProtocol instance :param epr: something that serves as a unique identifier of this device for discovery. If epr is a string, it must be usable as a path element in an url (no spaces, ...) :param validate: bool @@ -243,19 +246,19 @@ def _setup_components(self): cls = self._components.sco_operations_registry_class pm_names = self._mdib.data_model.pm_names - sco_descr_list = self._mdib.descriptions.NODETYPE.get(pm_names.ScoDescriptor, []) - for sco_descr in sco_descr_list: + entities = self._mdib.entities.by_node_type(pm_names.ScoDescriptor) + for entity in entities: sco_operations_registry = cls(self.hosted_services.set_service, self._components.operation_cls_getter, self._mdib, - sco_descr, + entity.descriptor, log_prefix=self._log_prefix) - self._sco_operations_registries[sco_descr.Handle] = sco_operations_registry + self._sco_operations_registries[entity.handle] = sco_operations_registry product_roles = self._components.role_provider_class(self._mdib, sco_operations_registry, self._log_prefix) - self.product_lookup[sco_descr.Handle] = product_roles + self.product_lookup[entity.handle] = product_roles product_roles.init_operations() if self._components.waveform_provider_class is not None: self.waveform_provider = self._components.waveform_provider_class(self._mdib, @@ -352,7 +355,7 @@ def publish(self): x_addrs) @property - def mdib(self) -> ProviderMdib: + def mdib(self) -> ProviderMdibProtocol: """Return mdib reference.""" return self._mdib @@ -396,8 +399,10 @@ def handle_operation_request(self, def start_all(self, start_rtsample_loop: bool = True, periodic_reports_interval: float | None = None, - shared_http_server=None): - """:param start_rtsample_loop: flag + shared_http_server=None): # noqa: ANN001 + """Start all background threads. + + :param start_rtsample_loop: flag :param periodic_reports_interval: if provided, a value in seconds :param shared_http_server: if provided, use this http server, else device creates its own. :return: @@ -416,7 +421,7 @@ def start_all(self, if start_rtsample_loop: self.start_rt_sample_loop() - def _start_services(self, shared_http_server=None): + def _start_services(self, shared_http_server=None): # noqa: ANN001 """Start the services.""" self._logger.info('starting services, addr = %r', self._wsdiscovery.get_active_addresses()) for sco in self._sco_operations_registries.values(): @@ -464,6 +469,7 @@ def _start_services(self, shared_http_server=None): subscriptions_manager.set_base_urls(self.base_urls) def stop_all(self, send_subscription_end: bool = True): + """Stop all background threads and clear local data.""" self.stop_realtime_sample_loop() if self._periodic_reports_handler: self._periodic_reports_handler.stop() @@ -482,6 +488,7 @@ def stop_all(self, send_subscription_end: bool = True): self._soap_client_pool.close_all() def start_rt_sample_loop(self): + """Start generating waveform data.""" if self.waveform_provider is None: raise ApiUsageError('no waveform provider configured.') if self.waveform_provider.is_running: @@ -489,10 +496,12 @@ def start_rt_sample_loop(self): self.waveform_provider.start() def stop_realtime_sample_loop(self): + """Stop generating waveform data.""" if self.waveform_provider is not None and self.waveform_provider.is_running: self.waveform_provider.stop() def get_xaddrs(self) -> list[str]: + """Return the addresses of the provider.""" if self._alternative_hostname: addresses = [self._alternative_hostname] else: diff --git a/src/sdc11073/provider/scopesfactory.py b/src/sdc11073/provider/scopesfactory.py index 1050592e..6a2dcec4 100644 --- a/src/sdc11073/provider/scopesfactory.py +++ b/src/sdc11073/provider/scopesfactory.py @@ -1,20 +1,24 @@ +"""The module implements the function mk_scopes.""" from urllib.parse import quote_plus from sdc11073.location import SdcLocation +from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol from sdc11073.xml_types.wsd_types import ScopesType -def mk_scopes(mdib) -> ScopesType: - """ scopes factory +def mk_scopes(mdib: ProviderMdibProtocol) -> ScopesType: + """Return a ScopesType instance. + This method creates the scopes for publishing in wsdiscovery. - :param mdib: - :return: wsdiscovery.Scope """ pm_types = mdib.data_model.pm_types pm_names = mdib.data_model.pm_names scope = ScopesType() - locations = mdib.context_states.NODETYPE.get(pm_names.LocationContextState, []) - assoc_loc = [loc for loc in locations if loc.ContextAssociation == pm_types.ContextAssociation.ASSOCIATED] + loc_entities = mdib.entities.by_node_type(pm_names.LocationContextDescriptor) + assoc_loc = [] + for ent in loc_entities: + assoc_loc.extend([loc for loc in ent.states.values() if + loc.ContextAssociation == pm_types.ContextAssociation.ASSOCIATED]) if len(assoc_loc) == 1: loc = assoc_loc[0] det = loc.LocationDetail @@ -26,9 +30,9 @@ def mk_scopes(mdib) -> ScopesType: (pm_names.EnsembleContextDescriptor, 'sdc.ctxt.ens'), (pm_names.WorkflowContextDescriptor, 'sdc.ctxt.wfl'), (pm_names.MeansContextDescriptor, 'sdc.ctxt.mns')): - descriptors = mdib.descriptions.NODETYPE.get(nodetype, []) - for descriptor in descriptors: - states = mdib.context_states.descriptor_handle.get(descriptor.Handle, []) + entities = mdib.entities.by_node_type(nodetype) + for entity in entities: + states = entity.states assoc_st = [s for s in states if s.ContextAssociation == pm_types.ContextAssociation.ASSOCIATED] for state in assoc_st: for ident in state.Identification: @@ -38,25 +42,25 @@ def mk_scopes(mdib) -> ScopesType: scope.text.append('sdc.mds.pkp:1.2.840.10004.20701.1.1') # key purpose Service provider return scope +def _get_device_component_based_scopes(mdib: ProviderMdibProtocol) -> set[str]: + """Return a set of scope strings. -def _get_device_component_based_scopes(mdib): - """ SDC: For every instance derived from pm:AbstractComplexDeviceComponentDescriptor in the MDIB an SDC SERVICE PROVIDER SHOULD include a URI-encoded pm:AbstractComplexDeviceComponentDescriptor/pm:Type as dpws:Scope of the MDPWS discovery messages. The URI encoding conforms to the given Extended Backus-Naur Form. E.G. sdc.cdc.type:///69650, sdc.cdc.type:/urn:oid:1.3.6.1.4.1.3592.2.1.1.0//DN_VMD After discussion with David: use only MDSDescriptor, VmdDescriptor makes no sense. - :return: a set of scopes + :return: a set of scope strings """ pm_types = mdib.data_model.pm_types pm_names = mdib.data_model.pm_names scopes = set() - descriptors = mdib.descriptions.NODETYPE.get(pm_names.MdsDescriptor) - for descriptor in descriptors: - if descriptor.Type is not None: - coding_systems = '' if descriptor.Type.CodingSystem == pm_types.DEFAULT_CODING_SYSTEM \ - else descriptor.Type.CodingSystem - csv = descriptor.Type.CodingSystemVersion or '' - scope_string = f'sdc.cdc.type:/{coding_systems}/{csv}/{descriptor.Type.Code}' + entities = mdib.entities.by_node_type(pm_names.MdsDescriptor) + for entity in entities: + if entity.descriptor.Type is not None: + coding_systems = '' if entity.descriptor.Type.CodingSystem == pm_types.DEFAULT_CODING_SYSTEM \ + else entity.descriptor.Type.CodingSystem + csv = entity.descriptor.Type.CodingSystemVersion or '' + scope_string = f'sdc.cdc.type:/{coding_systems}/{csv}/{entity.descriptor.Type.Code}' scopes.add(scope_string) return scopes diff --git a/src/sdc11073/provider/subscriptionmgr_async.py b/src/sdc11073/provider/subscriptionmgr_async.py index 98d3b40a..357ec4f2 100644 --- a/src/sdc11073/provider/subscriptionmgr_async.py +++ b/src/sdc11073/provider/subscriptionmgr_async.py @@ -1,7 +1,9 @@ +"""Async implementation of subscriptions manager.""" from __future__ import annotations import asyncio import socket +import time import traceback from collections import defaultdict from threading import Thread @@ -21,14 +23,15 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Iterable + from logging import LoggerAdapter + from sdc11073 import xml_utils from sdc11073.definitions_base import BaseDefinitions from sdc11073.dispatch import RequestData from sdc11073.mdib.mdibbase import MdibVersionGroup from sdc11073.pysoap.msgfactory import MessageFactory from sdc11073.pysoap.msgreader import ReceivedMessage from sdc11073.pysoap.soapclientpool import SoapClientPool - from sdc11073 import xml_utils def _mk_dispatch_identifier(reference_parameters: list, path_suffix: str) -> tuple[str | None, str]: @@ -106,7 +109,6 @@ async def async_send_notification_end_message( # it does not matter that we could not send the message - end is end ;) self._logger.info('exception async send subscription end to {}, subscription = {}', # noqa: PLE1205 url, self) - pass except Exception: # noqa: BLE001 self._logger.error(traceback.format_exc()) finally: @@ -116,28 +118,39 @@ async def async_send_notification_end_message( class AsyncioEventLoopThread(Thread): """Central event loop for provider.""" - def __init__(self, name: str): + def __init__(self, name: str, logger: LoggerAdapter ): super().__init__(name=name) + self._logger = logger self.daemon = True self.loop = asyncio.new_event_loop() - self.running = False + self._running = False + + @property + def running(self) -> bool: + """Return True if the event loop is running.""" + return self._running def run(self): """Run method of thread.""" - self.running = True + self._logger.info('%s started', self.__class__.__name__) + self._running = True self.loop.run_forever() + self._logger.info('%s finished', self.__class__.__name__) def run_coro(self, coro: Awaitable) -> Any: """Run threadsafe.""" - if not self.running: + if not self._running: + self._logger.error('%s: async thread is not running', self.__class__.__name__) return None return asyncio.run_coroutine_threadsafe(coro, loop=self.loop).result() def stop(self): """Stop thread.""" - self.running = False + self._logger.info('%s: stopping now', self.__class__.__name__) + self._running = False self.loop.call_soon_threadsafe(self.loop.stop) self.join() + self._logger.info('%s: stopped', self.__class__.__name__) class BICEPSSubscriptionsManagerBaseAsync(SubscriptionsManagerBase): @@ -153,14 +166,21 @@ def __init__(self, sdc_definitions: BaseDefinitions, msg_factory: MessageFactory, soap_client_pool: SoapClientPool, - max_subscription_duration: [float, None] = None, - log_prefix: str = None, + max_subscription_duration: float | None = None, + log_prefix: str | None = None, ): super().__init__(sdc_definitions, msg_factory, soap_client_pool, max_subscription_duration, log_prefix) if soap_client_pool.async_loop_subscr_mgr is None: - thr = AsyncioEventLoopThread(name='async_loop_subscr_mgr') + thr = AsyncioEventLoopThread(name='async_loop_subscr_mgr', logger=self._logger) soap_client_pool.async_loop_subscr_mgr = thr thr.start() + for _i in range(10): + if not thr.running: + time.sleep(0.1) + else: + break + if not thr.running: + raise RuntimeError('could not start AsyncioEventLoopThread') self._async_send_thread = soap_client_pool.async_loop_subscr_mgr def _mk_subscription_instance(self, request_data: RequestData) -> BicepsSubscriptionAsync: @@ -201,7 +221,8 @@ def send_to_subscribers(self, payload: MessageType | xml_utils.LxmlElement, """Send payload to all subscribers.""" with self._subscriptions.lock: if not self._async_send_thread.running: - self._logger.info('could not send notifications, async send loop is not running.') + self._logger.warning('could not send notifications, async send loop is not running.') + self._logger.warning(traceback.format_stack()) return subscribers = self._get_subscriptions_for_action(action) if isinstance(payload, MessageType): diff --git a/src/sdc11073/roles/alarmprovider.py b/src/sdc11073/roles/alarmprovider.py index 72cac3b4..8e928308 100644 --- a/src/sdc11073/roles/alarmprovider.py +++ b/src/sdc11073/roles/alarmprovider.py @@ -1,3 +1,4 @@ +"""Implementation of alarm provider functionality.""" from __future__ import annotations import time @@ -14,51 +15,22 @@ if TYPE_CHECKING: from collections.abc import Iterable - from sdc11073.mdib import ProviderMdib - from sdc11073.mdib.descriptorcontainers import AbstractDescriptorProtocol, AbstractOperationDescriptorProtocol - from sdc11073.mdib.transactionsprotocol import TransactionManagerProtocol - from sdc11073.provider.operations import OperationDefinitionBase, OperationDefinitionProtocol, ExecuteParameters + from sdc11073.mdib.descriptorcontainers import AbstractOperationDescriptorProtocol + from sdc11073.mdib.entityprotocol import EntityProtocol + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol + from sdc11073.mdib.transactionsprotocol import StateTransactionManagerProtocol + from sdc11073.provider.operations import ExecuteParameters, OperationDefinitionBase, OperationDefinitionProtocol from sdc11073.provider.sco import AbstractScoOperationsRegistry from .providerbase import OperationClassGetter -class GenericAlarmProvider(providerbase.ProviderRole): - """Provide some generic alarm handling functionality. - - - in pre commit handler it updates present alarms list of alarm system states - - runs periodic job to send currently present alarms in AlertSystemState - - supports alert delegation acc. to BICEPS chapter 6.2 - """ +class AlertDelegateProvider(providerbase.ProviderRole): + """Support alert delegation acc. to BICEPS chapter 6.2.""" - WORKER_THREAD_INTERVAL = 1.0 # seconds - self_check_safety_margin = 1.0 # how many seconds before SelfCheckInterval elapses a new self check is performed. - - def __init__(self, mdib: ProviderMdib, log_prefix: str): + def __init__(self, mdib: ProviderMdibProtocol, log_prefix: str): super().__init__(mdib, log_prefix) - self._stop_worker = Event() - self._worker_thread = None - - def init_operations(self, sco: AbstractScoOperationsRegistry): - """Initialize and start what the provider needs. - - - set initial values of all AlertSystemStateContainers. - - set initial values of all AlertStateContainers. - - start a worker thread that periodically updates AlertSystemStateContainers. - """ - super().init_operations(sco) - self._set_alert_system_states_initial_values() - self._set_alert_states_initial_values() - self._worker_thread = Thread(target=self._worker_thread_loop) - self._worker_thread.daemon = True - self._worker_thread.start() - - def stop(self): - """Stop worker thread.""" - self._stop_worker.set() - self._worker_thread.join() - def make_operation_instance(self, operation_descriptor_container: AbstractOperationDescriptorProtocol, operation_cls_getter: OperationClassGetter) -> OperationDefinitionBase | None: @@ -72,9 +44,10 @@ def make_operation_instance(self, """ pm_names = self._mdib.data_model.pm_names op_target_handle = operation_descriptor_container.OperationTarget - op_target_descr = self._mdib.descriptions.handle.get_one(op_target_handle) + op_target_entity = self._mdib.entities.by_handle(op_target_handle) if pm_names.SetAlertStateOperationDescriptor == operation_descriptor_container.NODETYPE: - if pm_names.AlertSignalDescriptor == op_target_descr.NODETYPE and op_target_descr.SignalDelegationSupported: + if (pm_names.AlertSignalDescriptor == op_target_entity.node_type + and op_target_entity.descriptor.SignalDelegationSupported): # operation_descriptor_container is a SetAlertStateOperationDescriptor set_state_descriptor_container = cast(AbstractSetStateOperationDescriptorContainer, operation_descriptor_container) @@ -89,64 +62,212 @@ def make_operation_instance(self, operation_handler=self._delegate_alert_signal, timeout_handler=self._on_timeout_delegate_alert_signal) - self._logger.debug('GenericAlarmProvider: added handler "self._setAlertState" for %s target=%s', - operation_descriptor_container, op_target_descr) + self._logger.debug('%s: added handler "self._delegate_alert_signal" for %s target=%s', + self.__class__.__name__, operation_descriptor_container, op_target_handle) return operation return None # None == no handler for this operation instantiated - def _set_alert_system_states_initial_values(self): - """Set ActivationState to ON in all alert systems. + def _delegate_alert_signal(self, params: ExecuteParameters) -> ExecuteResult: + """Handle operation call from remote (ExecuteHandler). - Adds audible SystemSignalActivation, state=ON to all AlertSystemState instances. Why???? + Sets ActivationState, Presence and ActualSignalGenerationDelay of the corresponding state in mdib. + If this is a delegable signal, it also sets the ActivationState of the fallback signal. """ + value = params.operation_request.argument + pm_types = self._mdib.data_model.pm_types pm_names = self._mdib.data_model.pm_names + all_alert_signal_entities = self._mdib.entities.by_node_type(pm_names.AlertSignalDescriptor) + + operation_target_handle = params.operation_instance.operation_target_handle + op_target_entity = self._mdib.entities.by_handle(operation_target_handle) + + self._logger.info('delegate alert signal %s of %s from %s to %s', operation_target_handle, + op_target_entity.state, op_target_entity.state.ActivationState, value.ActivationState) + for name in params.operation_instance.descriptor_container.ModifiableData: + tmp = getattr(value, name) + setattr(op_target_entity.state, name, tmp) + modified = [] + if op_target_entity.descriptor.SignalDelegationSupported: + if value.ActivationState == pm_types.AlertActivation.ON: + modified = self._pause_fallback_alert_signals(op_target_entity, + all_alert_signal_entities) + else: + modified = self._activate_fallback_alert_signals(op_target_entity, + all_alert_signal_entities) + with self._mdib.alert_state_transaction() as mgr: + mgr.write_entity(op_target_entity) + mgr.write_entities(modified) + + return ExecuteResult(operation_target_handle, + self._mdib.data_model.msg_types.InvocationState.FINISHED) + + def _on_timeout_delegate_alert_signal(self, operation_instance: OperationDefinitionProtocol): + """TimeoutHandler for delegated signal.""" pm_types = self._mdib.data_model.pm_types + pm_names = self._mdib.data_model.pm_names + + operation_target_handle = operation_instance.operation_target_handle + op_target_entity = self._mdib.entities.by_handle(operation_target_handle) - states = self._mdib.states.NODETYPE.get(pm_names.AlertSystemState, []) - for state in states: - state.ActivationState = pm_types.AlertActivation.ON - state.SystemSignalActivation.append( - pm_types.SystemSignalActivation(manifestation=pm_types.AlertSignalManifestation.AUD, - state=pm_types.AlertActivation.ON)) + all_alert_signal_entities = self._mdib.entities.by_node_type(pm_names.AlertSignalDescriptor) + self._logger.info('timeout alert signal delegate operation=%s target=%s', + operation_instance.handle, operation_target_handle) + op_target_entity.state.ActivationState = pm_types.AlertActivation.OFF + modified = self._activate_fallback_alert_signals(op_target_entity, + all_alert_signal_entities) - def _set_alert_states_initial_values(self): - """Set AlertConditions and AlertSignals. + with self._mdib.alert_state_transaction() as mgr: + mgr.write_entity(op_target_entity) + mgr.write_entities(modified) - - if an AlertCondition.ActivationState is 'On', then the local AlertSignals shall also be 'On' - - all remote alert Signals shall be 'Off' initially (must be explicitly enabled by delegating device). + def _pause_fallback_alert_signals(self, + delegable_signal_entity: EntityProtocol, + all_signal_entities: list[EntityProtocol], + ) -> list[EntityProtocol]: + """Pause fallback signals. + + The idea of the fallback signal is to set it paused when the delegable signal is currently ON, + and to set it back to ON when the delegable signal is not ON. + This method sets the fallback to PAUSED value. + :param delegable_signal_entity: the signal that the fallback signals are looked for. + :param all_signal_entities: list of all signals + :return: list of modified entities """ pm_types = self._mdib.data_model.pm_types - pm_names = self._mdib.data_model.pm_names - for alert_condition in self._mdib.states.NODETYPE.get(pm_names.AlertConditionState, []): - alert_condition.ActivationState = pm_types.AlertActivation.ON - alert_condition.Presence = False - for alert_condition in self._mdib.states.NODETYPE.get(pm_names.LimitAlertConditionState, []): - alert_condition.ActivationState = pm_types.AlertActivation.ON - alert_condition.Presence = False - - for alert_signal_state in self._mdib.states.NODETYPE.get(pm_names.AlertSignalState, []): - alert_signal_descr = self._mdib.descriptions.handle.get_one(alert_signal_state.DescriptorHandle) - if alert_signal_descr.SignalDelegationSupported: - alert_signal_state.Location = pm_types.AlertSignalPrimaryLocation.REMOTE - alert_signal_state.ActivationState = pm_types.AlertActivation.OFF - alert_signal_state.Presence = pm_types.AlertSignalPresence.OFF - else: - alert_signal_state.ActivationState = pm_types.AlertActivation.ON - alert_signal_state.Presence = pm_types.AlertSignalPresence.OFF + modified: list[EntityProtocol] = [] + # look for local fallback signal (same Manifestation), and set it to paused + for fallback_entity in self._get_fallback_signals(delegable_signal_entity, + all_signal_entities): + if fallback_entity.state.ActivationState != pm_types.AlertActivation.PAUSED: + fallback_entity.state.ActivationState = pm_types.AlertActivation.PAUSED + modified.append(fallback_entity) + return modified + + def _activate_fallback_alert_signals(self, delegable_signal_entity: EntityProtocol, + all_signal_entities: list[EntityProtocol], + ) -> list[EntityProtocol]: + pm_types = self._mdib.data_model.pm_types + modified: list[EntityProtocol] = [] + + # look for local fallback signal (same Manifestation), and set it to paused + for fallback_entity in self._get_fallback_signals(delegable_signal_entity, + all_signal_entities): + if fallback_entity.state.ActivationState == pm_types.AlertActivation.PAUSED: + fallback_entity.state.ActivationState = pm_types.AlertActivation.ON + modified.append(fallback_entity) + return modified + + @staticmethod + def _get_fallback_signals(delegable_signal_entity: EntityProtocol, + all_signal_entities: list[EntityProtocol]) -> list[EntityProtocol]: + """Return a list of all fallback signals for descriptor. + + looks in all_signal_descriptors for a signal with same ConditionSignaled and same + Manifestation as delegable_signal_descriptor and SignalDelegationSupported == True. + """ + return [tmp for tmp in all_signal_entities if not tmp.descriptor.SignalDelegationSupported + and tmp.descriptor.Manifestation == delegable_signal_entity.descriptor.Manifestation + and tmp.descriptor.ConditionSignaled == delegable_signal_entity.descriptor.ConditionSignaled] + + +class AlertSystemStateMaintainer(providerbase.ProviderRole): + """Provide some generic alarm handling functionality. - def _get_changed_alert_condition_states(self, - transaction: TransactionManagerProtocol) -> list[AbstractStateProtocol]: + - runs periodic job to update currently present alarms in AlertSystemState + """ + + WORKER_THREAD_INTERVAL = 1.0 # seconds + self_check_safety_margin = 1.0 # how many seconds before SelfCheckInterval elapses a new self check is performed. + + def __init__(self, mdib: ProviderMdibProtocol, log_prefix: str): + super().__init__(mdib, log_prefix) + + self._stop_worker = Event() + self._worker_thread = None + + def init_operations(self, sco: AbstractScoOperationsRegistry): + """Start a worker thread that periodically updates AlertSystemStateContainers.""" + super().init_operations(sco) + self._worker_thread = Thread(target=self._worker_thread_loop) + self._worker_thread.daemon = True + self._worker_thread.start() + + def stop(self): + """Stop worker thread.""" + self._stop_worker.set() + self._worker_thread.join() + + def _worker_thread_loop(self): + # delay start of operation + shall_stop = self._stop_worker.wait(timeout=self.WORKER_THREAD_INTERVAL) + if shall_stop: + return + + while True: + shall_stop = self._stop_worker.wait(timeout=self.WORKER_THREAD_INTERVAL) + if shall_stop: + return + self._update_alert_system_state_current_alerts() + + def _update_alert_system_state_current_alerts(self): + """Update AlertSystemState present alarms list.""" + try: + entities_needing_update = self._get_alert_system_entities_needing_update() + if len(entities_needing_update) > 0: + with self._mdib.alert_state_transaction() as mgr: + self._update_alert_system_states(entities_needing_update) + mgr.write_entities(entities_needing_update) + except Exception: # noqa: BLE001 + self._logger.error('_update_alert_system_state_current_alerts: %s', traceback.format_exc()) + + def _get_alert_system_entities_needing_update(self) -> list[EntityProtocol]: pm_names = self._mdib.data_model.pm_names - result = [] - for item in list(transaction.alert_state_updates.values()): - tmp = item.old if item.new is None else item.new - if tmp.NODETYPE in (pm_names.AlertConditionState, - pm_names.LimitAlertConditionState): - result.append(tmp) - return result + entities_needing_update = [] + try: + all_alert_system_entities = self._mdib.entities.by_node_type(pm_names.AlertSystemDescriptor) + for alert_system_entity in all_alert_system_entities: + if alert_system_entity.state is not None: + self_check_period = alert_system_entity.descriptor.SelfCheckPeriod + if self_check_period is not None: + last_self_check = alert_system_entity.state.LastSelfCheck or 0.0 + if time.time() - last_self_check >= self_check_period - self.self_check_safety_margin: + entities_needing_update.append(alert_system_entity) + except Exception: # noqa: BLE001 + self._logger.error('_get_alert_system_entities_needing_update: %s', traceback.format_exc()) + return entities_needing_update + + def _update_alert_system_states(self, alert_system_entities: Iterable[EntityProtocol]): + """Update alert system states.""" + pm_types = self._mdib.data_model.pm_types + + for alert_system_entity in alert_system_entities: + all_child_entities = self._mdib.entities.by_parent_handle(alert_system_entity.handle) + all_alert_condition_entities = [d for d in all_child_entities if d.descriptor.is_alert_condition_descriptor] + # select all state containers with technical alarms present + all_tech_entities = [d for d in all_alert_condition_entities if + d.descriptor.Kind == pm_types.AlertConditionKind.TECHNICAL] + all_present_tech_entities = [s for s in all_tech_entities if s.state.Presence] + # select all state containers with physiological alarms present + all_phys_entities = [d for d in all_alert_condition_entities if + d.descriptor.Kind == pm_types.AlertConditionKind.PHYSIOLOGICAL] + all_present_phys_entities = [s for s in all_phys_entities if s.state.Presence] + + alert_system_entity.state.PresentTechnicalAlarmConditions = [e.handle for e in all_present_tech_entities] + alert_system_entity.state.PresentPhysiologicalAlarmConditions = [e.handle for e in + all_present_phys_entities] + alert_system_entity.state.LastSelfCheck = time.time() + alert_system_entity.state.SelfCheckCount = 1 if alert_system_entity.state.SelfCheckCount is None \ + else alert_system_entity.state.SelfCheckCount + 1 - def on_pre_commit(self, mdib: ProviderMdib, transaction: TransactionManagerProtocol): + +class AlertPreCommitHandler(providerbase.ProviderRole): + """Provide some generic alarm handling functionality. + + - in pre commit handler it updates present alarms list of alarm system states + """ + + def on_pre_commit(self, mdib: ProviderMdibProtocol, transaction: StateTransactionManagerProtocol): """Manipulate the transaction. - Updates alert system states and adds them to transaction, if at least one of its alert @@ -158,104 +279,106 @@ def on_pre_commit(self, mdib: ProviderMdib, transaction: TransactionManagerProto :return: """ if not transaction.alert_state_updates: + # nothing to do return - changed_alert_conditions = self._get_changed_alert_condition_states(transaction) + all_alert_signal_entities = self._mdib.entities.by_node_type(self._mdib.data_model.pm_names.AlertSignalDescriptor) + changed_alert_condition_states = self._get_changed_alert_condition_states(transaction) # change AlertSignal Settings in order to be compliant with changed Alert Conditions - for changed_alert_condition in changed_alert_conditions: - self._update_alert_signals(changed_alert_condition, mdib, transaction) - - # find all alert systems with changed states - alert_system_states = self._find_alert_systems_with_modifications(transaction, changed_alert_conditions) + for changed_alert_condition_state in changed_alert_condition_states: + self._update_alert_signals(changed_alert_condition_state, + all_alert_signal_entities, + mdib, + transaction) + + # find all alert systems for changed_alert_condition_states + alert_system_states = self._find_alert_systems_with_modifications(transaction, + changed_alert_condition_states) if alert_system_states: # add found alert system states to transaction - self._update_alert_system_states(mdib, transaction, alert_system_states, is_self_check=False) + self._update_alert_system_states(mdib, alert_system_states, transaction) @staticmethod - def _find_alert_systems_with_modifications(transaction: TransactionManagerProtocol, - changed_alert_conditions: list[AbstractStateProtocol]) \ - -> set[AbstractStateProtocol]: - # find all alert systems for the changed alert conditions - alert_system_states = set() - for tmp in changed_alert_conditions: - alert_descriptor = transaction.actual_descriptor(tmp.DescriptorHandle) - alert_system_descriptor = transaction.actual_descriptor(alert_descriptor.parent_handle) - if alert_system_descriptor.Handle in transaction.alert_state_updates: - tmp_st = transaction.alert_state_updates[alert_system_descriptor.Handle] - if tmp_st.new is not None: - alert_system_states.add(tmp_st.new) - else: - alert_system_states.add(transaction.get_state(alert_system_descriptor.Handle)) - return alert_system_states - - @staticmethod - def _update_alert_system_states(mdib: ProviderMdib, - transaction: TransactionManagerProtocol, + def _update_alert_system_states(mdib: ProviderMdibProtocol, alert_system_states: Iterable[AbstractStateProtocol], - is_self_check: bool = True): - """Update alert system states.""" + transaction: StateTransactionManagerProtocol): + """Update alert system states PresentTechnicalAlarmConditions and PresentPhysiologicalAlarmConditions.""" pm_types = mdib.data_model.pm_types - def _get_alert_state(descriptor_handle: str) -> AbstractStateProtocol: - alert_state = None - tr_item = transaction.get_state_transaction_item(descriptor_handle) + def _get_alert_state(state: AbstractStateProtocol) -> AbstractStateProtocol: + """Return the equivalent state from current transaction, if it already in transaction.""" + _item = transaction.get_state_transaction_item(state.DescriptorHandle) + if _item is not None: + return _item.new + return state + + for _alert_system_state in alert_system_states: + write_entity = True + tr_item = transaction.get_state_transaction_item(_alert_system_state.DescriptorHandle) + # If the alert system state is already part of the transaction, make changes in that instance instead. if tr_item is not None: - alert_state = tr_item.new - if alert_state is None: - # it is not part of this transaction - alert_state = mdib.states.descriptor_handle.get_one(descriptor_handle, allow_none=True) - if alert_state is None: - raise ValueError(f'there is no alert state for {descriptor_handle}') - return alert_state - - for state in alert_system_states: - all_child_descriptors = mdib.descriptions.parent_handle.get(state.DescriptorHandle, []) - all_child_descriptors.extend( - [i.new for i in transaction.descriptor_updates.values() if - i.new.parent_handle == state.DescriptorHandle]) - all_alert_condition_descr = [d for d in all_child_descriptors if hasattr(d, 'Kind')] + alert_system_state = tr_item.new + write_entity = False + else: + alert_system_state = _alert_system_state + all_child_entities = mdib.entities.by_parent_handle(alert_system_state.DescriptorHandle) + all_alert_condition_entities = [d for d in all_child_entities if d.descriptor.is_alert_condition_descriptor] # select all state containers with technical alarms present - all_tech_descr = [d for d in all_alert_condition_descr if d.Kind == pm_types.AlertConditionKind.TECHNICAL] - _all_tech_states = [_get_alert_state(d.Handle) for d in all_tech_descr] + all_tech_entities = [d for d in all_alert_condition_entities if + d.descriptor.Kind == pm_types.AlertConditionKind.TECHNICAL] + _all_tech_states = [_get_alert_state(d.state) for d in all_tech_entities] all_tech_states = cast(list[AlertConditionStateContainer], _all_tech_states) - all_tech_states = [s for s in all_tech_states if s is not None] all_present_tech_states = [s for s in all_tech_states if s.Presence] - # select all state containers with physiological alarms present - all_phys_descr = [d for d in all_alert_condition_descr if - d.Kind == pm_types.AlertConditionKind.PHYSIOLOGICAL] - _all_phys_states = [_get_alert_state(d.Handle) for d in all_phys_descr] + + all_phys_entities = [d for d in all_alert_condition_entities if + d.descriptor.Kind == pm_types.AlertConditionKind.PHYSIOLOGICAL] + _all_phys_states = [_get_alert_state(d.state) for d in all_phys_entities] all_phys_states = cast(list[AlertConditionStateContainer], _all_phys_states) all_phys_states = [s for s in all_phys_states if s is not None] all_present_phys_states = [s for s in all_phys_states if s.Presence] - state.PresentTechnicalAlarmConditions = [s.DescriptorHandle for s in all_present_tech_states] - state.PresentPhysiologicalAlarmConditions = [s.DescriptorHandle for s in all_present_phys_states] - if is_self_check: - state.LastSelfCheck = time.time() - state.SelfCheckCount = 1 if state.SelfCheckCount is None else state.SelfCheckCount + 1 + alert_system_state.PresentTechnicalAlarmConditions = [s.DescriptorHandle for s in all_present_tech_states] + alert_system_state.PresentPhysiologicalAlarmConditions = [s.DescriptorHandle for s in + all_present_phys_states] + if write_entity: + transaction.write_entity(alert_system_state) + + @staticmethod + def _get_changed_alert_condition_states(transaction: StateTransactionManagerProtocol) -> list[ + AbstractStateProtocol]: + """Return all alert conditions in current transaction.""" + result = [] + for item in list(transaction.alert_state_updates.values()): + tmp = item.old if item.new is None else item.new + if tmp.is_alert_condition: + result.append(tmp) + return result @staticmethod def _update_alert_signals(changed_alert_condition: AbstractStateProtocol, - mdib: ProviderMdib, - transaction: TransactionManagerProtocol): + all_alert_signal_entities: list[EntityProtocol], + mdib: ProviderMdibProtocol, + transaction: StateTransactionManagerProtocol): """Handle alert signals for a changed alert condition. This method only changes states of local signals. Handling of delegated signals is in the responsibility of the delegated device! """ pm_types = mdib.data_model.pm_types - alert_signal_descriptors = mdib.descriptions.condition_signaled.get(changed_alert_condition.DescriptorHandle, - []) + + my_alert_signal_entities = [e for e in all_alert_signal_entities + if e.descriptor.ConditionSignaled == changed_alert_condition.DescriptorHandle] # separate remote from local - remote_alert_signal_descriptors = [a for a in alert_signal_descriptors if a.SignalDelegationSupported] - local_alert_signal_descriptors = [a for a in alert_signal_descriptors if not a.SignalDelegationSupported] + remote_alert_signal_entities = [a for a in my_alert_signal_entities if a.descriptor.SignalDelegationSupported] + local_alert_signal_entities = [a for a in my_alert_signal_entities if + not a.descriptor.SignalDelegationSupported] # look for active delegations (we only need the Manifestation value here) active_delegate_manifestations = [] - for descriptor in remote_alert_signal_descriptors: - alert_signal_state = mdib.states.descriptor_handle.get_one(descriptor.Handle) - if alert_signal_state.Presence != pm_types.AlertSignalPresence.OFF and alert_signal_state.Location == 'Rem': - active_delegate_manifestations.append(descriptor.Manifestation) + for entity in remote_alert_signal_entities: + if (entity.state.Presence != pm_types.AlertSignalPresence.OFF + and entity.state.Location == pm_types.AlertSignalPrimaryLocation.REMOTE): + active_delegate_manifestations.append(entity.descriptor.Manifestation) # this lookup gives the values that a local signal shall have: # key = (Cond.Presence, isDelegated): value = (SignalState.ActivationState, SignalState.Presence) @@ -266,154 +389,31 @@ def _update_alert_signals(changed_alert_condition: AbstractStateProtocol, (False, True): (pm_types.AlertActivation.PAUSED, pm_types.AlertSignalPresence.OFF), (False, False): (pm_types.AlertActivation.ON, pm_types.AlertSignalPresence.OFF), } - for descriptor in local_alert_signal_descriptors: - tr_item = transaction.get_state_transaction_item(descriptor.Handle) + for entity in local_alert_signal_entities: + tr_item = transaction.get_state_transaction_item(entity.handle) if tr_item is None: - is_delegated = descriptor.Manifestation in active_delegate_manifestations # is this local signal delegated? + is_delegated = entity.descriptor.Manifestation in active_delegate_manifestations # is this local signal delegated? activation, presence = lookup[(changed_alert_condition.Presence, is_delegated)] - alert_signal_state = transaction.get_state(descriptor.Handle) - - if alert_signal_state.ActivationState != activation or alert_signal_state.Presence != presence: - alert_signal_state.ActivationState = activation - alert_signal_state.Presence = presence - else: - # don't change - transaction.unget_state(alert_signal_state) - - def _pause_fallback_alert_signals(self, - delegable_signal_descriptor: AbstractDescriptorProtocol, - all_signal_descriptors: list[AbstractDescriptorProtocol] | None, - transaction: TransactionManagerProtocol): - """Pause fallback signals. - The idea of the fallback signal is to set it paused when the delegable signal is currently ON, - and to set it back to ON when the delegable signal is not ON. - This method sets the fallback to PAUSED value. - :param delegable_signal_descriptor: a descriptor container - :param all_signal_descriptors: list of descriptor containers - :param transaction: the current transaction. - :return: - """ - pm_types = self._mdib.data_model.pm_types - # look for local fallback signal (same Manifestation), and set it to paused - for fallback in self._get_fallback_signals(delegable_signal_descriptor, all_signal_descriptors): - ss_fallback = transaction.get_state(fallback.Handle) - if ss_fallback.ActivationState != pm_types.AlertActivation.PAUSED: - ss_fallback.ActivationState = pm_types.AlertActivation.PAUSED - else: - transaction.unget_state(ss_fallback) - - def _activate_fallback_alert_signals(self, delegable_signal_descriptor: AbstractDescriptorProtocol, - all_signal_descriptors: list[AbstractDescriptorProtocol] | None, - transaction: TransactionManagerProtocol): - pm_types = self._mdib.data_model.pm_types - # look for local fallback signal (same Manifestation), and set it to paused - for fallback in self._get_fallback_signals(delegable_signal_descriptor, all_signal_descriptors): - ss_fallback = transaction.get_state(fallback.Handle) - if ss_fallback.ActivationState == pm_types.AlertActivation.PAUSED: - ss_fallback.ActivationState = pm_types.AlertActivation.ON - else: - transaction.unget_state(ss_fallback) + if entity.state.ActivationState != activation or entity.state.Presence != presence: + entity.state.ActivationState = activation + entity.state.Presence = presence + transaction.write_entity(entity) - def _get_fallback_signals(self, - delegable_signal_descriptor: AbstractDescriptorProtocol, - all_signal_descriptors: list[AbstractDescriptorProtocol] | None) -> list[ - AbstractDescriptorProtocol]: - """Return a list of all fallback signals for descriptor. - - looks in all_signal_descriptors for a signal with same ConditionSignaled and same - Manifestation as delegable_signal_descriptor and SignalDelegationSupported == True. - """ - if all_signal_descriptors is None: - all_signal_descriptors = self._mdib.descriptions.condition_signaled.get( - delegable_signal_descriptor.ConditionSignaled, []) - return [tmp for tmp in all_signal_descriptors if not tmp.SignalDelegationSupported - and tmp.Manifestation == delegable_signal_descriptor.Manifestation - and tmp.ConditionSignaled == delegable_signal_descriptor.ConditionSignaled] - - def _delegate_alert_signal(self, params: ExecuteParameters) -> ExecuteResult: - """Handle operation call from remote (ExecuteHandler). - - Sets ActivationState, Presence and ActualSignalGenerationDelay of the corresponding state in mdib. - If this is a delegable signal, it also sets the ActivationState of the fallback signal. - - :param operation_instance: OperationDefinition instance - :param value: AlertSignalStateContainer instance - :return: - """ - value = params.operation_request.argument - pm_types = self._mdib.data_model.pm_types - operation_target_handle = params.operation_instance.operation_target_handle - with self._mdib.alert_state_transaction() as mgr: - state = mgr.get_state(operation_target_handle) - self._logger.info('delegate alert signal %s of %s from %s to %s', operation_target_handle, state, - state.ActivationState, value.ActivationState) - for name in params.operation_instance.descriptor_container.ModifiableData: - tmp = getattr(value, name) - setattr(state, name, tmp) - descr = self._mdib.descriptions.handle.get_one(operation_target_handle) - if descr.SignalDelegationSupported: - if value.ActivationState == pm_types.AlertActivation.ON: - self._pause_fallback_alert_signals(descr, None, mgr) - else: - self._activate_fallback_alert_signals(descr, None, mgr) - return ExecuteResult(operation_target_handle, - self._mdib.data_model.msg_types.InvocationState.FINISHED) - - def _on_timeout_delegate_alert_signal(self, operation_instance: OperationDefinitionProtocol): - """TimeoutHandler for delegated signal.""" - pm_types = self._mdib.data_model.pm_types - operation_target_handle = operation_instance.operation_target_handle - with self._mdib.alert_state_transaction() as mgr: - state = mgr.get_state(operation_target_handle) - self._logger.info('timeout alert signal delegate operation=%s target=%s', - operation_instance.handle, operation_target_handle) - state.ActivationState = pm_types.AlertActivation.OFF - descr = self._mdib.descriptions.handle.get_one(operation_target_handle) - self._activate_fallback_alert_signals(descr, None, mgr) - - def _worker_thread_loop(self): - # delay start of operation - shall_stop = self._stop_worker.wait(timeout=self.WORKER_THREAD_INTERVAL) - if shall_stop: - return - - while True: - shall_stop = self._stop_worker.wait(timeout=self.WORKER_THREAD_INTERVAL) - if shall_stop: - return - self._run_worker_job() - - def _run_worker_job(self): - self._update_alert_system_state_current_alerts() + def _find_alert_systems_with_modifications(self, + transaction: StateTransactionManagerProtocol, + changed_alert_conditions: list[AbstractStateProtocol]) \ + -> set[AbstractStateProtocol]: + # find all alert systems for the changed alert conditions + alert_system_states = set() + for tmp in changed_alert_conditions: + alert_condition_entity = self._mdib.entities.by_handle(tmp.DescriptorHandle) + alert_system_entity = self._mdib.entities.by_handle(alert_condition_entity.parent_handle) - def _update_alert_system_state_current_alerts(self): - """Update AlertSystemState present alarms list.""" - try: - with self._mdib.alert_state_transaction() as mgr: - states_needing_update = self._get_alert_system_states_needing_update() - if len(states_needing_update) > 0: - tr_states = [mgr.get_state(s.DescriptorHandle) for s in states_needing_update] - self._update_alert_system_states(self._mdib, mgr, tr_states) - except Exception: - self._logger.error('_update_alert_system_state_current_alerts: %s', traceback.format_exc()) + if alert_system_entity.handle not in transaction.alert_state_updates: + transaction.write_entity(alert_system_entity) - def _get_alert_system_states_needing_update(self) -> list[AbstractStateProtocol]: - """:return: all AlertSystemStateContainers of those last""" - pm_names = self._mdib.data_model.pm_names - states_needing_update = [] - try: - all_alert_systems_descr = self._mdib.descriptions.NODETYPE.get(pm_names.AlertSystemDescriptor, - []) - for alert_system_descr in all_alert_systems_descr: - alert_system_state = self._mdib.states.descriptor_handle.get_one(alert_system_descr.Handle, - allow_none=True) - if alert_system_state is not None: - self_check_period = alert_system_descr.SelfCheckPeriod - if self_check_period is not None: - last_self_check = alert_system_state.LastSelfCheck or 0.0 - if time.time() - last_self_check >= self_check_period - self.self_check_safety_margin: - states_needing_update.append(alert_system_state) - except Exception: - self._logger.error('_get_alert_system_states_needing_update: %r', traceback.format_exc()) - return states_needing_update + transaction_item = transaction.alert_state_updates[alert_system_entity.handle] + if transaction_item.new is not None: + alert_system_states.add(transaction_item.new) + return alert_system_states diff --git a/src/sdc11073/roles/audiopauseprovider.py b/src/sdc11073/roles/audiopauseprovider.py index 9d0d2c8c..95b71e25 100644 --- a/src/sdc11073/roles/audiopauseprovider.py +++ b/src/sdc11073/roles/audiopauseprovider.py @@ -1,8 +1,8 @@ +"""Implementation of audio pause functionality.""" from __future__ import annotations -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING -from sdc11073.mdib.statecontainers import AlertSystemStateContainer from sdc11073.provider.operations import ExecuteResult from sdc11073.xml_types.msg_types import InvocationState from sdc11073.xml_types.pm_types import Coding @@ -12,8 +12,8 @@ if TYPE_CHECKING: from sdc11073.mdib.descriptorcontainers import AbstractOperationDescriptorProtocol - from sdc11073.mdib.providermdib import ProviderMdib - from sdc11073.provider.operations import OperationDefinitionBase, ExecuteParameters + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol + from sdc11073.provider.operations import ExecuteParameters, OperationDefinitionBase from sdc11073.provider.sco import AbstractScoOperationsRegistry # coded values for SDC audio pause @@ -29,7 +29,7 @@ class GenericAudioPauseProvider(ProviderRole): Nothing is added to the mdib. If the mdib does not contain these operations, the functionality is not available. """ - def __init__(self, mdib: ProviderMdib, log_prefix: str): + def __init__(self, mdib: ProviderMdibProtocol, log_prefix: str): super().__init__(mdib, log_prefix) self._set_global_audio_pause_operations = [] self._cancel_global_audio_pause_operations = [] @@ -45,14 +45,18 @@ def make_operation_instance(self, self._logger.debug('instantiating "set audio pause" operation from existing descriptor handle=%s', operation_descriptor_container.Handle) set_ap_operation = self._mk_operation_from_operation_descriptor( - operation_descriptor_container, operation_cls_getter, operation_handler=self._set_global_audio_pause) + operation_descriptor_container, + operation_cls_getter, + operation_handler=self._set_global_audio_pause) self._set_global_audio_pause_operations.append(set_ap_operation) return set_ap_operation if operation_descriptor_container.coding == MDC_OP_SET_CANCEL_ALARMS_AUDIO_PAUSE: self._logger.debug('instantiating "cancel audio pause" operation from existing descriptor handle=%s', operation_descriptor_container.Handle) cancel_ap_operation = self._mk_operation_from_operation_descriptor( - operation_descriptor_container, operation_cls_getter, operation_handler=self._cancel_global_audio_pause) + operation_descriptor_container, + operation_cls_getter, + operation_handler=self._cancel_global_audio_pause) self._cancel_global_audio_pause_operations.append(cancel_ap_operation) return cancel_ap_operation @@ -72,53 +76,49 @@ def _set_global_audio_pause(self, params: ExecuteParameters) -> ExecuteResult: """ pm_types = self._mdib.data_model.pm_types pm_names = self._mdib.data_model.pm_names + + alert_system_entities = self._mdib.entities.by_node_type(pm_names.AlertSystemDescriptor) + if len(alert_system_entities) == 0: + self._logger.warning('_set_global_audio_pause called, but no AlertSystemDescriptor in mdib found') + return ExecuteResult(params.operation_instance.operation_target_handle, InvocationState.FAILED) + with self._mdib.alert_state_transaction() as mgr: - alert_system_descriptors = self._mdib.descriptions.NODETYPE.get(pm_names.AlertSystemDescriptor) - if alert_system_descriptors is None: - self._logger.warning('SDC_SetAudioPauseOperation called, but no AlertSystemDescriptor in mdib found') - return ExecuteResult(params.operation_instance.operation_target_handle, InvocationState.FAILED) + for as_entity in alert_system_entities: + if as_entity.state.ActivationState != pm_types.AlertActivation.ON: + self._logger.info('_set_global_audio_pause: nothing to do for alert system %s', as_entity.handle) + continue + + audible_signals = [ssa for ssa in as_entity.state.SystemSignalActivation if + ssa.Manifestation == pm_types.AlertSignalManifestation.AUD] + active_audible_signals = [ssa for ssa in audible_signals if + ssa.State != pm_types.AlertActivation.PAUSED] + if len(active_audible_signals) > 0: + for ssa in active_audible_signals: + ssa.State = pm_types.AlertActivation.PAUSED + self._logger.info('SetAudioPauseOperation: set alert system "%s" to paused', + as_entity.handle) + + # handle all audible alert signals of this alert system + all_alert_signal_entities = self._mdib.entities.by_node_type( + pm_names.AlertSignalDescriptor) + child_alert_signals = [d for d in all_alert_signal_entities if d.parent_handle == as_entity.handle] + audible_child_alert_signals = [d for d in child_alert_signals if + d.descriptor.Manifestation == pm_types.AlertSignalManifestation.AUD] + + for aud_signal in audible_child_alert_signals: + if aud_signal.descriptor.AcknowledgementSupported: + if aud_signal.state.ActivationState != pm_types.AlertActivation.PAUSED \ + or aud_signal.state.Presence != pm_types.AlertSignalPresence.ACK: + aud_signal.state.ActivationState = pm_types.AlertActivation.PAUSED + aud_signal.state.Presence = pm_types.AlertSignalPresence.ACK + mgr.write_entity(aud_signal) + elif aud_signal.state.ActivationState != pm_types.AlertActivation.PAUSED \ + or aud_signal.state.Presence != pm_types.AlertSignalPresence.OFF: + aud_signal.state.ActivationState = pm_types.AlertActivation.PAUSED + aud_signal.state.Presence = pm_types.AlertSignalPresence.OFF + mgr.write_entity(aud_signal) + mgr.write_entity(as_entity) - for alert_system_descriptor in alert_system_descriptors: - _alert_system_state = mgr.get_state(alert_system_descriptor.Handle) - alert_system_state = cast(AlertSystemStateContainer, _alert_system_state) - if alert_system_state.ActivationState != pm_types.AlertActivation.ON: - self._logger.info('SDC_SetAudioPauseOperation: nothing to do') - mgr.unget_state(_alert_system_state) - else: - audible_signals = [ssa for ssa in alert_system_state.SystemSignalActivation if - ssa.Manifestation == pm_types.AlertSignalManifestation.AUD] - active_audible_signals = [ssa for ssa in audible_signals if - ssa.State != pm_types.AlertActivation.PAUSED] - if not active_audible_signals: - # Alert System has no audible SystemSignalActivations, no action required - mgr.unget_state(_alert_system_state) - else: - for ssa in active_audible_signals: - ssa.State = pm_types.AlertActivation.PAUSED - self._logger.info('SetAudioPauseOperation: set alert system "%s" to paused', - alert_system_descriptor.Handle) - # handle all audible alert signals of this alert system - all_alert_signal_descriptors = self._mdib.descriptions.NODETYPE.get( - pm_names.AlertSignalDescriptor, []) - child_alert_signal_descriptors = [d for d in all_alert_signal_descriptors if - d.parent_handle == alert_system_descriptor.Handle] - audible_child_alert_signal_descriptors = [d for d in child_alert_signal_descriptors if - d.Manifestation == pm_types.AlertSignalManifestation.AUD] - for descriptor in audible_child_alert_signal_descriptors: - alert_signal_state = mgr.get_state(descriptor.Handle) - if descriptor.AcknowledgementSupported: - if alert_signal_state.ActivationState != pm_types.AlertActivation.PAUSED \ - or alert_signal_state.Presence != pm_types.AlertSignalPresence.ACK: - alert_signal_state.ActivationState = pm_types.AlertActivation.PAUSED - alert_signal_state.Presence = pm_types.AlertSignalPresence.ACK - else: - mgr.unget_state(alert_signal_state) - elif alert_signal_state.ActivationState != pm_types.AlertActivation.PAUSED \ - or alert_signal_state.Presence != pm_types.AlertSignalPresence.OFF: - alert_signal_state.ActivationState = pm_types.AlertActivation.PAUSED - alert_signal_state.Presence = pm_types.AlertSignalPresence.OFF - else: - mgr.unget_state(alert_signal_state) return ExecuteResult(params.operation_instance.operation_target_handle, self._mdib.data_model.msg_types.InvocationState.FINISHED) @@ -131,47 +131,39 @@ def _cancel_global_audio_pause(self, params: ExecuteParameters) -> ExecuteResult pm_types = self._mdib.data_model.pm_types pm_names = self._mdib.data_model.pm_names with self._mdib.alert_state_transaction() as mgr: - alert_system_descriptors = self._mdib.descriptions.NODETYPE.get(pm_names.AlertSystemDescriptor) - if alert_system_descriptors is None: - self._logger.warning('SDC_SetAudioPauseOperation called, but no AlertSystemDescriptor in mdib found') + alert_system_entities = self._mdib.entities.by_node_type(pm_names.AlertSystemDescriptor) + if len(alert_system_entities) == 0: + self._logger.warning('_cancel_global_audio_pause called, but no AlertSystemDescriptor in mdib found') return ExecuteResult(params.operation_instance.operation_target_handle, InvocationState.FAILED) - for alert_system_descriptor in alert_system_descriptors: - _alert_system_state = mgr.get_state(alert_system_descriptor.Handle) - alert_system_state = cast(AlertSystemStateContainer, _alert_system_state) - if alert_system_state.ActivationState != pm_types.AlertActivation.ON: - self._logger.info('SDC_CancelAudioPauseOperation: nothing to do') - mgr.unget_state(_alert_system_state) - else: - audible_signals = [ssa for ssa in alert_system_state.SystemSignalActivation if - ssa.Manifestation == pm_types.AlertSignalManifestation.AUD] - paused_audible_signals = [ssa for ssa in audible_signals if - ssa.State == pm_types.AlertActivation.PAUSED] - if not paused_audible_signals: - mgr.unget_state(_alert_system_state) - else: - for ssa in paused_audible_signals: - ssa.State = pm_types.AlertActivation.ON - self._logger.info('SetAudioPauseOperation: set alert system "%s" to ON', - alert_system_descriptor.Handle) - # handle all audible alert signals of this alert system - all_alert_signal_descriptors = self._mdib.descriptions.NODETYPE.get( - pm_names.AlertSignalDescriptor, []) - child_alert_signal_descriptors = [d for d in all_alert_signal_descriptors if - d.parent_handle == alert_system_descriptor.Handle] - audible_child_alert_signal_descriptors = [d for d in child_alert_signal_descriptors if - d.Manifestation == pm_types.AlertSignalManifestation.AUD] - for descriptor in audible_child_alert_signal_descriptors: - alert_signal_state = mgr.get_state(descriptor.Handle) - alert_condition_state = self._mdib.states.descriptor_handle.get_one( - descriptor.ConditionSignaled) - if alert_condition_state.Presence: - # set signal back to 'ON' - if alert_signal_state.ActivationState == pm_types.AlertActivation.PAUSED: - alert_signal_state.ActivationState = pm_types.AlertActivation.ON - alert_signal_state.Presence = pm_types.AlertSignalPresence.ON - else: - mgr.unget_state(alert_signal_state) + for as_entity in alert_system_entities: + if as_entity.state.ActivationState != pm_types.AlertActivation.ON: + self._logger.info('_cancel_global_audio_pause: nothing to do for alert system %s', as_entity.handle) + continue + + audible_signals = [ssa for ssa in as_entity.state.SystemSignalActivation if + ssa.Manifestation == pm_types.AlertSignalManifestation.AUD] + paused_audible_signals = [ssa for ssa in audible_signals if + ssa.State == pm_types.AlertActivation.PAUSED] + if len(paused_audible_signals) > 0: + for ssa in paused_audible_signals: + ssa.State = pm_types.AlertActivation.ON + self._logger.info('_cancel_global_audio_pause: set alert system "%s" to ON', as_entity.handle) + # handle all audible alert signals of this alert system + all_alert_signal_entities = self._mdib.entities.by_node_type(pm_names.AlertSignalDescriptor) + child_alert_signals = [e for e in all_alert_signal_entities if + e.parent_handle == as_entity.handle] + audible_child_alert_signals = [d for d in child_alert_signals if + d.descriptor.Manifestation == pm_types.AlertSignalManifestation.AUD] + for aud_signal in audible_child_alert_signals: + alert_condition_entity = self._mdib.entities.by_handle(aud_signal.descriptor.ConditionSignaled) + if alert_condition_entity.state.Presence: + # set signal back to 'ON' + if aud_signal.state.ActivationState == pm_types.AlertActivation.PAUSED: + aud_signal.state.ActivationState = pm_types.AlertActivation.ON + aud_signal.state.Presence = pm_types.AlertSignalPresence.ON + mgr.write_entity(aud_signal) + mgr.write_entity(as_entity) return ExecuteResult(params.operation_instance.operation_target_handle, self._mdib.data_model.msg_types.InvocationState.FINISHED) @@ -195,28 +187,21 @@ def make_missing_operations(self, sco: AbstractScoOperationsRegistry) -> list[Op ops = [] # in this case only the top level sco shall have the additional operations. # Check if this is the top level sco (parent is mds) - parent_descriptor = self._mdib.descriptions.handle.get_one(sco.sco_descriptor_container.parent_handle) - if pm_names.MdsDescriptor != parent_descriptor.NODETYPE: + parent_entity = self._mdib.entities.by_handle(sco.sco_descriptor_container.parent_handle) + if pm_names.MdsDescriptor != parent_entity.descriptor.NODETYPE: return ops operation_cls_getter = sco.operation_cls_getter # find mds for this sco - mds_descr = None - current_descr = sco.sco_descriptor_container - while mds_descr is None: - parent_descr = self._mdib.descriptions.handle.get_one(current_descr.parent_handle) - if parent_descr is None: - raise ValueError(f'could not find mds descriptor for sco {sco.sco_descriptor_container.Handle}') - if pm_names.MdsDescriptor == parent_descr.NODETYPE: - mds_descr = parent_descr - else: - current_descr = parent_descr - operation_target_container = mds_descr # the operation target is the mds itself + mds_entity = self._mdib.entities.by_handle(parent_entity.descriptor.source_mds) + if mds_entity is None: + raise ValueError(f"no source mds found for entity {parent_entity.handle}") + activate_op_cls = operation_cls_getter(pm_names.ActivateOperationDescriptor) if not self._set_global_audio_pause_operations: self._logger.debug('adding "set audio pause" operation, no descriptor in mdib (looked for code = %s)', NomenclatureCodes.MDC_OP_SET_ALL_ALARMS_AUDIO_PAUSE) set_ap_operation = activate_op_cls('AP__ON', - operation_target_container.Handle, + mds_entity.handle, self._set_global_audio_pause, coded_value=pm_types.CodedValue( NomenclatureCodes.MDC_OP_SET_ALL_ALARMS_AUDIO_PAUSE)) @@ -226,7 +211,7 @@ def make_missing_operations(self, sco: AbstractScoOperationsRegistry) -> list[Op self._logger.debug('adding "cancel audio pause" operation, no descriptor in mdib (looked for code = %s)', NomenclatureCodes.MDC_OP_SET_CANCEL_ALARMS_AUDIO_PAUSE) cancel_ap_operation = activate_op_cls('AP__CANCEL', - operation_target_container.Handle, + mds_entity.handle, self._cancel_global_audio_pause, coded_value=pm_types.CodedValue( NomenclatureCodes.MDC_OP_SET_CANCEL_ALARMS_AUDIO_PAUSE)) diff --git a/src/sdc11073/roles/clockprovider.py b/src/sdc11073/roles/clockprovider.py index 6666469e..f44c54ad 100644 --- a/src/sdc11073/roles/clockprovider.py +++ b/src/sdc11073/roles/clockprovider.py @@ -1,5 +1,7 @@ +"""Implementation of clock provider functionality.""" from __future__ import annotations +import uuid from typing import TYPE_CHECKING from sdc11073.provider.operations import ExecuteResult @@ -9,7 +11,7 @@ if TYPE_CHECKING: from sdc11073.mdib.descriptorcontainers import AbstractDescriptorProtocol, AbstractOperationDescriptorProtocol - from sdc11073.mdib.providermdib import ProviderMdib + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol from sdc11073.provider.operations import ExecuteParameters, OperationDefinitionBase from sdc11073.provider.sco import AbstractScoOperationsRegistry from sdc11073.xml_types.pm_types import CodedValue, SafetyClassification @@ -23,7 +25,7 @@ class GenericSDCClockProvider(ProviderRole): Nothing is added to the mdib. If the mdib does not contain these operations, the functionality is not available. """ - def __init__(self, mdib: ProviderMdib, log_prefix: str): + def __init__(self, mdib: ProviderMdibProtocol, log_prefix: str): super().__init__(mdib, log_prefix) self._set_ntp_operations = [] self._set_tz_operations = [] @@ -37,22 +39,25 @@ def init_operations(self, sco: AbstractScoOperationsRegistry): super().init_operations(sco) pm_types = self._mdib.data_model.pm_types pm_names = self._mdib.data_model.pm_names - clock_descriptor = self._mdib.descriptions.NODETYPE.get_one(pm_names.ClockDescriptor, - allow_none=True) - if clock_descriptor is None: - mds_container = self._mdib.descriptions.NODETYPE.get_one(pm_names.MdsDescriptor) - clock_descr_handle = 'clock_' + mds_container.Handle + clock_entities = self._mdib.entities.by_node_type(pm_names.ClockDescriptor) + if len(clock_entities) == 0: + mds_entities = self._mdib.entitiesby_.node_type(pm_names.MdsDescriptor) + if len(mds_entities) == 0: + self._logger.info('empty mdib, cannot create a clock descriptor') + return + # create a clock descriptor for the first mds + # Todo: create for all? + my_mds_entity = mds_entities[0] + clock_descr_handle = 'clock_' + uuid.uuid4().hex self._logger.debug('creating a clock descriptor, handle=%s', clock_descr_handle) - clock_descriptor = self._create_clock_descriptor_container( - handle=clock_descr_handle, - parent_handle=mds_container.Handle, - coded_value=pm_types.CodedValue('123'), - safety_classification=pm_types.SafetyClassification.INF) - self._mdib.descriptions.add_object(clock_descriptor) - clock_state = self._mdib.states.descriptor_handle.get_one(clock_descriptor.Handle, allow_none=True) - if clock_state is None: - clock_state = self._mdib.data_model.mk_state_container(clock_descriptor) - self._mdib.states.add_object(clock_state) + model = self._mdib.data_model + clock_entity = self._mdib.entities.new_entity(model.pm_names.ClockDescriptor, + handle = clock_descr_handle, + parent_handle=my_mds_entity.handle) + clock_entity.descriptor.SafetyClassification = pm_types.SafetyClassification.INF + clock_entity.descriptor.Type = pm_types.CodedValue('123') + with self._mdib.descriptor_transaction() as mgr: + mgr.write_entity(clock_entity) def make_operation_instance(self, operation_descriptor_container: AbstractOperationDescriptorProtocol, @@ -86,19 +91,23 @@ def _set_ntp_string(self, params: ExecuteParameters) -> ExecuteResult: self._logger.info('set value %s from %s to %s', params.operation_instance.operation_target_handle, params.operation_instance.current_value, value) + + op_target_entity = self._mdib.entities.by_handle(params.operation_instance.operation_target_handle) + + # look for clock entities that are a direct child of this mds + mds_handle = op_target_entity.descriptor.source_mds or op_target_entity.handle + clock_entities = self._mdib.entities.by_node_type(pm_names.ClockDescriptor) + clock_entities = [c for c in clock_entities if c.parent_handle == mds_handle] + + if len(clock_entities) == 0: + self._logger.warning('_set_ntp_string: no clock entity found') + return ExecuteResult(params.operation_instance.operation_target_handle, + self._mdib.data_model.msg_types.InvocationState.FAILED, + ) + + clock_entities[0].state.ReferenceSource = [value] with self._mdib.component_state_transaction() as mgr: - state = mgr.get_state(params.operation_instance.operation_target_handle) - if pm_names.MdsState == state.NODETYPE: - mds_handle = state.DescriptorHandle - mgr.unget_state(state) - # look for the ClockState child - clock_descriptors = self._mdib.descriptions.NODETYPE.get(pm_names.ClockDescriptor, []) - clock_descriptors = [c for c in clock_descriptors if c.parent_handle == mds_handle] - if len(clock_descriptors) == 1: - state = mgr.get_state(clock_descriptors[0].handle) - if pm_names.ClockState != state.NODETYPE: - raise ValueError(f'_set_ntp_string: expected ClockState, got {state.NODETYPE.localname}') - state.ReferenceSource = [value] + mgr.write_entity(clock_entities[0]) return ExecuteResult(params.operation_instance.operation_target_handle, self._mdib.data_model.msg_types.InvocationState.FINISHED) @@ -109,20 +118,22 @@ def _set_tz_string(self, params: ExecuteParameters) -> ExecuteResult: self._logger.info('set value %s from %s to %s', params.operation_instance.operation_target_handle, params.operation_instance.current_value, value) + + op_target_entity = self._mdib.entities.by_handle(params.operation_instance.operation_target_handle) + + # look for clock entities that are a direct child of this mds + mds_handle = op_target_entity.descriptor.source_mds or op_target_entity.handle + clock_entities = self._mdib.entities.by_node_type(pm_names.ClockDescriptor) + clock_entities = [c for c in clock_entities if c.parent_handle == mds_handle] + + if len(clock_entities) == 0: + self._logger.warning('_set_ntp_string: no clock entity found') + return ExecuteResult(params.operation_instance.operation_target_handle, + self._mdib.data_model.msg_types.InvocationState.FAILED) + + clock_entities[0].state.TimeZone = value with self._mdib.component_state_transaction() as mgr: - state = mgr.get_state(params.operation_instance.operation_target_handle) - if pm_names.MdsState == state.NODETYPE: - mds_handle = state.DescriptorHandle - mgr.unget_state(state) - # look for the ClockState child - clock_descriptors = self._mdib.descriptions.NODETYPE.get(pm_names.ClockDescriptor, []) - clock_descriptors = [c for c in clock_descriptors if c.parent_handle == mds_handle] - if len(clock_descriptors) == 1: - state = mgr.get_state(clock_descriptors[0].handle) - - if pm_names.ClockState != state.NODETYPE: - raise ValueError(f'_set_ntp_string: expected ClockState, got {state.NODETYPE.localname}') - state.TimeZone = value + mgr.write_entity(clock_entities[0]) return ExecuteResult(params.operation_instance.operation_target_handle, self._mdib.data_model.msg_types.InvocationState.FINISHED) @@ -140,7 +151,11 @@ def _create_clock_descriptor_container(self, handle: str, """ model = self._mdib.data_model cls = model.get_descriptor_container_class(model.pm_names.ClockDescriptor) - return self._create_descriptor_container(cls, handle, parent_handle, coded_value, safety_classification) + return self._create_descriptor_container(cls, + handle, + parent_handle, + coded_value, + safety_classification) class SDCClockProvider(GenericSDCClockProvider): diff --git a/src/sdc11073/roles/componentprovider.py b/src/sdc11073/roles/componentprovider.py index e43e5606..c51ac64a 100644 --- a/src/sdc11073/roles/componentprovider.py +++ b/src/sdc11073/roles/componentprovider.py @@ -1,3 +1,4 @@ +"""Implementation of component provider functionality.""" from __future__ import annotations from typing import TYPE_CHECKING @@ -8,8 +9,7 @@ if TYPE_CHECKING: from sdc11073.mdib.descriptorcontainers import AbstractOperationDescriptorProtocol - from sdc11073.provider.operations import OperationDefinitionBase - from sdc11073.provider.operations import ExecuteParameters + from sdc11073.provider.operations import ExecuteParameters, OperationDefinitionBase from .providerbase import OperationClassGetter @@ -28,15 +28,15 @@ def make_operation_instance(self, """ pm_names = self._mdib.data_model.pm_names operation_target_handle = operation_descriptor_container.OperationTarget - op_target_descriptor_container = self._mdib.descriptions.handle.get_one(operation_target_handle) + op_target_entity = self._mdib.entities.by_handle(operation_target_handle) if operation_descriptor_container.NODETYPE == pm_names.SetComponentStateOperationDescriptor: # noqa: SIM300 - if op_target_descriptor_container.NODETYPE in (pm_names.MdsDescriptor, - pm_names.ChannelDescriptor, - pm_names.VmdDescriptor, - pm_names.ClockDescriptor, - pm_names.ScoDescriptor, - ): + if op_target_entity.node_type in (pm_names.MdsDescriptor, + pm_names.ChannelDescriptor, + pm_names.VmdDescriptor, + pm_names.ClockDescriptor, + pm_names.ScoDescriptor, + ): op_cls = operation_cls_getter(pm_names.SetComponentStateOperationDescriptor) return op_cls(operation_descriptor_container.Handle, operation_target_handle, @@ -44,11 +44,11 @@ def make_operation_instance(self, coded_value=operation_descriptor_container.Type) elif operation_descriptor_container.NODETYPE == pm_names.ActivateOperationDescriptor: # noqa: SIM300 # on what can activate be called? - if op_target_descriptor_container.NODETYPE in (pm_names.MdsDescriptor, - pm_names.ChannelDescriptor, - pm_names.VmdDescriptor, - pm_names.ScoDescriptor, - ): + if op_target_entity.node_type in (pm_names.MdsDescriptor, + pm_names.ChannelDescriptor, + pm_names.VmdDescriptor, + pm_names.ScoDescriptor, + ): # no generic handler to be called! op_cls = operation_cls_getter(pm_names.ActivateOperationDescriptor) return op_cls(operation_descriptor_container.Handle, @@ -64,15 +64,16 @@ def _set_component_state(self, params: ExecuteParameters) -> ExecuteResult: params.operation_instance.current_value = value with self._mdib.component_state_transaction() as mgr: for proposed_state in value: - state = mgr.get_state(proposed_state.DescriptorHandle) - if state.is_component_state: - self._logger.info('updating %s with proposed component state', state) - state.update_from_other_container(proposed_state, - skipped_properties=['StateVersion', 'DescriptorVersion']) + entity = self._mdib.entities.by_handle(proposed_state.DescriptorHandle) + if entity.state.is_component_state: + self._logger.info('updating %s with proposed component state', entity.state) + entity.state.update_from_other_container( + proposed_state, skipped_properties=['StateVersion', 'DescriptorVersion']) + mgr.write_entity(entity) else: self._logger.warning( '_set_component_state operation: ignore invalid referenced type %s in operation', - state.NODETYPE.localname) + entity.node_type.localname) return ExecuteResult(params.operation_instance.operation_target_handle, self._mdib.data_model.msg_types.InvocationState.FINISHED) diff --git a/src/sdc11073/roles/contextprovider.py b/src/sdc11073/roles/contextprovider.py index ea351989..6f3326ae 100644 --- a/src/sdc11073/roles/contextprovider.py +++ b/src/sdc11073/roles/contextprovider.py @@ -1,3 +1,4 @@ +"""Implementation of context provider functionality.""" from __future__ import annotations import time @@ -6,13 +7,14 @@ from typing import TYPE_CHECKING from sdc11073.provider.operations import ExecuteResult + from . import providerbase if TYPE_CHECKING: from lxml import etree from sdc11073.mdib.descriptorcontainers import AbstractOperationDescriptorProtocol - from sdc11073.mdib.providermdib import ProviderMdib + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol from sdc11073.provider.operations import ExecuteParameters, OperationDefinitionBase from .providerbase import OperationClassGetter @@ -21,7 +23,7 @@ class GenericContextProvider(providerbase.ProviderRole): """Handles SetContextState operations.""" - def __init__(self, mdib: ProviderMdib, + def __init__(self, mdib: ProviderMdibProtocol, op_target_descr_types: list[etree.QName] | None = None, log_prefix: str | None = None): super().__init__(mdib, log_prefix) @@ -36,17 +38,17 @@ def make_operation_instance(self, """ pm_names = self._mdib.data_model.pm_names if pm_names.SetContextStateOperationDescriptor == operation_descriptor_container.NODETYPE: - op_target_descr_container = self._mdib.descriptions.handle.get_one( - operation_descriptor_container.OperationTarget) + op_target_entity = self._mdib.entities.by_handle( operation_descriptor_container.OperationTarget) + if (not self._op_target_descr_types) or ( - op_target_descr_container.NODETYPE not in self._op_target_descr_types): + op_target_entity.descriptor.NODETYPE not in self._op_target_descr_types): return None # we do not handle this target type return self._mk_operation_from_operation_descriptor(operation_descriptor_container, operation_cls_getter, operation_handler=self._set_context_state) return None - def _set_context_state(self, params: ExecuteParameters) -> ExecuteResult: + def _set_context_state(self, params: ExecuteParameters) -> ExecuteResult: # noqa: C901, PLR0912 """Execute the operation itself (ExecuteHandler). If the proposed context is a new context and ContextAssociation == pm_types.ContextAssociation.ASSOCIATED, @@ -66,13 +68,16 @@ def _set_context_state(self, params: ExecuteParameters) -> ExecuteResult: raise ValueError(f'more than one associated context for descriptor handle {handle}') operation_target_handles = [] + modified_state_handles: dict[str, list[str]] = defaultdict(list) + modified_entities = [] with self._mdib.context_state_transaction() as mgr: for proposed_st in proposed_context_states: + entity = self._mdib.entities.by_handle(proposed_st.DescriptorHandle) + modified_entities.append(entity) old_state_container = None if proposed_st.DescriptorHandle != proposed_st.Handle: - # this is an update for an existing state - old_state_container = self._mdib.context_states.handle.get_one( - proposed_st.Handle, allow_none=True) + # this is an update for an existing state or a new one + old_state_container = entity.states.get(proposed_st.Handle) if old_state_container is None: raise ValueError(f'handle {proposed_st.Handle} not found') if old_state_container is None: @@ -80,37 +85,50 @@ def _set_context_state(self, params: ExecuteParameters) -> ExecuteResult: # create a new unique handle proposed_st.Handle = uuid.uuid4().hex if proposed_st.ContextAssociation == pm_types.ContextAssociation.ASSOCIATED: + # disassociate existing states + handles = self._mdib.xtra.disassociate_all(entity, + unbinding_mdib_version = mgr.new_mdib_version) + operation_target_handles.extend(handles) + modified_state_handles[entity.handle].extend(handles) + # set version and time in new state proposed_st.BindingMdibVersion = mgr.new_mdib_version proposed_st.BindingStartTime = time.time() self._logger.info('new %s, DescriptorHandle=%s Handle=%s', proposed_st.NODETYPE.localname, proposed_st.DescriptorHandle, proposed_st.Handle) - mgr.add_state(proposed_st) - - if proposed_st.ContextAssociation == pm_types.ContextAssociation.ASSOCIATED: - operation_target_handles.extend(mgr.disassociate_all(proposed_st.DescriptorHandle)) + # add to entity, and keep handle for later + entity.states[proposed_st.Handle] = proposed_st else: # this is an update to an existing patient # use "regular" way to update via transaction manager self._logger.info('update %s, handle=%s', proposed_st.NODETYPE.localname, proposed_st.Handle) - old_state = mgr.get_context_state(proposed_st.Handle) # handle changed ContextAssociation - if (old_state.ContextAssociation == pm_types.ContextAssociation.ASSOCIATED + if (old_state_container.ContextAssociation == pm_types.ContextAssociation.ASSOCIATED and proposed_st.ContextAssociation != pm_types.ContextAssociation.ASSOCIATED): proposed_st.UnbindingMdibVersion = mgr.new_mdib_version proposed_st.BindingEndTime = time.time() - elif (old_state.ContextAssociation != pm_types.ContextAssociation.ASSOCIATED + elif (old_state_container.ContextAssociation != pm_types.ContextAssociation.ASSOCIATED and proposed_st.ContextAssociation == pm_types.ContextAssociation.ASSOCIATED): proposed_st.BindingMdibVersion = mgr.new_mdib_version proposed_st.BindingStartTime = time.time() - operation_target_handles.extend(mgr.disassociate_all(proposed_st.DescriptorHandle, - ignored_handle=old_state.Handle)) - - old_state.update_from_other_container(proposed_st, skipped_properties=['BindingMdibVersion', + handles = self._mdib.xtra.disassociate_all(entity, + unbinding_mdib_version = mgr.new_mdib_version, + ignored_handle=old_state_container.Handle) + operation_target_handles.extend(handles) + modified_state_handles[entity.handle].extend(handles) + old_state_container.update_from_other_container(proposed_st, skipped_properties=[ + 'BindingMdibVersion', 'UnbindingMdibVersion', 'BindingStartTime', 'BindingEndTime', 'StateVersion']) + modified_state_handles[entity.handle].append(proposed_st.Handle) operation_target_handles.append(proposed_st.Handle) + + # write changes back to mdib + for entity in modified_entities: + handles = modified_state_handles[entity.handle] + mgr.write_entity(entity, handles) + if len(operation_target_handles) == 1: return ExecuteResult(operation_target_handles[0], self._mdib.data_model.msg_types.InvocationState.FINISHED) @@ -123,7 +141,7 @@ def _set_context_state(self, params: ExecuteParameters) -> ExecuteResult: class EnsembleContextProvider(GenericContextProvider): """EnsembleContextProvider.""" - def __init__(self, mdib: ProviderMdib, log_prefix: str | None = None): + def __init__(self, mdib: ProviderMdibProtocol, log_prefix: str | None = None): super().__init__(mdib, op_target_descr_types=[mdib.data_model.pm_names.EnsembleContextDescriptor], log_prefix=log_prefix) @@ -132,7 +150,7 @@ def __init__(self, mdib: ProviderMdib, log_prefix: str | None = None): class LocationContextProvider(GenericContextProvider): """LocationContextProvider.""" - def __init__(self, mdib: ProviderMdib, log_prefix: str | None = None): + def __init__(self, mdib: ProviderMdibProtocol, log_prefix: str | None = None): super().__init__(mdib, op_target_descr_types=[mdib.data_model.pm_names.LocationContextDescriptor], log_prefix=log_prefix) diff --git a/src/sdc11073/roles/metricprovider.py b/src/sdc11073/roles/metricprovider.py index fd08d217..27fddc9f 100644 --- a/src/sdc11073/roles/metricprovider.py +++ b/src/sdc11073/roles/metricprovider.py @@ -1,3 +1,4 @@ +"""Implementation of metric provider functionality.""" from __future__ import annotations from typing import TYPE_CHECKING, cast @@ -11,9 +12,9 @@ if TYPE_CHECKING: from collections.abc import Iterable - from sdc11073.mdib import ProviderMdib from sdc11073.mdib.descriptorcontainers import AbstractOperationDescriptorProtocol - from sdc11073.mdib.transactionsprotocol import TransactionManagerProtocol, TransactionItem + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol + from sdc11073.mdib.transactionsprotocol import StateTransactionManagerProtocol, TransactionItem from sdc11073.provider.operations import ExecuteParameters, OperationDefinitionBase from .providerbase import OperationClassGetter @@ -27,7 +28,7 @@ class GenericMetricProvider(ProviderRole): - SetStringOperation on (enum) string metrics """ - def __init__(self, mdib: ProviderMdib, + def __init__(self, mdib: ProviderMdibProtocol, activation_state_can_remove_metric_value: bool = True, log_prefix: str | None = None): """Create a GenericMetricProvider.""" @@ -49,10 +50,10 @@ def make_operation_instance(self, """ pm_names = self._mdib.data_model.pm_names operation_target_handle = operation_descriptor_container.OperationTarget - op_target_descriptor_container = self._mdib.descriptions.handle.get_one(operation_target_handle) + op_target_entity = self._mdib.entities.by_handle(operation_target_handle) if operation_descriptor_container.NODETYPE == pm_names.SetValueOperationDescriptor: # noqa: SIM300 - if op_target_descriptor_container.NODETYPE == pm_names.NumericMetricDescriptor: # noqa: SIM300 + if op_target_entity.node_type == pm_names.NumericMetricDescriptor: op_cls = operation_cls_getter(pm_names.SetValueOperationDescriptor) return op_cls(operation_descriptor_container.Handle, operation_target_handle, @@ -60,8 +61,8 @@ def make_operation_instance(self, coded_value=operation_descriptor_container.Type) return None if operation_descriptor_container.NODETYPE == pm_names.SetStringOperationDescriptor: # noqa: SIM300 - if op_target_descriptor_container.NODETYPE in (pm_names.StringMetricDescriptor, - pm_names.EnumStringMetricDescriptor): + if op_target_entity.node_type in (pm_names.StringMetricDescriptor, + pm_names.EnumStringMetricDescriptor): op_cls = operation_cls_getter(pm_names.SetStringOperationDescriptor) return op_cls(operation_descriptor_container.Handle, operation_target_handle, @@ -81,21 +82,23 @@ def _set_metric_state(self, params: ExecuteParameters) -> ExecuteResult: # ToDo: consider ModifiableDate attribute proposed_states = params.operation_request.argument params.operation_instance.current_value = proposed_states + with self._mdib.metric_state_transaction() as mgr: for proposed_state in proposed_states: - state = mgr.get_state(proposed_state.DescriptorHandle) - if state.is_metric_state: - self._logger.info('updating %s with proposed metric state', state) - state.update_from_other_container(proposed_state, + target_entity = self._mdib.entities.by_handle(proposed_state.DescriptorHandle) + if target_entity.state.is_metric_state: + self._logger.info('updating %s with proposed metric state', target_entity.state) + target_entity.state.update_from_other_container(proposed_state, skipped_properties=['StateVersion', 'DescriptorVersion']) + mgr.write_entity(target_entity) else: self._logger.warning('_set_metric_state operation: ignore invalid referenced type %s in operation', - state.NODETYPE) + target_entity.state.NODETYPE) return ExecuteResult(params.operation_instance.operation_target_handle, self._mdib.data_model.msg_types.InvocationState.FINISHED) - def on_pre_commit(self, mdib: ProviderMdib, # noqa: ARG002 - transaction: TransactionManagerProtocol): + def on_pre_commit(self, mdib: ProviderMdibProtocol, # noqa: ARG002 + transaction: StateTransactionManagerProtocol): """Set state.MetricValue to None if state.ActivationState requires this.""" if not self.activation_state_can_remove_metric_value: return @@ -129,19 +132,19 @@ def _set_numeric_value(self, params: ExecuteParameters) -> ExecuteResult: params.operation_instance.handle, params.operation_instance.current_value, value) params.operation_instance.current_value = value + entity = self._mdib.entities.by_handle(params.operation_instance.operation_target_handle) + state = cast(MetricStateProtocol, entity.state) + if state.MetricValue is None: + state.mk_metric_value() + state.MetricValue.Value = value + # SF1823: For Metrics with the MetricCategory = Set|Preset that are being modified as a result of a + # SetValue or SetString operation a Metric Provider shall set the MetricQuality / Validity = Vld. + if entity.descriptor.MetricCategory in (pm_types.MetricCategory.SETTING, + pm_types.MetricCategory.PRESETTING): + state.MetricValue.Validity = pm_types.MeasurementValidity.VALID + with self._mdib.metric_state_transaction() as mgr: - _state = mgr.get_state(params.operation_instance.operation_target_handle) - state = cast(MetricStateProtocol, _state) - if state.MetricValue is None: - state.mk_metric_value() - state.MetricValue.Value = value - # SF1823: For Metrics with the MetricCategory = Set|Preset that are being modified as a result of a - # SetValue or SetString operation a Metric Provider shall set the MetricQuality / Validity = Vld. - metric_descriptor_container = self._mdib.descriptions.handle.get_one( - params.operation_instance.operation_target_handle) - if metric_descriptor_container.MetricCategory in (pm_types.MetricCategory.SETTING, - pm_types.MetricCategory.PRESETTING): - state.MetricValue.Validity = pm_types.MeasurementValidity.VALID + mgr.write_entity(entity) return ExecuteResult(params.operation_instance.operation_target_handle, self._mdib.data_model.msg_types.InvocationState.FINISHED) @@ -152,11 +155,13 @@ def _set_string(self, params: ExecuteParameters) -> ExecuteResult: params.operation_instance.operation_target_handle, params.operation_instance.current_value, value) params.operation_instance.current_value = value + entity = self._mdib.entities.by_handle(params.operation_instance.operation_target_handle) + state = cast(MetricStateProtocol, entity.state) + if state.MetricValue is None: + state.mk_metric_value() + state.MetricValue.Value = value + with self._mdib.metric_state_transaction() as mgr: - _state = mgr.get_state(params.operation_instance.operation_target_handle) - state = cast(MetricStateProtocol, _state) - if state.MetricValue is None: - state.mk_metric_value() - state.MetricValue.Value = value + mgr.write_entity(entity) return ExecuteResult(params.operation_instance.operation_target_handle, self._mdib.data_model.msg_types.InvocationState.FINISHED) diff --git a/src/sdc11073/roles/patientcontextprovider.py b/src/sdc11073/roles/patientcontextprovider.py index 10fe4e89..3ed952db 100644 --- a/src/sdc11073/roles/patientcontextprovider.py +++ b/src/sdc11073/roles/patientcontextprovider.py @@ -1,3 +1,4 @@ +"""Implementation of patient provider functionality.""" from __future__ import annotations from typing import TYPE_CHECKING @@ -6,7 +7,7 @@ if TYPE_CHECKING: from sdc11073.mdib.descriptorcontainers import AbstractOperationDescriptorProtocol - from sdc11073.mdib.providermdib import ProviderMdib + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol from sdc11073.provider.operations import OperationDefinitionBase from sdc11073.provider.sco import AbstractScoOperationsRegistry @@ -20,25 +21,26 @@ class GenericPatientContextProvider(GenericContextProvider): Nothing is added to the mdib. If the mdib does not contain these operations, the functionality is not available. """ - def __init__(self, mdib: ProviderMdib, log_prefix: str | None): + def __init__(self, mdib: ProviderMdibProtocol, log_prefix: str | None): super().__init__(mdib, log_prefix=log_prefix) - self._patient_context_descriptor_container = None + self._patient_context_entity = None self._set_patient_context_operations = [] def init_operations(self, sco: AbstractScoOperationsRegistry): """Find the PatientContextDescriptor.""" super().init_operations(sco) pm_names = self._mdib.data_model.pm_names - descriptor_containers = self._mdib.descriptions.NODETYPE.get(pm_names.PatientContextDescriptor) - if descriptor_containers is not None and len(descriptor_containers) == 1: - self._patient_context_descriptor_container = descriptor_containers[0] + entities = self._mdib.entities.by_node_type(pm_names.PatientContextDescriptor) + # Todo: what to do in multi mds case? + if len(entities) == 1: + self._patient_context_entity = entities[0] def make_operation_instance(self, operation_descriptor_container: AbstractOperationDescriptorProtocol, operation_cls_getter: OperationClassGetter) -> OperationDefinitionBase | None: """Add Operation Handler if operation target is the previously found PatientContextDescriptor.""" - if self._patient_context_descriptor_container and \ - operation_descriptor_container.OperationTarget == self._patient_context_descriptor_container.Handle: + if self._patient_context_entity is not None and \ + operation_descriptor_container.OperationTarget == self._patient_context_entity.handle: pc_operation = self._mk_operation_from_operation_descriptor(operation_descriptor_container, operation_cls_getter, operation_handler=self._set_context_state) @@ -55,10 +57,10 @@ def make_missing_operations(self, sco: AbstractScoOperationsRegistry) -> list[Op pm_names = self._mdib.data_model.pm_names ops = [] operation_cls_getter = sco.operation_cls_getter - if self._patient_context_descriptor_container and not self._set_patient_context_operations: + if self._patient_context_entity is not None and not self._set_patient_context_operations: op_cls = operation_cls_getter(pm_names.SetContextStateOperationDescriptor) pc_operation = op_cls('opSetPatCtx', - self._patient_context_descriptor_container.handle, + self._patient_context_entity.handle, self._set_context_state, coded_value=None) ops.append(pc_operation) diff --git a/src/sdc11073/roles/product.py b/src/sdc11073/roles/product.py index a37dcfef..07f10606 100644 --- a/src/sdc11073/roles/product.py +++ b/src/sdc11073/roles/product.py @@ -1,10 +1,14 @@ +"""Implementation of products. + +A product is a set of role providers that handle operations and other tasks. +""" from __future__ import annotations from typing import TYPE_CHECKING, Protocol from sdc11073 import loghelper -from .alarmprovider import GenericAlarmProvider +from .alarmprovider import AlertDelegateProvider, AlertPreCommitHandler, AlertSystemStateMaintainer from .audiopauseprovider import AudioPauseProvider from .clockprovider import GenericSDCClockProvider from .componentprovider import GenericSetComponentStateOperationProvider @@ -14,9 +18,9 @@ from .patientcontextprovider import GenericPatientContextProvider if TYPE_CHECKING: - from sdc11073.mdib import ProviderMdib from sdc11073.mdib.descriptorcontainers import AbstractOperationDescriptorProtocol - from sdc11073.mdib.transactionsprotocol import TransactionManagerProtocol + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol + from sdc11073.mdib.transactionsprotocol import AnyTransactionManagerProtocol from sdc11073.provider.operations import OperationDefinitionBase from sdc11073.provider.sco import AbstractScoOperationsRegistry @@ -53,10 +57,10 @@ def make_missing_operations(self, sco: AbstractScoOperationsRegistry) -> list[Op If a role provider needs to add operations beyond that, it can do it here. """ - def on_pre_commit(self, mdib: ProviderMdib, transaction: TransactionManagerProtocol): + def on_pre_commit(self, mdib: ProviderMdibProtocol, transaction: AnyTransactionManagerProtocol): """Manipulate operation (e.g. add more states).""" - def on_post_commit(self, mdib: ProviderMdib, transaction: TransactionManagerProtocol): + def on_post_commit(self, mdib: ProviderMdibProtocol, transaction: AnyTransactionManagerProtocol): """Implement actions after the transaction.""" ... @@ -69,14 +73,13 @@ class BaseProduct: """ def __init__(self, - mdib: ProviderMdib, + mdib: ProviderMdibProtocol, sco: AbstractScoOperationsRegistry, log_prefix: str | None = None): """Create a product.""" self._sco = sco self._mdib = mdib - self._model = mdib.data_model - self._ordered_providers: list[ProviderRoleProtocol] = [] # order matters, first come, first serve + self._ordered_providers: list[ProviderRoleProtocol] = [] # order matters, first come, first served # start with most specific providers, end with most general ones self._logger = loghelper.get_logger_adapter(f'sdc.device.{self.__class__.__name__}', log_prefix) @@ -102,8 +105,8 @@ def init_operations(self): for operation in operations: self._sco.register_operation(operation) - all_sco_operations = self._mdib.descriptions.parent_handle.get(self._sco.sco_descriptor_container.Handle, []) - all_op_handles = [op.Handle for op in all_sco_operations] + all_sco_operations = self._mdib.entities.by_parent_handle(self._sco.sco_descriptor_container.Handle) + all_op_handles = [op.handle for op in all_sco_operations] all_not_registered_op_handles = [op_h for op_h in all_op_handles if self._sco.get_operation_by_handle(op_h) is None] @@ -114,7 +117,6 @@ def init_operations(self): sco_handle, all_not_registered_op_handles) else: self._logger.info('sco %s: all operations have a handler.', sco_handle) - self._mdib.xtra.mk_state_containers_for_all_descriptors() self._mdib.pre_commit_handler = self._on_pre_commit self._mdib.post_commit_handler = self._on_post_commit @@ -128,9 +130,8 @@ def make_operation_instance(self, operation_cls_getter: OperationClassGetter) -> OperationDefinitionBase | None: """Call make_operation_instance of all role providers, until the first returns not None.""" operation_target_handle = operation_descriptor_container.OperationTarget - operation_target_descr = self._mdib.descriptions.handle.get_one(operation_target_handle, - allow_none=True) # descriptor container - if operation_target_descr is None: + operation_target_entity = self._mdib.entities.by_handle(operation_target_handle) + if operation_target_entity is None: # this operation is incomplete, the operation target does not exist. Registration not possible. self._logger.warning('Operation %s: target %s does not exist, will not register operation', operation_descriptor_container.Handle, operation_target_handle) @@ -145,23 +146,24 @@ def make_operation_instance(self, return None def _register_existing_mdib_operations(self, sco: AbstractScoOperationsRegistry): - operation_descriptor_containers = self._mdib.descriptions.parent_handle.get( - self._sco.sco_descriptor_container.Handle, []) - for descriptor in operation_descriptor_containers: - registered_op = sco.get_operation_by_handle(descriptor.Handle) + operation_entities = self._mdib.entities.by_parent_handle(self._sco.sco_descriptor_container.Handle) + for operation_entity in operation_entities: + registered_op = sco.get_operation_by_handle(operation_entity.handle) if registered_op is None: self._logger.debug('found unregistered %s in mdib, handle=%s, code=%r target=%s', - descriptor.NODETYPE.localname, descriptor.Handle, descriptor.Type, - descriptor.OperationTarget) - operation = self.make_operation_instance(descriptor, sco.operation_cls_getter) + operation_entity.node_type.localname, operation_entity.handle, + operation_entity.descriptor.Type, + operation_entity.descriptor.OperationTarget) + operation = self.make_operation_instance(operation_entity.descriptor, + sco.operation_cls_getter) if operation is not None: sco.register_operation(operation) - def _on_pre_commit(self, mdib: ProviderMdib, transaction: TransactionManagerProtocol): + def _on_pre_commit(self, mdib: ProviderMdibProtocol, transaction: AnyTransactionManagerProtocol): for provider in self._all_providers_sorted(): provider.on_pre_commit(mdib, transaction) - def _on_post_commit(self, mdib: ProviderMdib, transaction: TransactionManagerProtocol): + def _on_post_commit(self, mdib: ProviderMdibProtocol, transaction: AnyTransactionManagerProtocol): for provider in self._all_providers_sorted(): provider.on_post_commit(mdib, transaction) @@ -170,7 +172,7 @@ class DefaultProduct(BaseProduct): """Default Product.""" def __init__(self, - mdib: ProviderMdib, + mdib: ProviderMdibProtocol, sco: AbstractScoOperationsRegistry, log_prefix: str | None = None): super().__init__(mdib, sco, log_prefix) @@ -178,7 +180,9 @@ def __init__(self, self._ordered_providers.extend([AudioPauseProvider(mdib, log_prefix=log_prefix), GenericSDCClockProvider(mdib, log_prefix=log_prefix), GenericPatientContextProvider(mdib, log_prefix=log_prefix), - GenericAlarmProvider(mdib, log_prefix=log_prefix), + AlertDelegateProvider(mdib, log_prefix=log_prefix), + AlertSystemStateMaintainer(mdib, log_prefix=log_prefix), + AlertPreCommitHandler(mdib, log_prefix=log_prefix), self.metric_provider, OperationProvider(mdib, log_prefix=log_prefix), GenericSetComponentStateOperationProvider(mdib, log_prefix=log_prefix), @@ -189,7 +193,7 @@ class ExtendedProduct(DefaultProduct): """Add EnsembleContextProvider and LocationContextProvider.""" def __init__(self, - mdib: ProviderMdib, + mdib: ProviderMdibProtocol, sco: AbstractScoOperationsRegistry, log_prefix: str | None = None): super().__init__(mdib, sco, log_prefix) diff --git a/src/sdc11073/roles/protocols.py b/src/sdc11073/roles/protocols.py index 7be8c1b0..f37712b6 100644 --- a/src/sdc11073/roles/protocols.py +++ b/src/sdc11073/roles/protocols.py @@ -1,12 +1,12 @@ +"""Declare protocols for a Product and a WaveformProvider.""" from __future__ import annotations from typing import TYPE_CHECKING, Any, Protocol if TYPE_CHECKING: - from sdc11073.mdib import ProviderMdib from sdc11073.mdib.descriptorcontainers import AbstractOperationDescriptorProtocol - from sdc11073.mdib.transactionsprotocol import RtDataMdibUpdateTransaction + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol from sdc11073.provider.operations import OperationDefinitionBase from sdc11073.provider.sco import AbstractScoOperationsRegistry from sdc11073.xml_types.pm_types import ComponentActivation @@ -23,25 +23,21 @@ class ProductProtocol: """ def __init__(self, - mdib: ProviderMdib, + mdib: ProviderMdibProtocol, sco: AbstractScoOperationsRegistry, log_prefix: str | None = None): """Create a product.""" - ... def init_operations(self): """Register all actively provided operations.""" - ... def stop(self): """Stop all role providers.""" - ... def make_operation_instance(self, operation_descriptor_container: AbstractOperationDescriptorProtocol, operation_cls_getter: OperationClassGetter) -> OperationDefinitionBase | None: """Call make_operation_instance of all role providers, until the first returns not None.""" - ... class WaveformProviderProtocol(Protocol): @@ -52,7 +48,7 @@ class WaveformProviderProtocol(Protocol): is_running: bool - def __init__(self, mdib: ProviderMdib, log_prefix: str): + def __init__(self, mdib: ProviderMdibProtocol, log_prefix: str): ... def register_waveform_generator(self, descriptor_handle: str, wf_generator: WaveformGeneratorBase): @@ -77,8 +73,5 @@ def stop(self): def set_activation_state(self, descriptor_handle: str, component_activation_state: ComponentActivation): """Set the activation state of waveform generator and of Metric state in mdib.""" - def update_all_realtime_samples(self, transaction: RtDataMdibUpdateTransaction): - """Update all realtime sample states that have a waveform generator registered. - - On transaction commit the mdib will call the appropriate send method of the sdc device. - """ + def update_all_realtime_samples(self): + """Update all realtime sample states that have a waveform generator registered.""" diff --git a/src/sdc11073/roles/providerbase.py b/src/sdc11073/roles/providerbase.py index c1ce896c..01ead48a 100644 --- a/src/sdc11073/roles/providerbase.py +++ b/src/sdc11073/roles/providerbase.py @@ -1,3 +1,4 @@ +"""The module implements the base class of role providers.""" from __future__ import annotations from typing import TYPE_CHECKING, Callable @@ -8,9 +9,9 @@ from sdc11073.provider.operations import OperationDefinitionBase if TYPE_CHECKING: - from sdc11073.mdib import ProviderMdib from sdc11073.mdib.descriptorcontainers import AbstractDescriptorProtocol, AbstractOperationDescriptorProtocol - from sdc11073.mdib.transactionsprotocol import TransactionManagerProtocol + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol + from sdc11073.mdib.transactionsprotocol import AbstractTransactionManagerProtocol from sdc11073.provider.operations import ExecuteHandler, TimeoutHandler from sdc11073.provider.sco import AbstractScoOperationsRegistry from sdc11073.xml_types.pm_types import CodedValue, SafetyClassification @@ -21,7 +22,7 @@ class ProviderRole: """Base class for all role implementations.""" - def __init__(self, mdib: ProviderMdib, log_prefix: str): + def __init__(self, mdib: ProviderMdibProtocol, log_prefix: str): self._mdib = mdib self._logger = loghelper.get_logger_adapter(f'sdc.device.{self.__class__.__name__}', log_prefix) @@ -57,13 +58,13 @@ def make_missing_operations(self, sco: AbstractScoOperationsRegistry) -> list[ """ return [] - def on_pre_commit(self, mdib: ProviderMdib, transaction: TransactionManagerProtocol): + def on_pre_commit(self, mdib: ProviderMdibProtocol, transaction: AbstractTransactionManagerProtocol): """Manipulate the transaction if needed. Derived classes can overwrite this method. """ - def on_post_commit(self, mdib: ProviderMdib, transaction: TransactionManagerProtocol): + def on_post_commit(self, mdib: ProviderMdibProtocol, transaction: AbstractTransactionManagerProtocol): """Run stuff after transaction. Derived classes can overwrite this method. diff --git a/src/sdc11073/roles/waveformprovider/waveformproviderimpl.py b/src/sdc11073/roles/waveformprovider/waveformproviderimpl.py index 8ca91e1e..c95754ad 100644 --- a/src/sdc11073/roles/waveformprovider/waveformproviderimpl.py +++ b/src/sdc11073/roles/waveformprovider/waveformproviderimpl.py @@ -1,3 +1,4 @@ +"""The module implements a waveform provider.""" from __future__ import annotations import time @@ -14,9 +15,9 @@ if TYPE_CHECKING: from sdc11073.definitions_base import AbstractDataModel - from sdc11073.mdib import ProviderMdib + from sdc11073.mdib.entityprotocol import EntityProtocol + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol from sdc11073.mdib.statecontainers import RealTimeSampleArrayMetricStateContainer - from sdc11073.mdib.transactionsprotocol import RtDataMdibUpdateTransaction from sdc11073.xml_types.pm_types import ComponentActivation from .realtimesamples import AnnotatorProtocol @@ -95,7 +96,7 @@ class GenericWaveformProvider: WARN_LIMIT_REALTIMESAMPLES_BEHIND_SCHEDULE = 0.2 # warn limit when real time samples cannot be sent in time WARN_RATE_REALTIMESAMPLES_BEHIND_SCHEDULE = 5 # max. every x seconds a message - def __init__(self, mdib: ProviderMdib, log_prefix: str = ''): + def __init__(self, mdib: ProviderMdibProtocol, log_prefix: str = ''): self._mdib = mdib self._logger = loghelper.get_logger_adapter(f'sdc.device.{self.__class__.__name__}', log_prefix) @@ -114,12 +115,13 @@ def register_waveform_generator(self, descriptor_handle: str, wf_generator: Wave :param wf_generator: a waveforms.WaveformGenerator instance """ sample_period = wf_generator.sample_period - descriptor_container = self._mdib.descriptions.handle.get_one(descriptor_handle) - if descriptor_container.SamplePeriod != sample_period: + entity = self._mdib.entities.by_handle(descriptor_handle) + if entity.descriptor.SamplePeriod != sample_period: # we must inform subscribers + entity.descriptor.SamplePeriod = sample_period with self._mdib.descriptor_transaction() as mgr: - descr = mgr.get_descriptor(descriptor_handle) - descr.SamplePeriod = sample_period + mgr.write_entity(entity) + if descriptor_handle in self._waveform_generators: self._waveform_generators[descriptor_handle].set_waveform_generator(wf_generator) else: @@ -166,16 +168,19 @@ def set_activation_state(self, descriptor_handle: str, component_activation_stat if not wf_generator.is_active: state.MetricValue = None - def update_all_realtime_samples(self, transaction: RtDataMdibUpdateTransaction): + def update_all_realtime_samples(self) -> list[EntityProtocol]: """Update all realtime sample states that have a waveform generator registered. On transaction commit the mdib will call the appropriate send method of the sdc device. """ + updated_entities = [] for descriptor_handle, wf_generator in self._waveform_generators.items(): if wf_generator.is_active: - state = transaction.get_state(descriptor_handle) - self._update_rt_samples(state) + entity = self._mdib.entities.by_handle(descriptor_handle) + self._update_rt_samples(entity.state) + updated_entities.append(entity) self._add_all_annotations() + return updated_entities def provide_waveforms(self, generator_class: type[WaveformGeneratorProtocol] = waveforms.TriangleGenerator, @@ -188,16 +193,16 @@ def provide_waveforms(self, :param max_waveforms: limit number of waveforms. :return: list of handles of created generators """ - name = self._mdib.data_model.pm_names.RealTimeSampleArrayMetricDescriptor - all_waveforms = self._mdib.descriptions.NODETYPE.get(name) + pm_name = self._mdib.data_model.pm_names.RealTimeSampleArrayMetricDescriptor + all_waveform_entities = self._mdib.entities.by_node_type(pm_name) if max_waveforms: - all_waveforms = all_waveforms[:max_waveforms] - for waveform in all_waveforms: + all_waveform_entities = all_waveform_entities[:max_waveforms] + for entity in all_waveform_entities: min_value = 0 max_value = 1 - sample_period = waveform.SamplePeriod if waveform.SamplePeriod > 0 else 0.01 # guarantee usable value + sample_period = entity.descriptor.SamplePeriod if entity.descriptor.SamplePeriod > 0 else 0.01 # guarantee usable value try: - tech_range = waveform.TechnicalRange[0] + tech_range = entity.descriptor.TechnicalRange[0] except IndexError: pass else: @@ -213,8 +218,8 @@ def provide_waveforms(self, max_value=max_value, waveform_period=2.0, sample_period=sample_period) - self.register_waveform_generator(waveform.Handle, generator) - return [waveform.Handle for waveform in all_waveforms] + self.register_waveform_generator(entity.handle, generator) + return [ent.handle for ent in all_waveform_entities] def _worker_thread_loop(self): timer = IntervalTimer(period_in_seconds=self.notifications_interval) @@ -226,8 +231,9 @@ def _worker_thread_loop(self): behind_schedule_seconds = timer.wait_next_interval_begin() self._log_waveform_timing(behind_schedule_seconds) try: + updated_entities = self.update_all_realtime_samples() with self._mdib.rt_sample_state_transaction() as transaction: - self.update_all_realtime_samples(transaction) + transaction.write_entities(updated_entities) self._log_waveform_timing(behind_schedule_seconds) except Exception: # noqa: BLE001 # catch all to keep loop running diff --git a/src/sdc11073/wsdiscovery/networkingthread.py b/src/sdc11073/wsdiscovery/networkingthread.py index 06f34cd0..692ff1ad 100644 --- a/src/sdc11073/wsdiscovery/networkingthread.py +++ b/src/sdc11073/wsdiscovery/networkingthread.py @@ -20,7 +20,7 @@ from sdc11073 import commlog from sdc11073.exceptions import ValidationError -from sdc11073.wsdiscovery.common import MULTICAST_IPV4_ADDRESS, MULTICAST_OUT_TTL, message_reader +from sdc11073.wsdiscovery.common import MULTICAST_IPV4_ADDRESS, message_reader if TYPE_CHECKING: from logging import Logger @@ -208,6 +208,8 @@ def _run_q_read(self): self._logger.debug('incoming message already known: %s (from %r, Id %s).', received_message.action, addr, mid) continue + self._logger.debug('new incoming message: %s (from %r, Id %s).', + received_message.action, addr, mid) self._known_message_ids.appendleft(mid) self._wsd.handle_received_message(received_message, addr) except Exception: # noqa: BLE001 diff --git a/src/sdc11073/wsdiscovery/wsdimpl.py b/src/sdc11073/wsdiscovery/wsdimpl.py index 1956b916..cc218b87 100644 --- a/src/sdc11073/wsdiscovery/wsdimpl.py +++ b/src/sdc11073/wsdiscovery/wsdimpl.py @@ -13,11 +13,11 @@ from sdc11073.definitions_sdc import SdcV1Definitions from sdc11073.exceptions import ApiUsageError from sdc11073.namespaces import default_ns_helper as nsh +from sdc11073.wsdiscovery import networkingthread from sdc11073.xml_types import wsd_types from sdc11073.xml_types.addressing_types import HeaderInformationBlock from .common import MULTICAST_IPV4_ADDRESS, MULTICAST_PORT, message_factory -from sdc11073.wsdiscovery import networkingthread from .service import Service if TYPE_CHECKING: @@ -184,7 +184,7 @@ def stop(self): def search_services(self, types: Iterable[etree.QName] | None = None, scopes: wsd_types.ScopesType | None = None, - timeout: int | float | None = 5, + timeout: int | float | None = 5, # noqa: PYI041 repeat_probe_interval: int | None = 3) -> list[Service]: """Search for services that match given types and scopes. @@ -212,7 +212,7 @@ def search_services(self, def search_sdc_services(self, scopes: wsd_types.ScopesType | None = None, - timeout: int | float | None = 5, + timeout: int | float | None = 5, # noqa: PYI041 repeat_probe_interval: int | None = 3) -> list[Service]: """Search for sdc services that match given scopes. @@ -226,7 +226,7 @@ def search_sdc_services(self, def search_multiple_types(self, types_list: list[list[etree.QName]], scopes: wsd_types.ScopesType | None = None, - timeout: int | float | None = 10, + timeout: int | float | None = 10, # noqa: PYI041 repeat_probe_interval: int | None = 3) -> list[Service]: """Search for services given the list of TYPES and SCOPES in a given timeout. @@ -276,7 +276,7 @@ def publish_service(self, epr: str, raise ApiUsageError("Server not started") metadata_version = self._local_services[epr].metadata_version + 1 if epr in self._local_services else 1 - instance_id = str(random.randint(1, 0xFFFFFFFF)) # noqa: S311 + instance_id = str(random.randint(1, 0xFFFFFFFF)) service = Service(types, scopes, x_addrs, epr, instance_id, metadata_version=metadata_version) self._logger.info('publishing %r', service) self._local_services[epr] = service @@ -347,7 +347,7 @@ def _add_remote_service(self, service: Service): already_known_service = self._remote_services.get(service.epr) if not already_known_service: self._remote_services[service.epr] = service - self._logger.info('new remote %r', service) + self._logger.info('new remote epr="%s" x_addrs=%r', service.epr, service.x_addrs) return if service.metadata_version == already_known_service.metadata_version: diff --git a/tests/mockstuff.py b/tests/mockstuff.py index 6d2cef7e..06cdda8f 100644 --- a/tests/mockstuff.py +++ b/tests/mockstuff.py @@ -1,30 +1,40 @@ +"""The module implements classes needed for testing. + +It contains simplified replacements for WsDiscovery and Subscription on devices side. +It also contains SdcProvider implementations that simplify the instantiation of a device. +""" from __future__ import annotations import logging import pathlib import threading -from urllib.parse import SplitResult from decimal import Decimal from typing import TYPE_CHECKING +from urllib.parse import SplitResult from lxml import etree +from sdc11073.entity_mdib.entity_providermdib import EntityProviderMdib from sdc11073.mdib import ProviderMdib from sdc11073.namespaces import default_ns_helper as ns_hlp from sdc11073.provider import SdcProvider from sdc11073.provider.subscriptionmgr import BicepsSubscription -from sdc11073.xml_types import pm_types, pm_qnames as pm +from sdc11073.xml_types import pm_qnames as pm +from sdc11073.xml_types import pm_types from sdc11073.xml_types.addressing_types import HeaderInformationBlock -from sdc11073.xml_types.dpws_types import ThisModelType, ThisDeviceType +from sdc11073.xml_types.dpws_types import ThisDeviceType, ThisModelType from sdc11073.xml_types.eventing_types import Subscribe if TYPE_CHECKING: - import sdc11073.certloader + import ipaddress import uuid - from sdc11073.pysoap.soapclientpool import SoapClientPool - from sdc11073.provider.providerimpl import WsDiscoveryProtocol - from sdc11073.provider.components import SdcProviderComponents + import sdc11073.certloader + from sdc11073.provider.components import SdcProviderComponents + from sdc11073.provider.providerimpl import WsDiscoveryProtocol + from sdc11073.pysoap.msgfactory import MessageFactory + from sdc11073.pysoap.soapclientpool import SoapClientPool + from sdc11073.xml_utils import LxmlElement ports_lock = threading.Lock() _ports = 10000 @@ -33,44 +43,45 @@ _logger = logging.getLogger('sdc.mock') -def dec_list(*args): +def dec_list(*args: list[int | float]) -> list[Decimal]: + """Convert a list of numbers to decimal.""" return [Decimal(x) for x in args] -def _findServer(netloc): - dev_addr = netloc.split(':') - dev_addr = tuple([dev_addr[0], int(dev_addr[1])]) # make port number an integer - for key, srv in _mockhttpservers.items(): - if tuple(key) == dev_addr: - return srv - raise KeyError('{} is not in {}'.format(dev_addr, _mockhttpservers.keys())) +class MockWsDiscovery: + """Implementation of a minimal WsDiscovery interface. + The class does nothing except logging. + """ -class MockWsDiscovery: - def __init__(self, ipaddress): + def __init__(self, ipaddress: str | ipaddress.IPv4Address): self._ipaddress = ipaddress - def get_active_addresses(self): + def get_active_addresses(self) -> str: + """Return the ip address.""" return [self._ipaddress] - def clear_service(self, epr): - _logger.info('clear_service "{}"'.format(epr)) + def clear_service(self, epr: str): + """Clear services.""" + _logger.info('clear_service "%r"', epr) class TestDevSubscription(BicepsSubscription): - """ Can be used instead of real Subscription objects""" + """Can be used instead of real Subscription objects.""" + mode = 'SomeMode' notify_to = 'http://self.com:123' identifier = '0815' expires = 60 notify_ref = 'a ref string' - def __init__(self, filter_, + def __init__(self, + filter_: list[str], soap_client_pool: SoapClientPool, - msg_factory): + msg_factory: MessageFactory): notify_ref_node = etree.Element(ns_hlp.WSE.tag('References')) - identNode = etree.SubElement(notify_ref_node, ns_hlp.WSE.tag('Identifier')) - identNode.text = self.notify_ref + ident_node = etree.SubElement(notify_ref_node, ns_hlp.WSE.tag('Identifier')) + ident_node.text = self.notify_ref base_urls = [SplitResult('https', 'www.example.com:222', 'no_uuid', query=None, fragment=None)] accepted_encodings = ['foo'] # not needed here subscribe_request = Subscribe() @@ -84,29 +95,36 @@ def __init__(self, filter_, soap_client_pool, msg_factory=msg_factory, log_prefix='test') self.reports = [] - def send_notification_report(self, body_node, action: str): + def send_notification_report(self, body_node: LxmlElement, action: str): + """Send notification to subscriber.""" info_block = HeaderInformationBlock(action=action, addr_to=self.notify_to_address, reference_parameters=self.notify_ref_params) message = self._mk_notification_message(info_block, body_node) self.reports.append(message) - async def async_send_notification_report(self, body_node, action): + async def async_send_notification_report(self, body_node: LxmlElement, action: str): + """Send notification to subscriber.""" info_block = HeaderInformationBlock(action=action, addr_to=self.notify_to_address, reference_parameters=self.notify_ref_params) message = self._mk_notification_message(info_block, body_node) self.reports.append(message) - async def async_send_notification_end_message(self, code='SourceShuttingDown', - reason='Event source going off line.'): - pass + async def async_send_notification_end_message(self, + code: str = 'SourceShuttingDown', + reason: str = 'Event source going off line.'): + """Do nothing. + + Implementation not needed for tests. + """ class SomeDevice(SdcProvider): """A device used for unit tests. Some values are predefined.""" - def __init__(self, wsdiscovery: WsDiscoveryProtocol, + def __init__(self, # noqa: PLR0913 + wsdiscovery: WsDiscoveryProtocol, mdib_xml_data: bytes, epr: str | uuid.UUID | None = None, validate: bool = True, @@ -130,13 +148,84 @@ def __init__(self, wsdiscovery: WsDiscoveryProtocol, device_mdib_container = ProviderMdib.from_string(mdib_xml_data, log_prefix=log_prefix) device_mdib_container.instance_id = 1 # set the optional value # set Metadata - mdsDescriptors = device_mdib_container.descriptions.NODETYPE.get(pm.MdsDescriptor) - for mdsDescriptor in mdsDescriptors: - if mdsDescriptor.MetaData is not None: - mdsDescriptor.MetaData.Manufacturer.append(pm_types.LocalizedText('Example Manufacturer')) - mdsDescriptor.MetaData.ModelName.append(pm_types.LocalizedText(model.ModelName[0].text)) - mdsDescriptor.MetaData.SerialNumber.append('ABCD-1234') - mdsDescriptor.MetaData.ModelNumber = '0.99' + mds_descriptors = device_mdib_container.descriptions.NODETYPE.get(pm.MdsDescriptor) + for mds_descriptor in mds_descriptors: + if mds_descriptor.MetaData is not None: + mds_descriptor.MetaData.Manufacturer.append(pm_types.LocalizedText('Example Manufacturer')) + mds_descriptor.MetaData.ModelName.append(pm_types.LocalizedText(model.ModelName[0].text)) + mds_descriptor.MetaData.SerialNumber.append('ABCD-1234') + mds_descriptor.MetaData.ModelNumber = '0.99' + super().__init__(wsdiscovery, model, device, device_mdib_container, epr, validate, + ssl_context_container=ssl_context_container, + max_subscription_duration=max_subscription_duration, + log_prefix=log_prefix, + default_components=default_components, + specific_components=specific_components, + chunk_size=chunk_size, + alternative_hostname=alternative_hostname) + + @classmethod + def from_mdib_file(cls, # noqa: PLR0913 + wsdiscovery: WsDiscoveryProtocol, + epr: str | uuid.UUID | None, + mdib_xml_path: str | pathlib.Path, + validate: bool = True, + ssl_context_container: sdc11073.certloader.SSLContextContainer | None = None, + max_subscription_duration: int = 15, + log_prefix: str = '', + default_components: SdcProviderComponents | None = None, + specific_components: SdcProviderComponents | None = None, + chunk_size: int = 0, + alternative_hostname: str | None = None) -> SomeDevice: + """Construct class with path to a mdib file.""" + mdib_xml_path = pathlib.Path(mdib_xml_path) + if not mdib_xml_path.is_absolute(): + mdib_xml_path = pathlib.Path(__file__).parent.joinpath(mdib_xml_path) + return cls(wsdiscovery, mdib_xml_path.read_bytes(), epr, validate, ssl_context_container, + max_subscription_duration=max_subscription_duration, + log_prefix=log_prefix, + default_components=default_components, specific_components=specific_components, + chunk_size=chunk_size, + alternative_hostname=alternative_hostname) + + +class SomeDeviceEntityMdib(SdcProvider): + """A device used for unit tests. Some values are predefined.""" + + def __init__(self, # noqa: PLR0913 + wsdiscovery: WsDiscoveryProtocol, + mdib_xml_data: bytes, + epr: str | uuid.UUID | None = None, + validate: bool = True, + ssl_context_container: sdc11073.certloader.SSLContextContainer | None = None, + max_subscription_duration: int = 15, + log_prefix: str = '', + default_components: SdcProviderComponents | None = None, + specific_components: SdcProviderComponents | None = None, + chunk_size: int = 0, + alternative_hostname: str | None = None): + model = ThisModelType(manufacturer='Example Manufacturer', + manufacturer_url='www.example-manufacturer.com', + model_name='SomeDevice', + model_number='1.0', + model_url='www.example-manufacturer.com/whatever/you/want/model', + presentation_url='www.example-manufacturer.com/whatever/you/want/presentation') + device = ThisDeviceType(friendly_name='Py SomeDevice', + firmware_version='0.99', + serial_number='12345') + + device_mdib_container = EntityProviderMdib.from_string(mdib_xml_data, log_prefix=log_prefix) + device_mdib_container.instance_id = 1 # set the optional value + # set Metadata + mds_entities = device_mdib_container.entities.by_parent_handle(None) + # Todo: write that meta data back to dom tree + for mds_entity in mds_entities: + mds_descriptor = mds_entity.descriptor + if mds_descriptor.MetaData is not None: + mds_descriptor.MetaData.Manufacturer.append(pm_types.LocalizedText('Example Manufacturer')) + mds_descriptor.MetaData.ModelName.append(pm_types.LocalizedText(model.ModelName[0].text)) + mds_descriptor.MetaData.SerialNumber.append('ABCD-1234') + mds_descriptor.MetaData.ModelNumber = '0.99' super().__init__(wsdiscovery, model, device, device_mdib_container, epr, validate, ssl_context_container=ssl_context_container, max_subscription_duration=max_subscription_duration, @@ -147,7 +236,7 @@ def __init__(self, wsdiscovery: WsDiscoveryProtocol, alternative_hostname=alternative_hostname) @classmethod - def from_mdib_file(cls, + def from_mdib_file(cls, # noqa: PLR0913 wsdiscovery: WsDiscoveryProtocol, epr: str | uuid.UUID | None, mdib_xml_path: str | pathlib.Path, @@ -158,7 +247,7 @@ def from_mdib_file(cls, default_components: SdcProviderComponents | None = None, specific_components: SdcProviderComponents | None = None, chunk_size: int = 0, - alternative_hostname: str | None = None): + alternative_hostname: str | None = None) -> SomeDeviceEntityMdib: """Construct class with path to a mdib file.""" mdib_xml_path = pathlib.Path(mdib_xml_path) if not mdib_xml_path.is_absolute(): diff --git a/tests/test_client_device.py b/tests/test_client_device.py index d138648e..b1a11ef6 100644 --- a/tests/test_client_device.py +++ b/tests/test_client_device.py @@ -48,6 +48,8 @@ from tests import utils from tests.mockstuff import SomeDevice, dec_list +# ruff: noqa + ENABLE_COMMLOG = False if ENABLE_COMMLOG: comm_logger = commlog.DirectoryLogger(log_folder=r'c:\temp\sdc_commlog', @@ -483,7 +485,8 @@ def setUp(self): self.wsd.start() self.sdc_device = SomeDevice.from_mdib_file(self.wsd, None, mdib_70041, default_components=default_sdc_provider_components_async, - max_subscription_duration=10) # shorter duration for faster tests + max_subscription_duration=10, # shorter duration for faster tests + log_prefix=f'{self._testMethodName}: ') # in order to test correct handling of default namespaces, we make participant model the default namespace self.sdc_device.start_all(periodic_reports_interval=1.0) self._loc_validators = [pm_types.InstanceIdentifier('Validator', extension_string='System')] @@ -501,7 +504,8 @@ def setUp(self): sdc_definitions=self.sdc_device.mdib.sdc_definitions, ssl_context_container=None, validate=CLIENT_VALIDATE, - specific_components=specific_components) + specific_components=specific_components, + log_prefix=self._testMethodName) self.sdc_client.start_all() # with periodic reports and system error report time.sleep(1) sys.stderr.write('\n############### setUp done {} ##############\n'.format(self._testMethodName)) @@ -1373,12 +1377,18 @@ def setUp(self): self.httpserver.started_evt.wait(timeout=5) self.logger.info('common http server A listens on port {}', self.httpserver.my_port) - self.sdc_device_1 = SomeDevice.from_mdib_file(self.wsd, 'device1', mdib_70041, log_prefix=' ') + self.sdc_device_1 = SomeDevice.from_mdib_file(self.wsd, + 'device1', + mdib_70041, + log_prefix=f'{self._testMethodName}1: ') self.sdc_device_1.start_all(shared_http_server=self.httpserver) self.sdc_device_1.set_location(location, self._loc_validators) provide_realtime_data(self.sdc_device_1) - self.sdc_device_2 = SomeDevice.from_mdib_file(self.wsd, 'device2', mdib_70041, log_prefix=' ') + self.sdc_device_2 = SomeDevice.from_mdib_file(self.wsd, + 'device2', + mdib_70041, + log_prefix=f'{self._testMethodName}2: ') self.sdc_device_2.start_all(shared_http_server=self.httpserver) self.sdc_device_2.set_location(location, self._loc_validators) provide_realtime_data(self.sdc_device_2) @@ -1450,7 +1460,10 @@ def setUp(self): logging.getLogger('sdc').info('############### start setUp {} ##############'.format(self._testMethodName)) self.wsd = WSDiscovery('127.0.0.1') self.wsd.start() - self.sdc_device = SomeDevice.from_mdib_file(self.wsd, None, mdib_70041, log_prefix=' ', + self.sdc_device = SomeDevice.from_mdib_file(self.wsd, + None, + mdib_70041, + log_prefix=f'{self._testMethodName}: ', chunk_size=512) # in order to test correct handling of default namespaces, we make participant model the default namespace @@ -1505,7 +1518,8 @@ def setUp(self): subscriptions_manager_class={'StateEvent': SubscriptionsManagerReferenceParamAsync}, soap_client_class=SoapClientAsync ) - self.sdc_device = SomeDevice.from_mdib_file(self.wsd, None, mdib_70041, log_prefix=' ', + self.sdc_device = SomeDevice.from_mdib_file(self.wsd, None, mdib_70041, + log_prefix=f'{self._testMethodName}: ', specific_components=specific_components, chunk_size=512) # in order to test correct handling of default namespaces, we make participant model the default namespace @@ -1593,7 +1607,8 @@ def setUp(self): self.logger.info('############### start setUp {} ##############'.format(self._testMethodName)) self.wsd = WSDiscovery('127.0.0.1') self.wsd.start() - self.sdc_device = SomeDevice.from_mdib_file(self.wsd, None, mdib_70041, log_prefix='', + self.sdc_device = SomeDevice.from_mdib_file(self.wsd, None, mdib_70041, + log_prefix=f'{self._testMethodName}: ', default_components=default_sdc_provider_components_sync, chunk_size=512) self.sdc_device.start_all(periodic_reports_interval=1.0) @@ -1778,8 +1793,9 @@ def setUp(self): self.sdc_device = SomeDevice.from_mdib_file( self.wsd, None, mdib_70041, default_components=default_sdc_provider_components_async, - max_subscription_duration=10, - alternative_hostname=socket.getfqdn()) # shorter duration for faster tests + max_subscription_duration=10, # shorter duration for faster tests + alternative_hostname=socket.getfqdn(), + log_prefix=f'{self._testMethodName}: ') self.sdc_device.start_all() self._loc_validators = [pm_types.InstanceIdentifier('Validator', extension_string='System')] self.sdc_device.set_location(utils.random_location(), self._loc_validators) @@ -1834,8 +1850,10 @@ def setUp(self): self.sdc_device = SomeDevice.from_mdib_file( self.wsd, None, mdib_70041, default_components=default_sdc_provider_components_async, - max_subscription_duration=10, - alternative_hostname="some_random_invalid_hostname") # shorter duration for faster tests + max_subscription_duration=10, # shorter duration for faster tests + alternative_hostname="some_random_invalid_hostname", + log_prefix=f'{self._testMethodName}: ' + ) self.sdc_device.start_all() self._loc_validators = [pm_types.InstanceIdentifier('Validator', extension_string='System')] self.sdc_device.set_location(utils.random_location(), self._loc_validators) diff --git a/tests/test_entity_client_device.py b/tests/test_entity_client_device.py new file mode 100644 index 00000000..2ba347a2 --- /dev/null +++ b/tests/test_entity_client_device.py @@ -0,0 +1,504 @@ +"""The module tests functionality between consumer and provider, both using entity based mdibs.""" +from __future__ import annotations + +import datetime +import logging +import sys +import time +import traceback +import unittest.mock +from decimal import Decimal +from itertools import cycle +from typing import TYPE_CHECKING + +from lxml import etree as etree_ + +from sdc11073 import loghelper, observableproperties +from sdc11073.consumer import SdcConsumer +from sdc11073.consumer.components import SdcConsumerComponents +from sdc11073.dispatch import RequestDispatcher +from sdc11073.entity_mdib.entities import ConsumerEntity, ConsumerMultiStateEntity, XmlEntity, XmlMultiStateEntity +from sdc11073.entity_mdib.entity_consumermdib import EntityConsumerMdib +from sdc11073.loghelper import basic_logging_setup, get_logger_adapter +from sdc11073.roles.waveformprovider import waveforms +from sdc11073.wsdiscovery import WSDiscovery +from sdc11073.xml_types import pm_qnames, pm_types +from sdc11073.xml_types import pm_qnames as pm +from tests import utils +from tests.mockstuff import SomeDeviceEntityMdib + +if TYPE_CHECKING: + from sdc11073.entity_mdib.entities import ProviderMultiStateEntity + from sdc11073.provider import SdcProvider + +CLIENT_VALIDATE = True +SET_TIMEOUT = 10 # longer timeout than usually needed, but jenkins jobs frequently failed with 3 seconds timeout +NOTIFICATION_TIMEOUT = 5 # also jenkins related value + + +default_mdib_file = 'mdib_two_mds.xml' + + +def provide_realtime_data(sdc_provider: SdcProvider): + waveform_provider = sdc_provider.waveform_provider + if waveform_provider is None: + return + iterator = cycle([waveforms.SawtoothGenerator, + waveforms.SinusGenerator, + waveforms.TriangleGenerator]) + waveform_entities = sdc_provider.mdib.entities.by_node_type(pm_qnames.RealTimeSampleArrayMetricDescriptor) + for i, waveform_entity in enumerate(waveform_entities): + cls = iterator.__next__() + gen = cls(min_value=1, max_value=i+10, waveform_period=1.1, sample_period=0.01) + waveform_provider.register_waveform_generator(waveform_entity.handle, gen) + + if i == 2: + # make this generator the annotator source + waveform_provider.add_annotation_generator(pm_types.CodedValue('a', 'b'), + trigger_handle=waveform_entity.handle, + annotated_handles=[waveform_entities[0].handle], + ) + + +class TestClientSomeDeviceXml(unittest.TestCase): + def setUp(self): + basic_logging_setup() + self.logger = get_logger_adapter('sdc.test') + sys.stderr.write(f'\n############### start setUp {self._testMethodName} ##############\n'.format()) + self.logger.info('############### start setUp %s ##############', self._testMethodName) + self.wsd = WSDiscovery('127.0.0.1') + self.wsd.start() + self.sdc_provider: SomeDeviceEntityMdib | None = None + self.sdc_consumer: SdcConsumer | None = None + self.log_watcher = loghelper.LogWatcher(logging.getLogger('sdc'), level=logging.ERROR) + + def _init_provider_consumer(self, mdib_file: str = default_mdib_file): + self.sdc_provider = SomeDeviceEntityMdib.from_mdib_file(self.wsd, None, mdib_file, + max_subscription_duration=10) # shorter duration for faster tests + # in order to test correct handling of default namespaces, we make participant model the default namespace + self.sdc_provider.start_all(periodic_reports_interval=1.0) + self._loc_validators = [pm_types.InstanceIdentifier('Validator', extension_string='System')] + self.sdc_provider.set_location(utils.random_location(), self._loc_validators) + provide_realtime_data(self.sdc_provider) + + time.sleep(0.5) # allow init of devices to complete + # no deferred action handling for easier debugging + specific_components = SdcConsumerComponents( + action_dispatcher_class=RequestDispatcher, + ) + + x_addr = self.sdc_provider.get_xaddrs() + self.sdc_consumer = SdcConsumer(x_addr[0], + sdc_definitions=self.sdc_provider.mdib.sdc_definitions, + ssl_context_container=None, + validate=CLIENT_VALIDATE, + specific_components=specific_components) + self.sdc_consumer.start_all() # with periodic reports and system error report + time.sleep(1) + sys.stderr.write(f'\n############### setUp done {self._testMethodName} ##############\n') + self.logger.info('############### setUp done %s ##############', self._testMethodName) + time.sleep(0.5) + + def tearDown(self): + sys.stderr.write(f'############### tearDown {self._testMethodName}... ##############\n') + self.log_watcher.setPaused(True) + try: + if self.sdc_provider: + self.sdc_provider.stop_all() + if self.sdc_consumer: + self.sdc_consumer.stop_all(unsubscribe=False) + self.wsd.stop() + except: + sys.stderr.write(traceback.format_exc()) + try: + self.log_watcher.check() + except loghelper.LogWatchError as ex: + sys.stderr.write(repr(ex)) + raise + sys.stderr.write(f'############### tearDown {self._testMethodName} done ##############\n') + + def add_random_patient(self, count: int = 1) -> [ProviderMultiStateEntity, list]: + new_states = [] + entities = self.sdc_provider.mdib.entities.by_node_type(pm.PatientContextDescriptor) + if len(entities) != 1: + raise ValueError(f'cannot handle {len(entities)} instances of PatientContextDescriptor') + entity = entities[0] + handles = [] + for i in range(count): + st = entity.new_state() + st.CoreData.Givenname = f'Max{i}' + st.CoreData.Middlename = ['Willy'] + st.CoreData.Birthname = f'Mustermann{i}' + st.CoreData.Familyname = f'Musterfrau{i}' + st.CoreData.Title = 'Rex' + st.CoreData.Sex = pm_types.Sex.MALE + st.CoreData.PatientType = pm_types.PatientType.ADULT + st.CoreData.Height = pm_types.Measurement(Decimal('88.2'), pm_types.CodedValue('abc', 'def')) + st.CoreData.Weight = pm_types.Measurement(Decimal('68.2'), pm_types.CodedValue('abc')) + st.CoreData.Race = pm_types.CodedValue('123', 'def') + st.CoreData.DateOfBirth = datetime.datetime(2012, 3, 15, 13, 12, 11) + handles.append(st.Handle) + new_states.append(st) + + with self.sdc_provider.mdib.context_state_transaction() as mgr: + mgr.write_entity(entity, handles) + return entity, new_states + + def test_consumer_xml_mdib(self): + self._init_provider_consumer() + patient_descriptor_entity, _ = self.add_random_patient(2) + msg_reader = self.sdc_consumer.msg_reader + consumer_mdib = EntityConsumerMdib(self.sdc_consumer, max_realtime_samples=297) + consumer_mdib.init_mdib() + + # check sequence_id and instance_id + self.assertEqual(consumer_mdib.sequence_id, self.sdc_provider.mdib.sequence_id) + self.assertEqual(consumer_mdib.instance_id, self.sdc_provider.mdib.instance_id) + + # check difference of mdib versions (consumer is allowed to be max. one smaller + self.assertLess(self.sdc_provider.mdib.mdib_version - consumer_mdib.mdib_version, 2) + # check also in DOM tree + self.assertLess(self.sdc_provider.mdib.mdib_version + - int(consumer_mdib._get_mdib_response_node.get('MdibVersion')), 2) + self.assertLess(self.sdc_provider.mdib.mdib_version + - int(consumer_mdib._get_mdib_response_node[0].get('MdibVersion')), 2) + + msg_reader._validate_node(consumer_mdib._get_mdib_response_node) + self.assertEqual(len(self.sdc_provider.mdib.entities), len(consumer_mdib.entities)) + + for xml_entity in consumer_mdib._entities.values(): + self.assertIsInstance(xml_entity, (XmlEntity, XmlMultiStateEntity)) + self.assertIsInstance(xml_entity.node_type, etree_.QName) + self.assertIsInstance(xml_entity.source_mds, str) + + # needed? + for handle in consumer_mdib._entities: + ent = consumer_mdib.entities.by_handle(handle) + self.assertIsInstance(ent, (ConsumerEntity, ConsumerMultiStateEntity)) + + # verify that NODETYPE filter works as expected + consumer_ent_list = consumer_mdib.entities.by_node_type(pm_qnames.VmdDescriptor) + provider_list = self.sdc_provider.mdib.entities.by_node_type(pm_qnames.VmdDescriptor) + self.assertEqual(len(provider_list), len(consumer_ent_list)) + + # test update method of entities + metric_entities = consumer_mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor) + consumer_metric_entity = metric_entities[0] + descriptor_version = consumer_metric_entity.descriptor.DescriptorVersion + state_version = consumer_metric_entity.state.StateVersion + consumer_metric_entity.descriptor.DescriptorVersion += 1 + consumer_metric_entity.state.StateVersion += 1 + consumer_metric_entity.update() + self.assertEqual(descriptor_version, consumer_metric_entity.descriptor.DescriptorVersion) + self.assertEqual(state_version, consumer_metric_entity.state.StateVersion) + + # calling update with deleted xml entity source shall raise an error + del consumer_mdib._entities[consumer_metric_entity.handle] + self.assertRaises(ValueError, consumer_metric_entity.update) + + # same for multi state entity + context_descriptor_handle = patient_descriptor_entity.descriptor.Handle + context_consumer_entity = consumer_mdib.entities.by_handle(context_descriptor_handle) + del consumer_mdib._entities[context_descriptor_handle] + self.assertRaises(ValueError, context_consumer_entity.update) + + def test_metric_update(self): + self._init_provider_consumer() + msg_reader = self.sdc_consumer.msg_reader + consumer_mdib = EntityConsumerMdib(self.sdc_consumer, max_realtime_samples=297) + consumer_mdib.init_mdib() + self.assertEqual(len(self.sdc_provider.mdib.entities), len(consumer_mdib.entities)) + + metric_entities = self.sdc_provider.mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor) + provider_entity = metric_entities[0] + + coll = observableproperties.SingleValueCollector(consumer_mdib, + 'metric_handles') + + # set value of a metric + first_value = Decimal(12) + st = provider_entity.state + old_state_version = st.StateVersion + if st.MetricValue is None: + st.mk_metric_value() + st.MetricValue.Value = first_value + st.MetricValue.MetricQuality.Validity = pm_types.MeasurementValidity.VALID + + with self.sdc_provider.mdib.metric_state_transaction() as mgr: + # mgr automatically increases the StateVersion + mgr.write_entity(provider_entity) + + coll.result(timeout=NOTIFICATION_TIMEOUT) + provider_entity.update() + self.assertEqual(provider_entity.state.StateVersion, old_state_version + 1) + consumer_entity = consumer_mdib.entities.by_handle(provider_entity.handle) + self.assertIsNone(provider_entity.state.diff(consumer_entity.state, max_float_diff=1e-6)) + msg_reader._validate_node(consumer_mdib._get_mdib_response_node) + + def test_alert_update(self): + self._init_provider_consumer() + msg_reader = self.sdc_consumer.msg_reader + consumer_mdib = EntityConsumerMdib(self.sdc_consumer, max_realtime_samples=297) + consumer_mdib.init_mdib() + + self.assertEqual(len(self.sdc_provider.mdib.entities), len(consumer_mdib.entities)) + + provider_entities = self.sdc_provider.mdib.entities.by_node_type(pm_qnames.AlertConditionDescriptor) + provider_entity = provider_entities[0] + coll = observableproperties.SingleValueCollector(consumer_mdib, + 'alert_handles') + + with self.sdc_provider.mdib.alert_state_transaction() as mgr: + # mgr automatically increases the StateVersion + provider_entity.state.ActivationState = pm_types.AlertActivation.PAUSED + provider_entity.state.ActualPriority = pm_types.AlertConditionPriority.MEDIUM + mgr.write_entity(provider_entity) + + coll.result(timeout=NOTIFICATION_TIMEOUT) + provider_entity.update() # update to get correct version counters + consumer_entity = consumer_mdib.entities.by_handle(provider_entity.handle) + self.assertIsNone(provider_entity.state.diff(consumer_entity.state, max_float_diff=1e-6)) + msg_reader._validate_node(consumer_mdib._get_mdib_response_node) + + def test_component_update(self): + self._init_provider_consumer() + msg_reader = self.sdc_consumer.msg_reader + consumer_mdib = EntityConsumerMdib(self.sdc_consumer, max_realtime_samples=297) + consumer_mdib.init_mdib() + + channel_entities = self.sdc_provider.mdib.entities.by_node_type(pm_qnames.ChannelDescriptor) + provider_channel_entity = channel_entities[0] + coll = observableproperties.SingleValueCollector(consumer_mdib, + 'component_handles') + + old_state_version = provider_channel_entity.state.StateVersion + with self.sdc_provider.mdib.component_state_transaction() as mgr: + provider_channel_entity.state.ActivationState = pm_types.ComponentActivation.FAILURE + mgr.write_entity(provider_channel_entity) + + coll.result(timeout=NOTIFICATION_TIMEOUT) + provider_channel_entity.update() + self.assertEqual(provider_channel_entity.state.StateVersion, old_state_version + 1) + consumer_channel_entity = consumer_mdib.entities.by_handle(provider_channel_entity.handle) + self.assertIsNone(provider_channel_entity.state.diff(consumer_channel_entity.state, max_float_diff=1e-6)) + msg_reader._validate_node(consumer_mdib._get_mdib_response_node) + + def test_operational_state_update(self): + self._init_provider_consumer() + msg_reader = self.sdc_consumer.msg_reader + consumer_mdib = EntityConsumerMdib(self.sdc_consumer, max_realtime_samples=297) + consumer_mdib.init_mdib() + self.assertEqual(len(self.sdc_provider.mdib.entities), len(consumer_mdib._entities)) + + entities = self.sdc_provider.mdib.entities.by_node_type(pm_qnames.ActivateOperationDescriptor) + provider_entity = entities[0] + coll = observableproperties.SingleValueCollector(consumer_mdib, + 'operation_handles') + provider_entity.state.OperatingMode = pm_types.OperatingMode.NA + + with self.sdc_provider.mdib.operational_state_transaction() as mgr: + mgr.write_entity(provider_entity) + + coll.result(timeout=NOTIFICATION_TIMEOUT) + + consumer_entity = consumer_mdib.entities.by_handle(provider_entity.handle) + provider_entity.update() + + self.assertIsNone(provider_entity.state.diff(consumer_entity.state, max_float_diff=1e-6)) + msg_reader._validate_node(consumer_mdib._get_mdib_response_node) + + def test_remove_mds(self): + self._init_provider_consumer() + self.sdc_provider.stop_realtime_sample_loop() + time.sleep(0.1) + consumer_mdib = EntityConsumerMdib(self.sdc_consumer, max_realtime_samples=297) + consumer_mdib.init_mdib() + + # get all versions + descriptor_versions = {} + state_versions = {} + for handle, entity in self.sdc_provider.mdib._entities.items(): + descriptor_versions[handle] = entity.descriptor.DescriptorVersion + if entity.is_multi_state: + for state in entity.states.values(): + state_versions[state.Handle] = state.StateVersion + else: + state_versions[handle] = entity.state.StateVersion + + # now remove all + coll = observableproperties.SingleValueCollector(consumer_mdib, 'deleted_descriptors_handles') + mds_entities = self.sdc_provider.mdib.entities.by_node_type(pm.MdsDescriptor) + with self.sdc_provider.mdib.descriptor_transaction() as mgr: + for entity in mds_entities: + mgr.remove_entity(entity) + coll.result(timeout=NOTIFICATION_TIMEOUT) + + # verify both mdibs are empty + self.assertEqual(len(self.sdc_provider.mdib.entities), 0) + self.assertEqual(len(consumer_mdib.entities), 0) + # verify all version info is saved + self.assertEqual(descriptor_versions, self.sdc_provider.mdib.descr_handle_version_lookup) + self.assertEqual(state_versions, self.sdc_provider.mdib.state_handle_version_lookup) + + def test_set_patient_context_on_device(self): + """Verify that device updates patient. + + Verify that a notification device->client updates the client mdib. + """ + self._init_provider_consumer() + consumer_mdib = EntityConsumerMdib(self.sdc_consumer, max_realtime_samples=297) + consumer_mdib.init_mdib() + + entities = self.sdc_provider.mdib.entities.by_node_type(pm.PatientContextDescriptor) + self.assertEqual(len(entities), 1) + coll = observableproperties.SingleValueCollector(consumer_mdib, 'context_handles') + provider_entity, states = self.add_random_patient(1) # this runs a transaction + st_handle = states[0].Handle + coll.result(timeout=NOTIFICATION_TIMEOUT) + provider_entity.update() + provider_state = provider_entity.states[st_handle] + consumer_entity = consumer_mdib.entities.by_handle(provider_entity.descriptor.Handle) + consumer_state = consumer_entity.states[st_handle] + self.assertIsNone(consumer_state.diff(provider_state, max_float_diff=1e-6)) + + # test update of same patient + coll = observableproperties.SingleValueCollector(consumer_mdib, 'context_handles') + provider_entity.update() + + provider_state = provider_entity.states[st_handle] + provider_state.CoreData.Givenname = 'Moritz' + with self.sdc_provider.mdib.context_state_transaction() as mgr: + mgr.write_entity(provider_entity, [st_handle]) + coll.result(timeout=NOTIFICATION_TIMEOUT) + time.sleep(1) + provider_entity.update() + provider_state = provider_entity.states[st_handle] + consumer_entity.update() + consumer_state = consumer_entity.states[st_handle] + self.assertIsNone(consumer_state.diff(provider_state, max_float_diff=1e-6)) + + def test_description_modification(self): + self._init_provider_consumer() + msg_reader = self.sdc_consumer.msg_reader + consumer_mdib = EntityConsumerMdib(self.sdc_consumer, max_realtime_samples=297) + consumer_mdib.init_mdib() + + metric_entities = consumer_mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor) + consumer_entity = metric_entities[0] + + initial_descriptor_version = consumer_entity.descriptor.DescriptorVersion + initial_state_version = consumer_entity.state.StateVersion + + # now update a metric descriptor and wait for the next DescriptionModificationReport + coll = observableproperties.SingleValueCollector(consumer_mdib, + 'updated_descriptors_handles') + + new_determination_period = 3.14159 + provider_entity = self.sdc_provider.mdib.entities.by_handle(consumer_entity.handle) + provider_entity.descriptor.DeterminationPeriod = new_determination_period + with self.sdc_provider.mdib.descriptor_transaction() as mgr: + mgr.write_entity(provider_entity) + coll.result(timeout=NOTIFICATION_TIMEOUT) + + # verify that client got updates + consumer_entity.update() + self.assertEqual(consumer_entity.descriptor.DescriptorVersion, initial_descriptor_version + 1) + self.assertEqual(consumer_entity.descriptor.DeterminationPeriod, new_determination_period) + self.assertEqual(consumer_entity.state.DescriptorVersion, initial_descriptor_version + 1) + self.assertEqual(consumer_entity.state.StateVersion, initial_state_version + 1) + + msg_reader._validate_node(consumer_mdib._get_mdib_response_node) + + # now update a channel descriptor and wait for the next DescriptionModificationReport + channel_descriptor_handle = consumer_entity.descriptor.parent_handle #'2.1.6.1' # a channel + consumer_entity = consumer_mdib.entities.by_handle(channel_descriptor_handle) + coll = observableproperties.SingleValueCollector(consumer_mdib, + 'updated_descriptors_handles') + new_concept_description = 'foo bar' + provider_entity = self.sdc_provider.mdib.entities.by_handle(channel_descriptor_handle) + provider_entity.descriptor.Type.ConceptDescription[0].text = new_concept_description + initial_descriptor_version = provider_entity.descriptor.DescriptorVersion + initial_state_version = provider_entity.state.StateVersion + + with self.sdc_provider.mdib.descriptor_transaction() as mgr: + mgr.write_entity(provider_entity) + + provider_entity.update() + self.assertEqual(provider_entity.descriptor.DescriptorVersion, initial_descriptor_version + 1) + self.assertEqual(provider_entity.descriptor.Type.ConceptDescription[0].text, new_concept_description) + self.assertEqual(provider_entity.state.StateVersion, initial_state_version + 1) + + coll.result(timeout=NOTIFICATION_TIMEOUT) + + consumer_entity.update() + + self.assertEqual(consumer_entity.descriptor.DescriptorVersion, initial_descriptor_version + 1) + self.assertEqual(consumer_entity.descriptor.Type.ConceptDescription[0].text, new_concept_description) + self.assertEqual(consumer_entity.state.DescriptorVersion, consumer_entity.descriptor.DescriptorVersion) + self.assertEqual(consumer_entity.state.StateVersion, initial_state_version + 1) + + msg_reader._validate_node(consumer_mdib._get_mdib_response_node) + + # test creating a numeric descriptor + # coll: wait for the next DescriptionModificationReport + coll = observableproperties.SingleValueCollector(consumer_mdib, + 'new_descriptors_handles') + + new_handle = 'a_generated_descriptor' + + new_entity = self.sdc_provider.mdib.entities.new_entity(pm.NumericMetricDescriptor, + new_handle, + channel_descriptor_handle) + new_entity.descriptor.Type = pm_types.CodedValue('12345') + new_entity.descriptor.Unit = pm_types.CodedValue('hector') + new_entity.descriptor.Resolution = Decimal('0.42') + + # verify that it is possible to create an entity with same handle twice + self.assertRaises(ValueError, self.sdc_provider.mdib.entities.new_entity, + pm.NumericMetricDescriptor, + new_handle, + channel_descriptor_handle, + ) + + with self.sdc_provider.mdib.descriptor_transaction() as mgr: + mgr.write_entity(new_entity) + coll.result(timeout=NOTIFICATION_TIMEOUT) + + new_consumer_entity = consumer_mdib.entities.by_handle(new_handle) + self.assertEqual(new_consumer_entity.descriptor.Resolution, Decimal('0.42')) + msg_reader._validate_node(consumer_mdib._get_mdib_response_node) + + # test creating a battery descriptor + entities = self.sdc_provider.mdib.entities.by_node_type(pm_qnames.MdsDescriptor) + provider_mds_entity = entities[0] + + # coll: wait for the next DescriptionModificationReport + coll = observableproperties.SingleValueCollector(consumer_mdib, + 'new_descriptors_handles') + new_battery_handle = 'new_battery_handle' + node_name = pm.BatteryDescriptor + new_entity = self.sdc_provider.mdib.entities.new_entity(node_name, + new_battery_handle, + provider_mds_entity.handle) + new_entity.descriptor.Type = pm_types.CodedValue('23456') + + with self.sdc_provider.mdib.descriptor_transaction() as mgr: + mgr.write_entity(new_entity) + # long timeout, sometimes high load on jenkins makes these tests fail + coll.result(timeout=NOTIFICATION_TIMEOUT) + consumer_entity = consumer_mdib.entities.by_handle(new_battery_handle) + + msg_reader._validate_node(consumer_mdib._get_mdib_response_node) + + self.assertEqual(consumer_entity.descriptor.Handle, new_battery_handle) + + # test deleting a descriptor + coll = observableproperties.SingleValueCollector(consumer_mdib, + 'deleted_descriptors_handles') + provider_channel_entity = self.sdc_provider.mdib.entities.by_handle(channel_descriptor_handle) + + with self.sdc_provider.mdib.descriptor_transaction() as mgr: + mgr.remove_entity(provider_channel_entity) + coll.result(timeout=NOTIFICATION_TIMEOUT) + entity = consumer_mdib.entities.by_handle(new_handle) + self.assertIsNone(entity) diff --git a/tests/test_entity_operations.py b/tests/test_entity_operations.py new file mode 100644 index 00000000..88887fd7 --- /dev/null +++ b/tests/test_entity_operations.py @@ -0,0 +1,675 @@ +"""The module tests operations with provider and consumer that use entity mdibs.""" +from __future__ import annotations + +import datetime +import logging +import time +import unittest +from decimal import Decimal + +from sdc11073 import commlog, loghelper, observableproperties +from sdc11073.consumer import SdcConsumer +from sdc11073.consumer.components import SdcConsumerComponents +from sdc11073.dispatch import RequestDispatcher +from sdc11073.entity_mdib.entity_consumermdib import EntityConsumerMdib +from sdc11073.loghelper import basic_logging_setup +from sdc11073.roles.nomenclature import NomenclatureCodes +from sdc11073.wsdiscovery import WSDiscovery +from sdc11073.xml_types import msg_types, pm_types +from sdc11073.xml_types import pm_qnames as pm +from tests import utils +from tests.mockstuff import SomeDeviceEntityMdib + +ENABLE_COMMLOG = False +if ENABLE_COMMLOG: + comm_logger = commlog.DirectoryLogger(log_folder=r'c:\temp\sdc_commlog', + log_out=True, + log_in=True, + broadcast_ip_filter=None) + comm_logger.start() + +CLIENT_VALIDATE = True +SET_TIMEOUT = 10 # longer timeout than usually needed, but jenkins jobs frequently failed with 3 seconds timeout +NOTIFICATION_TIMEOUT = 5 # also jenkins related value + +default_mdib_file = 'mdib_two_mds.xml' +mdib_70041_file = '70041_MDIB_Final.xml' + + +class TestEntityOperations(unittest.TestCase): + """Test role providers (located in sdc11073.roles).""" + + def setUp(self): + basic_logging_setup() + self._logger = logging.getLogger('sdc.test') + self._logger.info('############### start setUp %s ##############', self._testMethodName) + self.wsd = WSDiscovery('127.0.0.1') + self.wsd.start() + self.sdc_provider: SomeDeviceEntityMdib | None = None + self.sdc_consumer: SdcConsumer | None = None + self.log_watcher = loghelper.LogWatcher(logging.getLogger('sdc'), level=logging.ERROR) + self._logger.info('############### setUp done %s ##############', self._testMethodName) + + def _init_provider_consumer(self, mdib_file: str = default_mdib_file): + self.sdc_provider = SomeDeviceEntityMdib.from_mdib_file( + self.wsd, None, + mdib_file, + max_subscription_duration=10) # shorter duration for faster tests + # in order to test correct handling of default namespaces, we make participant model the default namespace + self.sdc_provider.start_all(periodic_reports_interval=1.0) + self._loc_validators = [pm_types.InstanceIdentifier('Validator', extension_string='System')] + self.sdc_provider.set_location(utils.random_location(), self._loc_validators) + + time.sleep(0.5) # allow init of devices to complete + # no deferred action handling for easier debugging + specific_components = SdcConsumerComponents( + action_dispatcher_class=RequestDispatcher, + ) + + x_addr = self.sdc_provider.get_xaddrs() + self.sdc_consumer = SdcConsumer(x_addr[0], + sdc_definitions=self.sdc_provider.mdib.sdc_definitions, + ssl_context_container=None, + validate=CLIENT_VALIDATE, + specific_components=specific_components) + self.sdc_consumer.start_all() # with periodic reports and system error report + time.sleep(1) + + def tearDown(self): + self._logger.info('############### tearDown %s... ##############\n', self._testMethodName) + self.log_watcher.setPaused(True) + if self.sdc_consumer: + self.sdc_consumer.stop_all() + if self.sdc_provider: + self.sdc_provider.stop_all() + self.wsd.stop() + try: + self.log_watcher.check() + except loghelper.LogWatchError as ex: + self._logger.warning(repr(ex)) + raise + self._logger.info('############### tearDown %s done ##############\n', self._testMethodName) + + def test_set_patient_context_operation(self): + """Client calls corresponding operation of GenericContextProvider. + + - verify that operation is successful. + - verify that a notification device->client also updates the consumer mdib. + """ + self._init_provider_consumer() + + # delete possible existing states + patient_entities = self.sdc_provider.mdib.entities.by_node_type(pm.PatientContextDescriptor) + with self.sdc_provider.mdib.context_state_transaction() as tr: + for ent in patient_entities: + handles = list(ent.states.keys()) + if len(handles) > 0: + ent.states.clear() + tr.write_entity(ent, handles) + + client_mdib = EntityConsumerMdib(self.sdc_consumer) + client_mdib.init_mdib() + + patient_entities = client_mdib.entities.by_node_type(pm.PatientContextDescriptor) + my_patient_entity = patient_entities[0] + # initially the device shall not have any patient + self.assertEqual(len(my_patient_entity.states), 0) + operation_entities = client_mdib.entities.by_node_type(pm.SetContextStateOperationDescriptor) + pat_op_entities = [ ent for ent in operation_entities if ent.descriptor.OperationTarget == my_patient_entity.handle] + self.assertEqual(len(pat_op_entities), 1) + my_operation = pat_op_entities[0] + self._logger.info('Handle for SetContextState Operation = %s', my_operation.handle) + context = self.sdc_consumer.client('Context') + + # insert a new patient with wrong handle, this shall fail + proposed_context = my_patient_entity.new_state() + proposed_context.ContextAssociation = pm_types.ContextAssociation.ASSOCIATED + proposed_context.Handle = 'some_nonexisting_handle' + proposed_context.CoreData.Givenname = 'Karl' + proposed_context.CoreData.Middlename = ['M.'] + proposed_context.CoreData.Familyname = 'Klammer' + proposed_context.CoreData.Birthname = 'Bourne' + proposed_context.CoreData.Title = 'Dr.' + proposed_context.CoreData.Sex = pm_types.Sex.MALE + proposed_context.CoreData.PatientType = pm_types.PatientType.ADULT + proposed_context.CoreData.set_birthdate('2000-12-12') + proposed_context.CoreData.Height = pm_types.Measurement(Decimal('88.2'), pm_types.CodedValue('abc', 'def')) + proposed_context.CoreData.Weight = pm_types.Measurement(Decimal('68.2'), pm_types.CodedValue('abc')) + proposed_context.CoreData.Race = pm_types.CodedValue('somerace') + self.log_watcher.setPaused(True) + future = context.set_context_state(my_operation.handle, [proposed_context]) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FAILED) + self.assertIsNone(result.OperationTarget) + + # insert two new patients for same descriptor, both associated. This shall fail + proposed_context1 = my_patient_entity.new_state() + + proposed_context1.ContextAssociation = pm_types.ContextAssociation.ASSOCIATED + proposed_context2 = my_patient_entity.new_state() + proposed_context2.ContextAssociation = pm_types.ContextAssociation.ASSOCIATED + future = context.set_context_state(my_operation.handle, [proposed_context1, proposed_context2]) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FAILED) + + self.log_watcher.setPaused(False) + + # insert a new patient with correct handle, this shall succeed + proposed_context.Handle = my_patient_entity.handle + future = context.set_context_state(my_operation.handle, [proposed_context]) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FINISHED) + self.assertIsNone(result.InvocationInfo.InvocationError) + self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) + self.assertIsNotNone(result.OperationTarget) + + # check client side patient context, this shall have been set via notification + consumer_entity = client_mdib.entities.by_handle(my_patient_entity.handle) + patient_context_state_container = list(consumer_entity.states.values())[0] # noqa: RUF015 + self.assertEqual(patient_context_state_container.CoreData.Givenname, 'Karl') + self.assertEqual(patient_context_state_container.CoreData.Middlename, ['M.']) + self.assertEqual(patient_context_state_container.CoreData.Familyname, 'Klammer') + self.assertEqual(patient_context_state_container.CoreData.Birthname, 'Bourne') + self.assertEqual(patient_context_state_container.CoreData.Title, 'Dr.') + self.assertEqual(patient_context_state_container.CoreData.Sex, 'M') + self.assertEqual(patient_context_state_container.CoreData.PatientType, pm_types.PatientType.ADULT) + self.assertEqual(patient_context_state_container.CoreData.Height.MeasuredValue, Decimal('88.2')) + self.assertEqual(patient_context_state_container.CoreData.Weight.MeasuredValue, Decimal('68.2')) + self.assertEqual(patient_context_state_container.CoreData.Race, pm_types.CodedValue('somerace')) + self.assertNotEqual(patient_context_state_container.Handle, + my_patient_entity.handle) # device replaced it with its own handle + self.assertEqual(patient_context_state_container.ContextAssociation, pm_types.ContextAssociation.ASSOCIATED) + self.assertIsNotNone(patient_context_state_container.BindingMdibVersion) + self.assertIsNotNone(patient_context_state_container.BindingStartTime) + + # test update of the patient + patient_context_state_container.CoreData.Givenname = 'Karla' + future = context.set_context_state(my_operation.handle, [patient_context_state_container]) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FINISHED) + self.assertEqual(result.OperationTarget, patient_context_state_container.Handle) + + consumer_entity.update() + patient_context_state_container = list(consumer_entity.states.values())[0] # noqa: RUF015 + self.assertEqual(patient_context_state_container.CoreData.Givenname, 'Karla') + self.assertEqual(patient_context_state_container.CoreData.Familyname, 'Klammer') + + # set new patient, check binding mdib versions and context association + proposed_context = my_patient_entity.new_state() + proposed_context.ContextAssociation = pm_types.ContextAssociation.ASSOCIATED + proposed_context.CoreData.Givenname = 'Heidi' + proposed_context.CoreData.Middlename = ['M.'] + proposed_context.CoreData.Familyname = 'Klammer' + proposed_context.CoreData.Birthname = 'Bourne' + proposed_context.CoreData.Title = 'Dr.' + proposed_context.CoreData.Sex = pm_types.Sex.FEMALE + proposed_context.CoreData.PatientType = pm_types.PatientType.ADULT + proposed_context.CoreData.set_birthdate('2000-12-12') + proposed_context.CoreData.Height = pm_types.Measurement(Decimal('88.2'), pm_types.CodedValue('abc', 'def')) + proposed_context.CoreData.Weight = pm_types.Measurement(Decimal('68.2'), pm_types.CodedValue('abc')) + proposed_context.CoreData.Race = pm_types.CodedValue('somerace') + future = context.set_context_state(my_operation.handle, [proposed_context]) + result = future.result(timeout=SET_TIMEOUT) + invocation_state = result.InvocationInfo.InvocationState + self.assertEqual(invocation_state, msg_types.InvocationState.FINISHED) + self.assertIsNone(result.InvocationInfo.InvocationError) + self.assertIsNotNone(result.OperationTarget) + self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) + consumer_entity.update() + patient_context_state_containers = list(consumer_entity.states.values()) + + # sort by BindingMdibVersion + patient_context_state_containers.sort(key=lambda obj: obj.BindingMdibVersion) + self.assertEqual(len(patient_context_state_containers), 2) + old_patient = patient_context_state_containers[0] + new_patient = patient_context_state_containers[1] + self.assertEqual(old_patient.ContextAssociation, pm_types.ContextAssociation.DISASSOCIATED) + self.assertEqual(new_patient.ContextAssociation, pm_types.ContextAssociation.ASSOCIATED) + + # create a patient locally on device, then test update from client + pat_entity = self.sdc_provider.mdib.entities.by_handle(my_patient_entity.handle) + st = pat_entity.new_state() + st.CoreData.Givenname = 'Max123' + st.CoreData.Middlename = ['Willy'] + st.CoreData.Birthname = 'Mustermann' + st.CoreData.Familyname = 'Musterfrau' + st.CoreData.Title = 'Rex' + st.CoreData.Sex = pm_types.Sex.MALE + st.CoreData.PatientType = pm_types.PatientType.ADULT + st.CoreData.Height = pm_types.Measurement(Decimal('88.2'), pm_types.CodedValue('abc', 'def')) + st.CoreData.Weight = pm_types.Measurement(Decimal('68.2'), pm_types.CodedValue('abc')) + st.CoreData.Race = pm_types.CodedValue('123', 'def') + st.CoreData.DateOfBirth = datetime.datetime(2012, 3, 15, 13, 12, 11) + + coll = observableproperties.SingleValueCollector(self.sdc_consumer, 'episodic_context_report') + with self.sdc_provider.mdib.context_state_transaction() as mgr: + mgr.write_entity(pat_entity, modified_handles=[st.Handle]) + coll.result(timeout=NOTIFICATION_TIMEOUT) + + consumer_entity.update() + patient_context_state_containers = list(consumer_entity.states.values()) + + my_patients = [p for p in patient_context_state_containers if p.CoreData.Givenname == 'Max123'] + self.assertEqual(len(my_patients), 1) + my_patient = my_patients[0] + my_patient.CoreData.Givenname = 'Karl123' + future = context.set_context_state(my_operation.handle, [my_patient]) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FINISHED) + consumer_entity.update() + my_updated_patient = consumer_entity.states[my_patient.Handle] + + self.assertEqual(my_updated_patient.CoreData.Givenname, 'Karl123') + + def test_location_context(self): + # initially the device shall have one location, and the client must have it in its mdib + self._init_provider_consumer() + device_mdib = self.sdc_provider.mdib + client_mdib = EntityConsumerMdib(self.sdc_consumer) + client_mdib.init_mdib() + + dev_location_entities = device_mdib.entities.by_node_type(pm.LocationContextDescriptor) + self.assertEqual(len(dev_location_entities), 1) + dev_location_entity = dev_location_entities[0] + loc_context_handle = dev_location_entity.handle + cl_location_entity = client_mdib.entities.by_handle(loc_context_handle) + self.assertIsNotNone(cl_location_entity) + initial_number_of_states = len(dev_location_entity.states) + self.assertGreater(initial_number_of_states, 0) + + self.assertEqual(len(dev_location_entity.states), len(cl_location_entity.states)) + + for i in range(10): + new_location = utils.random_location() + coll = observableproperties.SingleValueCollector(client_mdib, 'context_handles') + self.sdc_provider.set_location(new_location) + coll.result(timeout=NOTIFICATION_TIMEOUT) + dev_location_entity = device_mdib.entities.by_handle(loc_context_handle) + cl_location_entity = client_mdib.entities.by_handle(loc_context_handle) + self.assertEqual(len(dev_location_entity.states), i + 1 + initial_number_of_states) + self.assertEqual(len(cl_location_entity.states), i + 1 + initial_number_of_states) + + # sort by mdib_version + dev_locations = list(dev_location_entity.states.values()) + cl_locations = list(cl_location_entity.states.values()) + + dev_locations.sort(key=lambda a: a.BindingMdibVersion) + cl_locations.sort(key=lambda a: a.BindingMdibVersion) + # Plausibility check that the new location has expected data + self.assertEqual(dev_locations[-1].LocationDetail.PoC, new_location.poc) + self.assertEqual(cl_locations[-1].LocationDetail.PoC, new_location.poc) + self.assertEqual(dev_locations[-1].LocationDetail.Bed, new_location.bed) + self.assertEqual(cl_locations[-1].LocationDetail.Bed, new_location.bed) + self.assertEqual(dev_locations[-1].ContextAssociation, pm_types.ContextAssociation.ASSOCIATED) + self.assertEqual(cl_locations[-1].ContextAssociation, pm_types.ContextAssociation.ASSOCIATED) + self.assertEqual(dev_locations[-1].UnbindingMdibVersion, None) + self.assertEqual(cl_locations[-1].UnbindingMdibVersion, None) + + for j, loc in enumerate(dev_locations[:-1]): + self.assertEqual(loc.ContextAssociation, pm_types.ContextAssociation.DISASSOCIATED) + self.assertEqual(loc.UnbindingMdibVersion, dev_locations[j + 1].BindingMdibVersion) + + for j, loc in enumerate(cl_locations[:-1]): + self.assertEqual(loc.ContextAssociation, pm_types.ContextAssociation.DISASSOCIATED) + self.assertEqual(loc.UnbindingMdibVersion, cl_locations[j + 1].BindingMdibVersion) + + def test_activate(self): + """Test AudioPauseProvider.""" + # switch one alert system off + self._init_provider_consumer(mdib_70041_file) + alert_system_entity_off = self.sdc_provider.mdib.entities.by_handle('Asy.3208') + self.assertIsNotNone(alert_system_entity_off) + alert_system_entity_off.state.ActivationState = pm_types.AlertActivation.OFF + with self.sdc_provider.mdib.alert_state_transaction() as mgr: + mgr.write_entity(alert_system_entity_off) + + set_service = self.sdc_consumer.client('Set') + client_mdib = EntityConsumerMdib(self.sdc_consumer) + client_mdib.init_mdib() + coding = pm_types.Coding(NomenclatureCodes.MDC_OP_SET_ALL_ALARMS_AUDIO_PAUSE) + operation_pause_entities = self.sdc_provider.mdib.entities.by_coding(coding) + coding = pm_types.Coding(NomenclatureCodes.MDC_OP_SET_CANCEL_ALARMS_AUDIO_PAUSE) + operation_cancel_entities = self.sdc_provider.mdib.entities.by_coding(coding) + self.assertEqual(len(operation_pause_entities), 1) + self.assertEqual(len(operation_cancel_entities), 1) + + pause_entity = operation_pause_entities[0] + cancel_entity = operation_cancel_entities[0] + + future = set_service.activate(operation_handle=pause_entity.handle, arguments=None) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FINISHED) + time.sleep(0.5) # allow notifications to arrive + alert_system_entities = self.sdc_provider.mdib.entities.by_node_type(pm.AlertSystemDescriptor) + for alert_system_entity in alert_system_entities: + if alert_system_entity.handle != alert_system_entity_off.handle: + self.assertEqual(alert_system_entity.state.SystemSignalActivation[0].State, + pm_types.AlertActivation.PAUSED) + + future = set_service.activate(operation_handle=cancel_entity.handle, arguments=None) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FINISHED) + time.sleep(0.5) # allow notifications to arrive + for alert_system_entity in alert_system_entities: + if alert_system_entity.handle != alert_system_entity_off.handle: + alert_system_entity.update() + self.assertEqual(alert_system_entity.state.SystemSignalActivation[0].State, + pm_types.AlertActivation.ON) + + # now remove all alert systems from provider mdib and verify that operation now fails + alert_system_entities = self.sdc_provider.mdib.entities.by_node_type(pm.AlertSystemDescriptor) + + with self.sdc_provider.mdib.descriptor_transaction() as mgr: + for ent in alert_system_entities: + mgr.remove_entity(ent) + future = set_service.activate(operation_handle=pause_entity.handle, arguments=None) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FAILED) + + future = set_service.activate(operation_handle=cancel_entity.handle, arguments=None) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FAILED) + + def test_set_ntp_server(self): + self._init_provider_consumer() + set_service = self.sdc_consumer.client('Set') + client_mdib = EntityConsumerMdib(self.sdc_consumer) + client_mdib.init_mdib() + coding = pm_types.Coding(NomenclatureCodes.MDC_OP_SET_TIME_SYNC_REF_SRC) + operation_entities = self.sdc_provider.mdib.entities.by_coding(coding) + self.assertGreater(len(operation_entities), 0) + my_operation_entity = operation_entities[0] + operation_handle = my_operation_entity.handle + for value in ('169.254.0.199', '169.254.0.199:1234'): + self._logger.info('ntp server = %s', value) + future = set_service.set_string(operation_handle=operation_handle, requested_string=value) + result = future.result(timeout=SET_TIMEOUT) + invocation_state = result.InvocationInfo.InvocationState + self.assertEqual(invocation_state, msg_types.InvocationState.FINISHED) + self.assertIsNone(result.InvocationInfo.InvocationError) + self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) + + # verify that the corresponding state has been updated + op_target_entity = client_mdib.entities.by_handle(my_operation_entity.descriptor.OperationTarget) + if op_target_entity.node_type == pm.MdsState: + # look for the ClockState child + clock_entities = client_mdib.entities.by_node_type(pm.ClockDescriptor) + clock_entities = [c for c in clock_entities if c.parent_handle == op_target_entity.handle] + if len(clock_entities) == 1: + op_target_entity = clock_entities[0] + self.assertEqual(op_target_entity.state.ReferenceSource[0], value) + + def test_set_time_zone(self): + self._init_provider_consumer() + set_service = self.sdc_consumer.client('Set') + client_mdib = EntityConsumerMdib(self.sdc_consumer) + client_mdib.init_mdib() + + coding = pm_types.Coding(NomenclatureCodes.MDC_ACT_SET_TIME_ZONE) + operation_entities = self.sdc_provider.mdib.entities.by_coding(coding) + self.assertGreater(len(operation_entities), 0) + my_operation_entity = operation_entities[0] + operation_handle = my_operation_entity.handle + for value in ('+03:00', '-03:00'): # are these correct values? + self._logger.info('time zone = %s', value) + future = set_service.set_string(operation_handle=operation_handle, requested_string=value) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FINISHED) + self.assertIsNone(result.InvocationInfo.InvocationError) + self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) + + # verify that the corresponding state has been updated + op_target_entity = client_mdib.entities.by_handle(my_operation_entity.descriptor.OperationTarget) + if op_target_entity.node_type == pm.MdsState: + # look for the ClockState child + clock_entities = client_mdib.entities.by_node_type(pm.ClockDescriptor) + clock_entities = [c for c in clock_entities if c.parent_handle == op_target_entity.handle] + if len(clock_entities) == 1: + op_target_entity = clock_entities[0] + self.assertEqual(op_target_entity.state.TimeZone, value) + + def test_set_metric_state(self): + # first we need to add a set_metric_state Operation + self._init_provider_consumer() + sco_entities = self.sdc_provider.mdib.entities.by_node_type(pm.ScoDescriptor) + my_sco = sco_entities[0] + + metric_entities = self.sdc_provider.mdib.entities.by_node_type(pm.NumericMetricDescriptor) + my_metric_entity = metric_entities[0] + + new_operation_entity = self.sdc_provider.mdib.entities.new_entity(pm.SetMetricStateOperationDescriptor, + handle='HANDLE_FOR_MY_TEST', + parent_handle=my_sco.handle) + my_code = pm_types.CodedValue('99999') + new_operation_entity.descriptor.Type = my_code + new_operation_entity.descriptor.SafetyClassification = pm_types.SafetyClassification.INF + new_operation_entity.descriptor.OperationTarget = my_metric_entity.handle + + with self.sdc_provider.mdib.descriptor_transaction() as mgr: + mgr.write_entity(new_operation_entity) + + sco = self.sdc_provider._sco_operations_registries[my_sco.handle] + role_provider = self.sdc_provider.product_lookup[my_sco.handle] + + op = role_provider.metric_provider.make_operation_instance( + new_operation_entity.descriptor, sco.operation_cls_getter) + sco.register_operation(op) + self.sdc_provider.mdib.xtra.mk_state_containers_for_all_descriptors() + set_service = self.sdc_consumer.client('Set') + consumer_mdib = EntityConsumerMdib(self.sdc_consumer) + consumer_mdib.init_mdib() + + consumer_entity = consumer_mdib.entities.by_handle(my_metric_entity.handle) + self.assertIsNotNone(consumer_entity) + + # modify entity.state as new proposed state + before_state_version = consumer_entity.state.StateVersion + + operation_handle = new_operation_entity.handle + new_lifetime_period = 42.5 + consumer_entity.state.LifeTimePeriod = new_lifetime_period + future = set_service.set_metric_state(operation_handle=operation_handle, + proposed_metric_states=[consumer_entity.state]) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FINISHED) + self.assertIsNone(result.InvocationInfo.InvocationError) + self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) + consumer_entity.update() + self.assertEqual(consumer_entity.state.StateVersion, before_state_version + 1) + self.assertAlmostEqual(consumer_entity.state.LifeTimePeriod, new_lifetime_period) + + def test_set_component_state(self): + """Test GenericSetComponentStateOperationProvider.""" + # Use a single mds mdib. This makes test easier because source_mds of channel and sco are the same. + self._init_provider_consumer('mdib_tns.xml') + channels = self.sdc_provider.mdib.entities.by_node_type(pm.ChannelDescriptor) + my_channel_entity = channels[0] + # first we need to add a set_component_state Operation + sco_entities = self.sdc_provider.mdib.entities.by_node_type(pm.ScoDescriptor) + my_sco_entity = sco_entities[0] + + operation_entity = self.sdc_provider.mdib.entities.new_entity(pm.SetComponentStateOperationDescriptor, + 'HANDLE_FOR_MY_TEST', + my_sco_entity.handle) + + operation_entity.descriptor.SafetyClassification = pm_types.SafetyClassification.INF + operation_entity.descriptor.OperationTarget = my_channel_entity.handle + operation_entity.descriptor.Type = pm_types.CodedValue('999998') + with self.sdc_provider.mdib.descriptor_transaction() as tr: + tr.write_entity(operation_entity) + + sco = self.sdc_provider._sco_operations_registries[my_sco_entity.handle] + role_provider = self.sdc_provider.product_lookup[my_sco_entity.handle] + op = role_provider.make_operation_instance(operation_entity.descriptor, + sco.operation_cls_getter) + sco.register_operation(op) + self.sdc_provider.mdib.xtra.mk_state_containers_for_all_descriptors() + set_service = self.sdc_consumer.client('Set') + client_mdib = EntityConsumerMdib(self.sdc_consumer) + client_mdib.init_mdib() + + proposed_component_state = my_channel_entity.state #client_mdib.xtra.mk_proposed_state(my_channel_entity.handle) + self.assertIsNone( + proposed_component_state.OperatingHours) # just to be sure that we know the correct intitial value + before_state_version = proposed_component_state.StateVersion + new_operating_hours = 42 + proposed_component_state.OperatingHours = new_operating_hours + future = set_service.set_component_state(operation_handle=operation_entity.handle, + proposed_component_states=[proposed_component_state]) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FINISHED) + self.assertIsNone(result.InvocationInfo.InvocationError) + self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) + + updated_channel_entity = self.sdc_provider.mdib.entities.by_handle(my_channel_entity.handle) + self.assertEqual(updated_channel_entity.state.OperatingHours, new_operating_hours) + self.assertEqual(updated_channel_entity.state.StateVersion, before_state_version + 1) + + def test_operation_without_handler(self): + """Verify that a correct response is sent.""" + self._init_provider_consumer() + set_service = self.sdc_consumer.client('Set') + client_mdib = EntityConsumerMdib(self.sdc_consumer) + client_mdib.init_mdib() + + operation_handle = 'SVO.42.2.1.1.2.0-6' + value = 'foobar' + future = set_service.set_string(operation_handle=operation_handle, requested_string=value) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FAILED) + self.assertIsNotNone(result.InvocationInfo.InvocationError) + # Verify that transaction id increases even with "invalid" calls. + future2 = set_service.set_string(operation_handle=operation_handle, requested_string=value) + result2 = future2.result(timeout=SET_TIMEOUT) + self.assertGreater(result2.InvocationInfo.TransactionId, result.InvocationInfo.TransactionId) + + def test_delayed_processing(self): + """Verify that flag 'delayed_processing' changes responses as expected.""" + self._init_provider_consumer() + logging.getLogger('sdc.client.op_mgr').setLevel(logging.DEBUG) + logging.getLogger('sdc.device.op_reg').setLevel(logging.DEBUG) + logging.getLogger('sdc.device.SetService').setLevel(logging.DEBUG) + logging.getLogger('sdc.device.subscrMgr').setLevel(logging.DEBUG) + set_service = self.sdc_consumer.client('Set') + client_mdib = EntityConsumerMdib(self.sdc_consumer) + client_mdib.init_mdib() + coding = pm_types.Coding(NomenclatureCodes.MDC_OP_SET_TIME_SYNC_REF_SRC) + entities = self.sdc_provider.mdib.entities.by_coding(coding) + self.assertEqual(len(entities), 1) + my_operation_entity = entities[0] + + operation = self.sdc_provider.get_operation_by_handle(my_operation_entity.handle) + for value in ('169.254.0.199', '169.254.0.199:1234'): + self._logger.info('ntp server = %s', value) + operation.delayed_processing = True # first OperationInvokedReport shall have InvocationState.WAIT + coll = observableproperties.SingleValueCollector(self.sdc_consumer, 'operation_invoked_report') + future = set_service.set_string(operation_handle=my_operation_entity.handle, requested_string=value) + result = future.result(timeout=SET_TIMEOUT) + received_message = coll.result(timeout=5) + my_msg_types = received_message.msg_reader.msg_types + operation_invoked_report = my_msg_types.OperationInvokedReport.from_node(received_message.p_msg.msg_node) + self.assertEqual(operation_invoked_report.ReportPart[0].InvocationInfo.InvocationState, + my_msg_types.InvocationState.WAIT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, my_msg_types.InvocationState.FINISHED) + self.assertIsNone(result.InvocationInfo.InvocationError) + self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) + time.sleep(0.5) + # disable delayed processing + self._logger.info("disable delayed processing") + operation.delayed_processing = False # first OperationInvokedReport shall have InvocationState.FINISHED + coll = observableproperties.SingleValueCollector(self.sdc_consumer, 'operation_invoked_report') + future = set_service.set_string(operation_handle=my_operation_entity.handle, requested_string=value) + result = future.result(timeout=SET_TIMEOUT) + received_message = coll.result(timeout=5) + my_msg_types = received_message.msg_reader.msg_types + operation_invoked_report = my_msg_types.OperationInvokedReport.from_node(received_message.p_msg.msg_node) + self.assertEqual(operation_invoked_report.ReportPart[0].InvocationInfo.InvocationState, + my_msg_types.InvocationState.FINISHED) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, my_msg_types.InvocationState.FINISHED) + self.assertIsNone(result.InvocationInfo.InvocationError) + self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) + + def test_set_operating_mode(self): + self._init_provider_consumer() + logging.getLogger('sdc.device.subscrMgr').setLevel(logging.DEBUG) + logging.getLogger('ssdc.client.subscr').setLevel(logging.DEBUG) + consumer_mdib = EntityConsumerMdib(self.sdc_consumer) + consumer_mdib.init_mdib() + + operation_entities = consumer_mdib.entities.by_node_type(pm.ActivateOperationDescriptor) + my_operation_entity = operation_entities[0] + operation = self.sdc_provider.get_operation_by_handle(my_operation_entity.handle) + for op_mode in (pm_types.OperatingMode.NA, pm_types.OperatingMode.ENABLED): + operation.set_operating_mode(op_mode) + time.sleep(1) + my_operation_entity.update() + self.assertEqual(my_operation_entity.state.OperatingMode, op_mode) + + def test_set_string_value(self): + """Verify that metric provider instantiated an operation for SetString call. + + OperationTarget of operation 0815 is an EnumStringMetricState. + """ + self._init_provider_consumer(mdib_70041_file) + set_service = self.sdc_consumer.client('Set') + client_mdib = EntityConsumerMdib(self.sdc_consumer) + client_mdib.init_mdib() + coding = pm_types.Coding('0815') + my_operation_entities = self.sdc_provider.mdib.entities.by_coding(coding) + self.assertEqual(len(my_operation_entities), 1) + my_operation_entity = my_operation_entities[0] + operation_handle = my_operation_entity.handle + for value in ('ADULT', 'PEDIATRIC'): + self._logger.info('string value = %s', value) + future = set_service.set_string(operation_handle=operation_handle, requested_string=value) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FINISHED) + self.assertIsNone(result.InvocationInfo.InvocationError) + self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) + + # verify that the corresponding state has been updated + consumer_entity = client_mdib.entities.by_handle(my_operation_entity.descriptor.OperationTarget) + self.assertEqual(consumer_entity.state.MetricValue.Value, value) + + def test_set_metric_value(self): + """Verify that metric provider instantiated an operation for SetNumericValue call. + + OperationTarget of operation 0815-1 is a NumericMetricState. + """ + self._init_provider_consumer(mdib_70041_file) + set_service = self.sdc_consumer.client('Set') + client_mdib = EntityConsumerMdib(self.sdc_consumer) + client_mdib.init_mdib() + coding = pm_types.Coding('0815-1') + my_operation_entities = self.sdc_provider.mdib.entities.by_coding(coding) + my_operation_entity = my_operation_entities[0] + + for value in (Decimal(1), Decimal(42), 1.1, 10, "12"): + self._logger.info('metric value = %s', value) + future = set_service.set_numeric_value(operation_handle=my_operation_entity.handle, + requested_numeric_value=value) + result = future.result(timeout=SET_TIMEOUT) + state = result.InvocationInfo.InvocationState + self.assertEqual(state, msg_types.InvocationState.FINISHED) + self.assertIsNone(result.InvocationInfo.InvocationError) + self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) + + # verify that the corresponding state has been updated + ent = client_mdib.entities.by_handle(my_operation_entity.descriptor.OperationTarget) + self.assertEqual(ent.state.MetricValue.Value, Decimal(str(value))) diff --git a/tests/test_entity_transaction.py b/tests/test_entity_transaction.py new file mode 100644 index 00000000..79570bea --- /dev/null +++ b/tests/test_entity_transaction.py @@ -0,0 +1,214 @@ +"""Tests for transactions of EntityProviderMdib.""" +import pathlib +import unittest + +from sdc11073.definitions_sdc import SdcV1Definitions +from sdc11073.entity_mdib.entity_providermdib import EntityProviderMdib +from sdc11073.exceptions import ApiUsageError +from sdc11073.xml_types import pm_qnames, pm_types + +mdib_file = str(pathlib.Path(__file__).parent.joinpath('mdib_tns.xml')) + + +class TestEntityTransactions(unittest.TestCase): + """Test all kinds of transactions for entity interface of EntityProviderMdib.""" + + def setUp(self): + self._mdib = EntityProviderMdib.from_mdib_file(mdib_file, + protocol_definition=SdcV1Definitions) + + def test_alert_state_update(self): + """Verify that alert_state_transaction works as expected. + + - mdib_version is incremented + - StateVersion is incremented in mdib state + - updated state is referenced in transaction_result + - observable alert_by_handle is updated + - ApiUsageError is thrown if state of wrong kind is added, + """ + mdib_version = self._mdib.mdib_version + old_ac_entity = self._mdib.entities.by_node_type(pm_qnames.AlertConditionDescriptor)[0] + old_ac_entity.state.Presence = True + with self._mdib.alert_state_transaction() as mgr: + mgr.write_entity(old_ac_entity) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + self.assertEqual(len(self._mdib.transaction.alert_updates), 1) # this causes an EpisodicAlertReport + + new_ac_entity = self._mdib.entities.by_handle(old_ac_entity.handle) + self.assertEqual(new_ac_entity.state.StateVersion, old_ac_entity.state.StateVersion + 1) + + metric_entities = self._mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor) + with self._mdib.alert_state_transaction() as mgr: + self.assertRaises(ApiUsageError, mgr.write_entity, metric_entities[0]) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + + def test_metric_state_update(self): + """Verify that metric_state_transaction works as expected. + + - mdib_version is incremented + - StateVersion is incremented in mdib state + - updated state is referenced in transaction_result + - observable metric_by_handle is updated + - ApiUsageError is thrown if state of wrong kind is added, + """ + mdib_version = self._mdib.mdib_version + old_metric_entity = self._mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor)[0] + old_metric_entity.state.LifeTimePeriod = 2 + with self._mdib.metric_state_transaction() as mgr: + mgr.write_entity(old_metric_entity) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + self.assertEqual(len(self._mdib.transaction.metric_updates), 1) + + new_metric_entity = self._mdib.entities.by_handle(old_metric_entity.handle) + self.assertEqual(new_metric_entity.state.StateVersion, old_metric_entity.state.StateVersion + 1) + + ac_entities = self._mdib.entities.by_node_type(pm_qnames.AlertConditionDescriptor) + with self._mdib.metric_state_transaction() as mgr: + self.assertRaises(ApiUsageError, mgr.write_entity, ac_entities[0]) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + + def test_operational_state_update(self): + """Verify that operational_state_transaction works as expected. + + - mdib_version is incremented + - StateVersion is incremented in mdib state + - updated state is referenced in transaction_result + - observable operation_by_handle is updated + - ApiUsageError is thrown if state of wrong kind is added, + """ + mdib_version = self._mdib.mdib_version + old_op_entity = self._mdib.entities.by_node_type(pm_qnames.SetAlertStateOperationDescriptor)[0] + old_op_entity.state.OperationMode = pm_types.OperatingMode.DISABLED + with self._mdib.operational_state_transaction() as mgr: + mgr.write_entity(old_op_entity) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + self.assertEqual(len(self._mdib.transaction.op_updates), 1) + self.assertTrue(old_op_entity.handle in self._mdib.operation_by_handle) + + new_op_entity = self._mdib.entities.by_handle(old_op_entity.handle) + self.assertEqual(new_op_entity.state.StateVersion, old_op_entity.state.StateVersion + 1) + + metric_entities = self._mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor) + with self._mdib.alert_state_transaction() as mgr: + self.assertRaises(ApiUsageError, mgr.write_entity, metric_entities[0]) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + + def test_context_state_transaction(self): + """Verify that context_state_transaction works as expected. + + - mk_context_state method works as expected + - mdib_version is incremented + - StateVersion is incremented in mdib state + - updated state is referenced in transaction_result + - observable context_by_handle is updated + - ApiUsageError is thrown if state of wrong kind is added + """ + mdib_version = self._mdib.mdib_version + old_pat_entity = self._mdib.entities.by_node_type(pm_qnames.PatientContextDescriptor)[0] + new_state = old_pat_entity.new_state() + self.assertIsNotNone(new_state.Handle) + new_state.CoreData.Givenname = 'foo' + new_state.CoreData.Familyname = 'bar' + context_handle = new_state.Handle + + with self._mdib.context_state_transaction() as mgr: + mgr.write_entity(old_pat_entity, [context_handle]) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + self.assertEqual(len(self._mdib.transaction.ctxt_updates), 1) + + new_pat_entity = self._mdib.entities.by_handle(old_pat_entity.handle) + self.assertEqual(new_pat_entity.states[context_handle].StateVersion,0) + self.assertEqual(new_pat_entity.states[context_handle].CoreData.Givenname,'foo') + self.assertEqual(new_pat_entity.states[context_handle].CoreData.Familyname,'bar') + + new_pat_entity.states[context_handle].CoreData.Familyname = 'foobar' + + with self._mdib.context_state_transaction() as mgr: + mgr.write_entity(new_pat_entity, [context_handle]) + + newest_pat_entity = self._mdib.entities.by_handle(old_pat_entity.handle) + self.assertEqual(newest_pat_entity.states[context_handle].StateVersion,1) + self.assertEqual(newest_pat_entity.states[context_handle].CoreData.Familyname,'foobar') + + metric_entities = self._mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor) + with self._mdib.alert_state_transaction() as mgr: + self.assertRaises(ApiUsageError, mgr.write_entity, metric_entities[0]) + self.assertEqual(mdib_version + 2, self._mdib.mdib_version) + + def test_description_modification(self): + """Verify that descriptor_transaction works as expected. + + - mdib_version is incremented + - observable updated_descriptors_by_handle is updated + - corresponding states for descriptor modifications are also updated + - ApiUsageError is thrown if data of wrong kind is requested + """ + mdib_version = self._mdib.mdib_version + old_ac_entity = self._mdib.entities.by_node_type(pm_qnames.AlertConditionDescriptor)[0] + old_metric_entity = self._mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor)[0] + old_op_entity = self._mdib.entities.by_node_type(pm_qnames.SetAlertStateOperationDescriptor)[0] + old_comp_entity = self._mdib.entities.by_node_type(pm_qnames.ChannelDescriptor)[0] + old_rt_entity = self._mdib.entities.by_node_type(pm_qnames.RealTimeSampleArrayMetricDescriptor)[0] + old_ctx_entity = self._mdib.entities.by_node_type(pm_qnames.PatientContextDescriptor)[0] + + with self._mdib.descriptor_transaction() as mgr: + # verify that updating descriptors of different kinds and accessing corresponding states works + mgr.write_entity(old_ac_entity) + mgr.write_entity(old_metric_entity) + mgr.write_entity(old_op_entity) + mgr.write_entity(old_comp_entity) + mgr.write_entity(old_rt_entity) + mgr.write_entity(old_ctx_entity) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + transaction_result = self._mdib.transaction + self.assertEqual(len(transaction_result.metric_updates), 1) + self.assertEqual(len(transaction_result.alert_updates), 1) + self.assertEqual(len(transaction_result.op_updates), 1) + self.assertEqual(len(transaction_result.comp_updates), 1) + self.assertEqual(len(transaction_result.rt_updates), 1) + self.assertEqual(len(transaction_result.ctxt_updates), 1) + self.assertEqual(len(transaction_result.descr_updated), 6) + + self.assertTrue(old_ac_entity.handle in self._mdib.updated_descriptors_by_handle) + self.assertTrue(old_ac_entity.handle in self._mdib.alert_by_handle) + self.assertTrue(old_metric_entity.handle in self._mdib.updated_descriptors_by_handle) + self.assertTrue(old_metric_entity.handle in self._mdib.metrics_by_handle) + self.assertTrue(old_op_entity.handle in self._mdib.updated_descriptors_by_handle) + self.assertTrue(old_op_entity.handle in self._mdib.operation_by_handle) + self.assertTrue(old_comp_entity.handle in self._mdib.updated_descriptors_by_handle) + self.assertTrue(old_comp_entity.handle in self._mdib.component_by_handle) + self.assertTrue(old_rt_entity.handle in self._mdib.updated_descriptors_by_handle) + + def test_remove_add(self): + """Verify that removing descriptors / states and adding them later again results in correct versions.""" + # remove all root descriptors + all_entities = {} + for handle in self._mdib._entities: + all_entities[handle] = self._mdib.entities.by_handle(handle) # get external representation + + root_entities = self._mdib.entities.by_parent_handle(None) + with self._mdib.descriptor_transaction() as mgr: + for ent in root_entities: + mgr.remove_entity(ent) + + self.assertEqual(0, len(self._mdib._entities)) + + # add all entities again + with self._mdib.descriptor_transaction() as mgr: + mgr.write_entities(all_entities.values()) + + # verify that the number of entities is the same as before + self.assertEqual(len(all_entities), len(self._mdib.internal_entities)) + + # verify that all descriptors and states have incremented version counters + for current_ent in self._mdib.internal_entities.values(): + old_ent = all_entities[current_ent.handle] + self.assertEqual(current_ent.descriptor.DescriptorVersion, old_ent.descriptor.DescriptorVersion + 1) + if current_ent.is_multi_state: + for handle, current_state in current_ent.states.items(): + old_state = old_ent.states[handle] + self.assertEqual(current_state.StateVersion, old_state.StateVersion + 1) + self.assertEqual(current_state.DescriptorVersion, current_ent.descriptor.DescriptorVersion) + else: + self.assertEqual(current_ent.state.StateVersion, old_ent.state.StateVersion + 1) + self.assertEqual(current_ent.state.DescriptorVersion, current_ent.descriptor.DescriptorVersion) diff --git a/tests/test_operations.py b/tests/test_operations.py index 594cb600..1c94cfce 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -1,20 +1,20 @@ +"""The module tests operations with provider and consumer that use old mdibs.""" import datetime import logging import time import unittest from decimal import Decimal -from sdc11073 import commlog -from sdc11073 import loghelper -from sdc11073 import observableproperties -from sdc11073.xml_types import pm_types, msg_types, pm_qnames as pm +from sdc11073 import commlog, loghelper, observableproperties +from sdc11073.consumer import SdcConsumer +from sdc11073.consumer.components import SdcConsumerComponents +from sdc11073.dispatch import RequestDispatcher from sdc11073.loghelper import basic_logging_setup from sdc11073.mdib import ConsumerMdib from sdc11073.roles.nomenclature import NomenclatureCodes -from sdc11073.consumer import SdcConsumer from sdc11073.wsdiscovery import WSDiscovery -from sdc11073.consumer.components import SdcConsumerComponents -from sdc11073.dispatch import RequestDispatcher +from sdc11073.xml_types import msg_types, pm_types +from sdc11073.xml_types import pm_qnames as pm from tests import utils from tests.mockstuff import SomeDevice @@ -31,7 +31,7 @@ NOTIFICATION_TIMEOUT = 5 # also jenkins related value -class Test_BuiltinOperations(unittest.TestCase): +class TestBuiltinOperations(unittest.TestCase): """Test role providers (located in sdc11073.roles).""" def setUp(self): @@ -51,7 +51,7 @@ def setUp(self): x_addr = self.sdc_device.get_xaddrs() # no deferred action handling for easier debugging specific_components = SdcConsumerComponents( - action_dispatcher_class=RequestDispatcher + action_dispatcher_class=RequestDispatcher, ) self.sdc_client = SdcConsumer(x_addr[0], sdc_definitions=self.sdc_device.mdib.sdc_definitions, @@ -80,18 +80,20 @@ def tearDown(self): self._logger.info('############### tearDown %s done ##############\n', self._testMethodName) def test_set_patient_context_operation(self): - """client calls corresponding operation of GenericContextProvider. + """Verify that consumer calls corresponding operation of GenericContextProvider. + - verify that operation is successful. - verify that a notification device->client also updates the client mdib.""" - client_mdib = ConsumerMdib(self.sdc_client) - client_mdib.init_mdib() - patient_descriptor_container = client_mdib.descriptions.NODETYPE.get_one(pm.PatientContextDescriptor) + - verify that a notification device->client also updates the client mdib. + """ + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() + patient_descriptor_container = consumer_mdib.descriptions.NODETYPE.get_one(pm.PatientContextDescriptor) # initially the device shall not have any patient - patient_context_state_container = client_mdib.context_states.NODETYPE.get_one( + patient_context_state_container = consumer_mdib.context_states.NODETYPE.get_one( pm.PatientContext, allow_none=True) self.assertIsNone(patient_context_state_container) - my_operations = client_mdib.get_operation_descriptors_for_descriptor_handle( + my_operations = consumer_mdib.get_operation_descriptors_for_descriptor_handle( patient_descriptor_container.Handle, NODETYPE=pm.SetContextStateOperationDescriptor) self.assertEqual(len(my_operations), 1) @@ -146,7 +148,7 @@ def test_set_patient_context_operation(self): self.assertIsNotNone(result.OperationTarget) # check client side patient context, this shall have been set via notification - patient_context_state_container = client_mdib.context_states.NODETYPE.get_one(pm.PatientContextState) + patient_context_state_container = consumer_mdib.context_states.NODETYPE.get_one(pm.PatientContextState) self.assertEqual(patient_context_state_container.CoreData.Givenname, 'Karl') self.assertEqual(patient_context_state_container.CoreData.Middlename, ['M.']) self.assertEqual(patient_context_state_container.CoreData.Familyname, 'Klammer') @@ -172,7 +174,7 @@ def test_set_patient_context_operation(self): state = result.InvocationInfo.InvocationState self.assertEqual(state, msg_types.InvocationState.FINISHED) self.assertEqual(result.OperationTarget, proposed_context.Handle) - patient_context_state_container = client_mdib.context_states.handle.get_one( + patient_context_state_container = consumer_mdib.context_states.handle.get_one( patient_context_state_container.Handle) self.assertEqual(patient_context_state_container.CoreData.Givenname, 'Karla') self.assertEqual(patient_context_state_container.CoreData.Familyname, 'Klammer') @@ -198,7 +200,7 @@ def test_set_patient_context_operation(self): self.assertIsNone(result.InvocationInfo.InvocationError) self.assertIsNotNone(result.OperationTarget) self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) - patient_context_state_containers = client_mdib.context_states.NODETYPE.get(pm.PatientContextState, []) + patient_context_state_containers = consumer_mdib.context_states.NODETYPE.get(pm.PatientContextState, []) # sort by BindingMdibVersion patient_context_state_containers.sort(key=lambda obj: obj.BindingMdibVersion) self.assertEqual(len(patient_context_state_containers), 2) @@ -223,7 +225,7 @@ def test_set_patient_context_operation(self): st.CoreData.Race = pm_types.CodedValue('123', 'def') st.CoreData.DateOfBirth = datetime.datetime(2012, 3, 15, 13, 12, 11) coll.result(timeout=NOTIFICATION_TIMEOUT) - patient_context_state_containers = client_mdib.context_states.NODETYPE.get(pm.PatientContextState) + patient_context_state_containers = consumer_mdib.context_states.NODETYPE.get(pm.PatientContextState) my_patients = [p for p in patient_context_state_containers if p.CoreData.Givenname == 'Max123'] self.assertEqual(len(my_patients), 1) my_patient = my_patients[0] @@ -239,11 +241,11 @@ def test_set_patient_context_operation(self): def test_location_context(self): # initially the device shall have one location, and the client must have it in its mdib device_mdib = self.sdc_device.mdib - client_mdib = ConsumerMdib(self.sdc_client) - client_mdib.init_mdib() + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() dev_locations = device_mdib.context_states.NODETYPE.get(pm.LocationContextState) - cl_locations = client_mdib.context_states.NODETYPE.get(pm.LocationContextState) + cl_locations = consumer_mdib.context_states.NODETYPE.get(pm.LocationContextState) self.assertEqual(len(dev_locations), 1) self.assertEqual(len(cl_locations), 1) self.assertEqual(dev_locations[0].Handle, cl_locations[0].Handle) @@ -253,11 +255,11 @@ def test_location_context(self): for i in range(10): new_location = utils.random_location() - coll = observableproperties.SingleValueCollector(client_mdib, 'context_by_handle') + coll = observableproperties.SingleValueCollector(consumer_mdib, 'context_by_handle') self.sdc_device.set_location(new_location) coll.result(timeout=NOTIFICATION_TIMEOUT) dev_locations = device_mdib.context_states.NODETYPE.get(pm.LocationContextState) - cl_locations = client_mdib.context_states.NODETYPE.get(pm.LocationContextState) + cl_locations = consumer_mdib.context_states.NODETYPE.get(pm.LocationContextState) self.assertEqual(len(dev_locations), i + 2) self.assertEqual(len(cl_locations), i + 2) @@ -283,9 +285,7 @@ def test_location_context(self): self.assertEqual(loc.UnbindingMdibVersion, cl_locations[j + 1].BindingMdibVersion) def test_audio_pause(self): - """Tests AudioPauseProvider - - """ + """Test AudioPauseProvider.""" # switch one alert system off alert_system_off = 'Asy.3208' with self.sdc_device.mdib.alert_state_transaction() as mgr: @@ -296,8 +296,8 @@ def test_audio_pause(self): self.assertGreater(len(alert_system_descriptors), 0) set_service = self.sdc_client.client('Set') - client_mdib = ConsumerMdib(self.sdc_client) - client_mdib.init_mdib() + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() coding = pm_types.Coding(NomenclatureCodes.MDC_OP_SET_ALL_ALARMS_AUDIO_PAUSE) operation_pause = self.sdc_device.mdib.descriptions.coding.get_one(coding) coding = pm_types.Coding(NomenclatureCodes.MDC_OP_SET_CANCEL_ALARMS_AUDIO_PAUSE) @@ -360,7 +360,7 @@ def test_audio_pause_two_clients(self): x_addr = self.sdc_device.get_xaddrs() # no deferred action handling for easier debugging specific_components = SdcConsumerComponents( - action_dispatcher_class=RequestDispatcher + action_dispatcher_class=RequestDispatcher, ) sdc_client2 = SdcConsumer(x_addr[0], sdc_definitions=self.sdc_device.mdib.sdc_definitions, @@ -410,8 +410,8 @@ def test_audio_pause_two_clients(self): def test_set_ntp_server(self): set_service = self.sdc_client.client('Set') - client_mdib = ConsumerMdib(self.sdc_client) - client_mdib.init_mdib() + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() coding = pm_types.Coding(NomenclatureCodes.MDC_OP_SET_TIME_SYNC_REF_SRC) my_operation_descriptor = self.sdc_device.mdib.descriptions.coding.get_one(coding, allow_none=True) @@ -426,19 +426,19 @@ def test_set_ntp_server(self): self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) # verify that the corresponding state has been updated - state = client_mdib.states.descriptor_handle.get_one(my_operation_descriptor.OperationTarget) - if state.NODETYPE == pm.MdsState: + state = consumer_mdib.states.descriptor_handle.get_one(my_operation_descriptor.OperationTarget) + if pm.MdsState == state.NODETYPE: # look for the ClockState child - clock_descriptors = client_mdib.descriptions.NODETYPE.get(pm.ClockDescriptor, []) + clock_descriptors = consumer_mdib.descriptions.NODETYPE.get(pm.ClockDescriptor, []) clock_descriptors = [c for c in clock_descriptors if c.descriptor_handle == state.descriptor_handle] if len(clock_descriptors) == 1: - state = client_mdib.states.descriptor_handle.get_one(clock_descriptors[0].Handle) + state = consumer_mdib.states.descriptor_handle.get_one(clock_descriptors[0].Handle) self.assertEqual(state.ReferenceSource[0], value) def test_set_time_zone(self): set_service = self.sdc_client.client('Set') - client_mdib = ConsumerMdib(self.sdc_client) - client_mdib.init_mdib() + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() coding = pm_types.Coding(NomenclatureCodes.MDC_ACT_SET_TIME_ZONE) my_operation_descriptor = self.sdc_device.mdib.descriptions.coding.get_one(coding, allow_none=True) @@ -454,26 +454,31 @@ def test_set_time_zone(self): self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) # verify that the corresponding state has been updated - state = client_mdib.states.descriptor_handle.get_one(my_operation_descriptor.OperationTarget) - if state.NODETYPE == pm.MdsState: + state = consumer_mdib.states.descriptor_handle.get_one(my_operation_descriptor.OperationTarget) + if pm.MdsState == state.NODETYPE: # look for the ClockState child - clock_descriptors = client_mdib.descriptions.NODETYPE.get(pm.ClockDescriptor, []) + clock_descriptors = consumer_mdib.descriptions.NODETYPE.get(pm.ClockDescriptor, []) clock_descriptors = [c for c in clock_descriptors if c.parent_handle == state.DescriptorHandle] if len(clock_descriptors) == 1: - state = client_mdib.states.descriptor_handle.get_one(clock_descriptors[0].Handle) + state = consumer_mdib.states.descriptor_handle.get_one(clock_descriptors[0].Handle) self.assertEqual(state.TimeZone, value) def test_set_metric_state(self): # first we need to add a set_metric_state Operation sco_descriptors = self.sdc_device.mdib.descriptions.NODETYPE.get(pm.ScoDescriptor) - cls = self.sdc_device.mdib.data_model.get_descriptor_container_class(pm.SetMetricStateOperationDescriptor) + descr_cls = self.sdc_device.mdib.data_model.get_descriptor_container_class(pm.SetMetricStateOperationDescriptor) + state_cls = self.sdc_device.mdib.data_model.get_state_container_class(descr_cls.STATE_QNAME) operation_target_handle = '0x34F001D5' my_code = pm_types.CodedValue('99999') - my_operation_descriptor = cls('HANDLE_FOR_MY_TEST', sco_descriptors[0].Handle) + my_operation_descriptor = descr_cls('HANDLE_FOR_MY_TEST', sco_descriptors[0].Handle) my_operation_descriptor.Type = my_code my_operation_descriptor.SafetyClassification = pm_types.SafetyClassification.INF my_operation_descriptor.OperationTarget = operation_target_handle self.sdc_device.mdib.descriptions.add_object(my_operation_descriptor) + + my_operation_state = state_cls(my_operation_descriptor) + self.sdc_device.mdib.states.add_object(my_operation_state) + sco_handle = 'Sco.mds0' sco = self.sdc_device._sco_operations_registries[sco_handle] role_provider = self.sdc_device.product_lookup[sco_handle] @@ -482,40 +487,45 @@ def test_set_metric_state(self): my_operation_descriptor, sco.operation_cls_getter) sco.register_operation(op) self.sdc_device.mdib.xtra.mk_state_containers_for_all_descriptors() - setService = self.sdc_client.client('Set') - clientMdib = ConsumerMdib(self.sdc_client) - clientMdib.init_mdib() + set_service = self.sdc_client.client('Set') + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() operation_handle = my_operation_descriptor.Handle - proposed_metric_state = clientMdib.xtra.mk_proposed_state(operation_target_handle) + proposed_metric_state = consumer_mdib.xtra.mk_proposed_state(operation_target_handle) self.assertIsNone( proposed_metric_state.LifeTimePeriod) # just to be sure that we know the correct intitial value before_state_version = proposed_metric_state.StateVersion - newLifeTimePeriod = 42.5 - proposed_metric_state.LifeTimePeriod = newLifeTimePeriod - future = setService.set_metric_state(operation_handle=operation_handle, + new_lifetimeperiod = 42.5 + proposed_metric_state.LifeTimePeriod = new_lifetimeperiod + future = set_service.set_metric_state(operation_handle=operation_handle, proposed_metric_states=[proposed_metric_state]) result = future.result(timeout=SET_TIMEOUT) state = result.InvocationInfo.InvocationState self.assertEqual(state, msg_types.InvocationState.FINISHED) self.assertIsNone(result.InvocationInfo.InvocationError) self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) - updated_metric_state = clientMdib.states.descriptor_handle.get_one(operation_target_handle) + updated_metric_state = consumer_mdib.states.descriptor_handle.get_one(operation_target_handle) self.assertEqual(updated_metric_state.StateVersion, before_state_version + 1) - self.assertAlmostEqual(updated_metric_state.LifeTimePeriod, newLifeTimePeriod) + self.assertAlmostEqual(updated_metric_state.LifeTimePeriod, new_lifetimeperiod) def test_set_component_state(self): - """ tests GenericSetComponentStateOperationProvider""" + """Test GenericSetComponentStateOperationProvider.""" operation_target_handle = '2.1.2.1' # a channel # first we need to add a set_component_state Operation sco_descriptors = self.sdc_device.mdib.descriptions.NODETYPE.get(pm.ScoDescriptor) - cls = self.sdc_device.mdib.data_model.get_descriptor_container_class(pm.SetComponentStateOperationDescriptor) - my_operation_descriptor = cls('HANDLE_FOR_MY_TEST', sco_descriptors[0].Handle) - my_operation_descriptor.SafetyClassification = pm_types.SafetyClassification.INF + descr_cls = self.sdc_device.mdib.data_model.get_descriptor_container_class(pm.SetComponentStateOperationDescriptor) + state_cls = self.sdc_device.mdib.data_model.get_state_container_class(descr_cls.STATE_QNAME) + my_operation_descriptor = descr_cls('HANDLE_FOR_MY_TEST', sco_descriptors[0].Handle) + my_operation_descriptor.SafetyClassification = pm_types.SafetyClassification.INF my_operation_descriptor.OperationTarget = operation_target_handle my_operation_descriptor.Type = pm_types.CodedValue('999998') self.sdc_device.mdib.descriptions.add_object(my_operation_descriptor) + + my_operation_state = state_cls(my_operation_descriptor) + self.sdc_device.mdib.states.add_object(my_operation_state) + sco_handle = 'Sco.mds0' sco = self.sdc_device._sco_operations_registries[sco_handle] role_provider = self.sdc_device.product_lookup[sco_handle] @@ -523,11 +533,11 @@ def test_set_component_state(self): sco.register_operation(op) self.sdc_device.mdib.xtra.mk_state_containers_for_all_descriptors() set_service = self.sdc_client.client('Set') - client_mdib = ConsumerMdib(self.sdc_client) - client_mdib.init_mdib() + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() operation_handle = my_operation_descriptor.Handle - proposed_component_state = client_mdib.xtra.mk_proposed_state(operation_target_handle) + proposed_component_state = consumer_mdib.xtra.mk_proposed_state(operation_target_handle) self.assertIsNone( proposed_component_state.OperatingHours) # just to be sure that we know the correct intitial value before_state_version = proposed_component_state.StateVersion @@ -540,15 +550,15 @@ def test_set_component_state(self): self.assertEqual(state, msg_types.InvocationState.FINISHED) self.assertIsNone(result.InvocationInfo.InvocationError) self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) - updated_component_state = client_mdib.states.descriptor_handle.get_one(operation_target_handle) + updated_component_state = consumer_mdib.states.descriptor_handle.get_one(operation_target_handle) self.assertEqual(updated_component_state.StateVersion, before_state_version + 1) self.assertEqual(updated_component_state.OperatingHours, new_operating_hours) def test_operation_without_handler(self): """Verify that a correct response is sent.""" set_service = self.sdc_client.client('Set') - client_mdib = ConsumerMdib(self.sdc_client) - client_mdib.init_mdib() + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() operation_handle = 'SVO.42.2.1.1.2.0-6' value = 'foobar' @@ -569,8 +579,8 @@ def test_delayed_processing(self): logging.getLogger('sdc.device.SetService').setLevel(logging.DEBUG) logging.getLogger('sdc.device.subscrMgr').setLevel(logging.DEBUG) set_service = self.sdc_client.client('Set') - client_mdib = ConsumerMdib(self.sdc_client) - client_mdib.init_mdib() + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() coding = pm_types.Coding(NomenclatureCodes.MDC_OP_SET_TIME_SYNC_REF_SRC) my_operation_descriptor = self.sdc_device.mdib.descriptions.coding.get_one(coding, allow_none=True) @@ -583,12 +593,12 @@ def test_delayed_processing(self): future = set_service.set_string(operation_handle=operation_handle, requested_string=value) result = future.result(timeout=SET_TIMEOUT) received_message = coll.result(timeout=5) - msg_types = received_message.msg_reader.msg_types - operation_invoked_report = msg_types.OperationInvokedReport.from_node(received_message.p_msg.msg_node) + my_msg_types = received_message.msg_reader.msg_types + operation_invoked_report = my_msg_types.OperationInvokedReport.from_node(received_message.p_msg.msg_node) self.assertEqual(operation_invoked_report.ReportPart[0].InvocationInfo.InvocationState, - msg_types.InvocationState.WAIT) + my_msg_types.InvocationState.WAIT) state = result.InvocationInfo.InvocationState - self.assertEqual(state, msg_types.InvocationState.FINISHED) + self.assertEqual(state, my_msg_types.InvocationState.FINISHED) self.assertIsNone(result.InvocationInfo.InvocationError) self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) time.sleep(0.5) @@ -599,37 +609,37 @@ def test_delayed_processing(self): future = set_service.set_string(operation_handle=operation_handle, requested_string=value) result = future.result(timeout=SET_TIMEOUT) received_message = coll.result(timeout=5) - msg_types = received_message.msg_reader.msg_types - operation_invoked_report = msg_types.OperationInvokedReport.from_node(received_message.p_msg.msg_node) + my_msg_types = received_message.msg_reader.msg_types + operation_invoked_report = my_msg_types.OperationInvokedReport.from_node(received_message.p_msg.msg_node) self.assertEqual(operation_invoked_report.ReportPart[0].InvocationInfo.InvocationState, - msg_types.InvocationState.FINISHED) + my_msg_types.InvocationState.FINISHED) state = result.InvocationInfo.InvocationState - self.assertEqual(state, msg_types.InvocationState.FINISHED) + self.assertEqual(state, my_msg_types.InvocationState.FINISHED) self.assertIsNone(result.InvocationInfo.InvocationError) self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) def test_set_operating_mode(self): logging.getLogger('sdc.device.subscrMgr').setLevel(logging.DEBUG) logging.getLogger('ssdc.client.subscr').setLevel(logging.DEBUG) - client_mdib = ConsumerMdib(self.sdc_client) - client_mdib.init_mdib() + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() operation_handle = 'SVO.37.3569' operation = self.sdc_device.get_operation_by_handle(operation_handle) for op_mode in (pm_types.OperatingMode.NA, pm_types.OperatingMode.ENABLED): operation.set_operating_mode(op_mode) time.sleep(1) - operation_state = client_mdib.states.descriptor_handle.get_one(operation_handle) + operation_state = consumer_mdib.states.descriptor_handle.get_one(operation_handle) self.assertEqual(operation_state.OperatingMode, op_mode) def test_set_string_value(self): """Verify that metricprovider instantiated an operation for SetString call. - OperationTarget of operation 0815 is an EnumStringMetricState. - """ + OperationTarget of operation 0815 is an EnumStringMetricState. + """ set_service = self.sdc_client.client('Set') - client_mdib = ConsumerMdib(self.sdc_client) - client_mdib.init_mdib() + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() coding = pm_types.Coding('0815') my_operation_descriptor = self.sdc_device.mdib.descriptions.coding.get_one(coding, allow_none=True) @@ -644,17 +654,17 @@ def test_set_string_value(self): self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) # verify that the corresponding state has been updated - state = client_mdib.states.descriptor_handle.get_one(my_operation_descriptor.OperationTarget) + state = consumer_mdib.states.descriptor_handle.get_one(my_operation_descriptor.OperationTarget) self.assertEqual(state.MetricValue.Value, value) def test_set_metric_value(self): """Verify that metricprovider instantiated an operation for SetNumericValue call. - OperationTarget of operation 0815-1 is a NumericMetricState. - """ + OperationTarget of operation 0815-1 is a NumericMetricState. + """ set_service = self.sdc_client.client('Set') - client_mdib = ConsumerMdib(self.sdc_client) - client_mdib.init_mdib() + consumer_mdib = ConsumerMdib(self.sdc_client) + consumer_mdib.init_mdib() coding = pm_types.Coding('0815-1') my_operation_descriptor = self.sdc_device.mdib.descriptions.coding.get_one(coding, allow_none=True) @@ -670,5 +680,5 @@ def test_set_metric_value(self): self.assertEqual(0, len(result.InvocationInfo.InvocationErrorMessage)) # verify that the corresponding state has been updated - state = client_mdib.states.descriptor_handle.get_one(my_operation_descriptor.OperationTarget) + state = consumer_mdib.states.descriptor_handle.get_one(my_operation_descriptor.OperationTarget) self.assertEqual(state.MetricValue.Value, Decimal(str(value))) diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 614ee098..76a3c658 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,13 +1,15 @@ -import copy +"""The module contains tests for transactions for ProviderMdib. + +It tests classic transactions and entity transactions. +""" import pathlib import unittest from sdc11073.definitions_sdc import SdcV1Definitions from sdc11073.exceptions import ApiUsageError from sdc11073.mdib.providermdib import ProviderMdib -from sdc11073.mdib.transactions import mk_transaction from sdc11073.mdib.statecontainers import NumericMetricStateContainer - +from sdc11073.mdib.transactions import mk_transaction from sdc11073.xml_types import pm_qnames, pm_types mdib_file = str(pathlib.Path(__file__).parent.joinpath('mdib_tns.xml')) @@ -263,3 +265,210 @@ def test_remove_add(self): for state in transaction_result.all_states(): self.assertEqual(state.DescriptorVersion, current_descriptors[state.DescriptorHandle].DescriptorVersion) + + + +class TestEntityTransactions(unittest.TestCase): + """Test all kinds of transactions for entity interface of ProviderMdib.""" + + def setUp(self): + self._mdib = ProviderMdib.from_mdib_file(mdib_file, + protocol_definition=SdcV1Definitions) + + def test_alert_state_update(self): + """Verify that alert_state_transaction works as expected. + + - mdib_version is incremented + - StateVersion is incremented in mdib state + - updated state is referenced in transaction_result + - observable alert_by_handle is updated + - ApiUsageError is thrown if state of wrong kind is added, + """ + mdib_version = self._mdib.mdib_version + old_ac_entity = self._mdib.entities.by_node_type(pm_qnames.AlertConditionDescriptor)[0] + old_ac_entity.state.Presence = True + with self._mdib.alert_state_transaction() as mgr: + mgr.write_entity(old_ac_entity) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + self.assertEqual(len(self._mdib.transaction.alert_updates), 1) # this causes an EpisodicAlertReport + + new_ac_entity = self._mdib.entities.by_handle(old_ac_entity.handle) + self.assertEqual(new_ac_entity.state.StateVersion, old_ac_entity.state.StateVersion + 1) + + metric_entities = self._mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor) + with self._mdib.alert_state_transaction() as mgr: + self.assertRaises(ApiUsageError, mgr.write_entity, metric_entities[0]) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + + def test_metric_state_update(self): + """Verify that metric_state_transaction works as expected. + + - mdib_version is incremented + - StateVersion is incremented in mdib state + - updated state is referenced in transaction_result + - observable metric_by_handle is updated + - ApiUsageError is thrown if state of wrong kind is added, + """ + mdib_version = self._mdib.mdib_version + old_metric_entity = self._mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor)[0] + old_metric_entity.state.LifeTimePeriod = 2 + with self._mdib.metric_state_transaction() as mgr: + mgr.write_entity(old_metric_entity) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + self.assertEqual(len(self._mdib.transaction.metric_updates), 1) + + new_metric_entity = self._mdib.entities.by_handle(old_metric_entity.handle) + self.assertEqual(new_metric_entity.state.StateVersion, old_metric_entity.state.StateVersion + 1) + + ac_entities = self._mdib.entities.by_node_type(pm_qnames.AlertConditionDescriptor) + with self._mdib.metric_state_transaction() as mgr: + self.assertRaises(ApiUsageError, mgr.write_entity, ac_entities[0]) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + + def test_operational_state_update(self): + """Verify that operational_state_transaction works as expected. + + - mdib_version is incremented + - StateVersion is incremented in mdib state + - updated state is referenced in transaction_result + - observable operation_by_handle is updated + - ApiUsageError is thrown if state of wrong kind is added, + """ + mdib_version = self._mdib.mdib_version + old_op_entity = self._mdib.entities.by_node_type(pm_qnames.SetAlertStateOperationDescriptor)[0] + old_op_entity.state.OperationMode = pm_types.OperatingMode.DISABLED + with self._mdib.operational_state_transaction() as mgr: + mgr.write_entity(old_op_entity) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + self.assertEqual(len(self._mdib.transaction.op_updates), 1) + self.assertTrue(old_op_entity.handle in self._mdib.operation_by_handle) + + new_op_entity = self._mdib.entities.by_handle(old_op_entity.handle) + self.assertEqual(new_op_entity.state.StateVersion, old_op_entity.state.StateVersion + 1) + + metric_entities = self._mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor) + with self._mdib.alert_state_transaction() as mgr: + self.assertRaises(ApiUsageError, mgr.write_entity, metric_entities[0]) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + + def test_context_state_transaction(self): + """Verify that context_state_transaction works as expected. + + - mk_context_state method works as expected + - mdib_version is incremented + - StateVersion is incremented in mdib state + - updated state is referenced in transaction_result + - observable context_by_handle is updated + - ApiUsageError is thrown if state of wrong kind is added + """ + mdib_version = self._mdib.mdib_version + old_pat_entity = self._mdib.entities.by_node_type(pm_qnames.PatientContextDescriptor)[0] + new_state = old_pat_entity.new_state() + self.assertIsNotNone(new_state.Handle) + new_state.CoreData.Givenname = 'foo' + new_state.CoreData.Familyname = 'bar' + context_handle = new_state.Handle + + with self._mdib.context_state_transaction() as mgr: + mgr.write_entity(old_pat_entity, [context_handle]) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + self.assertEqual(len(self._mdib.transaction.ctxt_updates), 1) + + new_pat_entity = self._mdib.entities.by_handle(old_pat_entity.handle) + self.assertEqual(new_pat_entity.states[context_handle].StateVersion,0) + self.assertEqual(new_pat_entity.states[context_handle].CoreData.Givenname,'foo') + self.assertEqual(new_pat_entity.states[context_handle].CoreData.Familyname,'bar') + + new_pat_entity.states[context_handle].CoreData.Familyname = 'foobar' + + with self._mdib.context_state_transaction() as mgr: + mgr.write_entity(new_pat_entity, [context_handle]) + + newest_pat_entity = self._mdib.entities.by_handle(old_pat_entity.handle) + self.assertEqual(newest_pat_entity.states[context_handle].StateVersion,1) + self.assertEqual(newest_pat_entity.states[context_handle].CoreData.Familyname,'foobar') + + metric_entities = self._mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor) + with self._mdib.alert_state_transaction() as mgr: + self.assertRaises(ApiUsageError, mgr.write_entity, metric_entities[0]) + self.assertEqual(mdib_version + 2, self._mdib.mdib_version) + + def test_description_modification(self): + """Verify that descriptor_transaction works as expected. + + - mdib_version is incremented + - observable updated_descriptors_by_handle is updated + - corresponding states for descriptor modifications are also updated + - ApiUsageError is thrown if data of wrong kind is requested + """ + mdib_version = self._mdib.mdib_version + old_ac_entity = self._mdib.entities.by_node_type(pm_qnames.AlertConditionDescriptor)[0] + old_metric_entity = self._mdib.entities.by_node_type(pm_qnames.NumericMetricDescriptor)[0] + old_op_entity = self._mdib.entities.by_node_type(pm_qnames.SetAlertStateOperationDescriptor)[0] + old_comp_entity = self._mdib.entities.by_node_type(pm_qnames.ChannelDescriptor)[0] + old_rt_entity = self._mdib.entities.by_node_type(pm_qnames.RealTimeSampleArrayMetricDescriptor)[0] + old_ctx_entity = self._mdib.entities.by_node_type(pm_qnames.PatientContextDescriptor)[0] + + with self._mdib.descriptor_transaction() as mgr: + # verify that updating descriptors of different kinds and accessing corresponding states works + mgr.write_entity(old_ac_entity) + mgr.write_entity(old_metric_entity) + mgr.write_entity(old_op_entity) + mgr.write_entity(old_comp_entity) + mgr.write_entity(old_rt_entity) + mgr.write_entity(old_ctx_entity) + self.assertEqual(mdib_version + 1, self._mdib.mdib_version) + transaction_result = self._mdib.transaction + self.assertEqual(len(transaction_result.metric_updates), 1) + self.assertEqual(len(transaction_result.alert_updates), 1) + self.assertEqual(len(transaction_result.op_updates), 1) + self.assertEqual(len(transaction_result.comp_updates), 1) + self.assertEqual(len(transaction_result.rt_updates), 1) + self.assertEqual(len(transaction_result.ctxt_updates), 1) + self.assertEqual(len(transaction_result.descr_updated), 6) + + self.assertTrue(old_ac_entity.handle in self._mdib.updated_descriptors_by_handle) + self.assertTrue(old_ac_entity.handle in self._mdib.alert_by_handle) + self.assertTrue(old_metric_entity.handle in self._mdib.updated_descriptors_by_handle) + self.assertTrue(old_metric_entity.handle in self._mdib.metrics_by_handle) + self.assertTrue(old_op_entity.handle in self._mdib.updated_descriptors_by_handle) + self.assertTrue(old_op_entity.handle in self._mdib.operation_by_handle) + self.assertTrue(old_comp_entity.handle in self._mdib.updated_descriptors_by_handle) + self.assertTrue(old_comp_entity.handle in self._mdib.component_by_handle) + self.assertTrue(old_rt_entity.handle in self._mdib.updated_descriptors_by_handle) + + def test_remove_add(self): + """Verify that removing descriptors / states and adding them later again results in correct versions.""" + # remove all root descriptors + all_entities = {} + for descr in self._mdib.descriptions.objects: + all_entities[descr.Handle] = self._mdib.entities.by_handle(descr.Handle) # get external representation + + root_entities = self._mdib.entities.by_parent_handle(None) + with self._mdib.descriptor_transaction() as mgr: + for ent in root_entities: + mgr.remove_entity(ent) + + self.assertEqual(0, len(self._mdib.descriptions.objects)) + + # add all entities again + with self._mdib.descriptor_transaction() as mgr: + mgr.write_entities(all_entities.values()) + + # verify that the number of entities is the same as before + self.assertEqual(len(all_entities), len(self._mdib.descriptions.objects)) + + # verify that all descriptors and states have incremented version counters + # for current_ent in self._mdib.internal_entities.values(): + for handle in all_entities: + current_ent = self._mdib.entities.by_handle(handle) + old_ent = all_entities[current_ent.handle] + self.assertEqual(current_ent.descriptor.DescriptorVersion, old_ent.descriptor.DescriptorVersion + 1) + if current_ent.is_multi_state: + for state_handle, current_state in current_ent.states.items(): + old_state = old_ent.states[state_handle] + self.assertEqual(current_state.StateVersion, old_state.StateVersion + 1) + self.assertEqual(current_state.DescriptorVersion, current_ent.descriptor.DescriptorVersion) + else: + self.assertEqual(current_ent.state.StateVersion, old_ent.state.StateVersion + 1) + self.assertEqual(current_ent.state.DescriptorVersion, current_ent.descriptor.DescriptorVersion) diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 19a61ff8..b5e2139f 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -1,6 +1,8 @@ +"""The module contains example how to use sdc provider and consumer.""" from __future__ import annotations import os +import time import unittest import uuid from decimal import Decimal @@ -8,11 +10,10 @@ from sdc11073 import network from sdc11073.consumer import SdcConsumer -from sdc11073.definitions_base import ProtocolsRegistry from sdc11073.definitions_sdc import SdcV1Definitions +from sdc11073.entity_mdib.entity_consumermdib import EntityConsumerMdib +from sdc11073.entity_mdib.entity_providermdib import EntityProviderMdib from sdc11073.loghelper import basic_logging_setup, get_logger_adapter -from sdc11073.mdib import ProviderMdib -from sdc11073.mdib.consumermdib import ConsumerMdib from sdc11073.provider import SdcProvider from sdc11073.provider.components import SdcProviderComponents from sdc11073.provider.operations import ExecuteResult @@ -21,16 +22,18 @@ from sdc11073.wsdiscovery import WSDiscovery, WSDiscoverySingleAdapter from sdc11073.xml_types import msg_types, pm_types from sdc11073.xml_types import pm_qnames as pm +from sdc11073.xml_types.actions import periodic_actions_and_system_error_report from sdc11073.xml_types.dpws_types import ThisDeviceType, ThisModelType from sdc11073.xml_types.msg_types import InvocationState from sdc11073.xml_types.pm_types import CodedValue from sdc11073.xml_types.wsd_types import ScopesType -from sdc11073.xml_types.actions import periodic_actions_and_system_error_report from tests import utils if TYPE_CHECKING: from sdc11073.mdib.descriptorcontainers import AbstractOperationDescriptorProtocol + from sdc11073.mdib.mdibprotocol import ProviderMdibProtocol from sdc11073.provider.operations import ExecuteParameters, OperationDefinitionBase + from sdc11073.provider.sco import AbstractScoOperationsRegistry from sdc11073.roles.providerbase import OperationClassGetter loopback_adapter = next(adapter for adapter in network.get_adapters() if adapter.is_loopback) @@ -38,12 +41,15 @@ SEARCH_TIMEOUT = 2 # in real world applications this timeout is too short, 10 seconds is a good value. # Here this short timeout is used to accelerate the test. -here = os.path.dirname(__file__) -my_mdib_path = os.path.join(here, '70041_MDIB_Final.xml') +here = os.path.dirname(__file__) # noqa: PTH120 +my_mdib_path = os.path.join(here, '70041_MDIB_Final.xml') # noqa: PTH118 -def createGenericDevice(wsdiscovery_instance, location, mdib_path, specific_components=None): - my_mdib = ProviderMdib.from_mdib_file(mdib_path) +def create_generic_provider(wsdiscovery_instance: WSDiscovery, + location: str, + mdib_path: str, + specific_components: SdcProviderComponents | None = None) -> SdcProvider: + my_mdib = EntityProviderMdib.from_mdib_file(mdib_path) my_epr = uuid.uuid4().hex this_model = ThisModelType(manufacturer='Draeger', manufacturer_url='www.draeger.com', @@ -55,18 +61,20 @@ def createGenericDevice(wsdiscovery_instance, location, mdib_path, specific_comp this_device = ThisDeviceType(friendly_name='TestDevice', firmware_version='Version1', serial_number='12345') - sdc_device = SdcProvider(wsdiscovery_instance, - this_model, - this_device, - my_mdib, - epr=my_epr, - specific_components=specific_components) - for desc in sdc_device.mdib.descriptions.objects: - desc.SafetyClassification = pm_types.SafetyClassification.MED_A - sdc_device.start_all(start_rtsample_loop=False) + sdc_provider = SdcProvider(wsdiscovery_instance, + this_model, + this_device, + my_mdib, + epr=my_epr, + specific_components=specific_components) + with sdc_provider.mdib.descriptor_transaction() as tr: + for _, ent in sdc_provider.mdib.entities.items(): + ent.descriptor.SafetyClassification = pm_types.SafetyClassification.MED_A + tr.write_entity(ent) + sdc_provider.start_all(start_rtsample_loop=False) validators = [pm_types.InstanceIdentifier('Validator', extension_string='System')] - sdc_device.set_location(location, validators) - return sdc_device + sdc_provider.set_location(location, validators) + return sdc_provider MY_CODE_1 = CodedValue('196279') # refers to an activate operation in mdib @@ -76,12 +84,14 @@ def createGenericDevice(wsdiscovery_instance, location, mdib_path, specific_comp class MyProvider1(ProviderRole): - """This provider handles operations with code == MY_CODE_1 and MY_CODE_2. + """The provider handles operations with code == MY_CODE_1 and MY_CODE_2. Operations with these codes already exist in the mdib that is used for this test. """ - def __init__(self, mdib, log_prefix): + def __init__(self, + mdib: ProviderMdibProtocol, + log_prefix: str): super().__init__(mdib, log_prefix) self.operation1_called = 0 self.operation1_args = None @@ -91,7 +101,9 @@ def __init__(self, mdib, log_prefix): def make_operation_instance(self, operation_descriptor_container: AbstractOperationDescriptorProtocol, operation_cls_getter: OperationClassGetter) -> OperationDefinitionBase | None: - """If the role provider is responsible for handling of calls to this operation_descriptor_container, + """Create an operation instance if operation_descriptor_container matches this operation. + + If the role provider is responsible for handling of calls to this operation_descriptor_container, it creates an operation instance and returns it, otherwise it returns None. """ if operation_descriptor_container.coding == MY_CODE_1.coding: @@ -101,45 +113,45 @@ def make_operation_instance(self, # # The following line shows how to provide your callback (in this case self._handle_operation_1). # This callback is called when a consumer calls the operation. - operation = self._mk_operation_from_operation_descriptor(operation_descriptor_container, - operation_cls_getter, - self._handle_operation_1) - return operation + return self._mk_operation_from_operation_descriptor(operation_descriptor_container, + operation_cls_getter, + self._handle_operation_1) if operation_descriptor_container.coding == MY_CODE_2.coding: - operation = self._mk_operation_from_operation_descriptor(operation_descriptor_container, - operation_cls_getter, - self._handle_operation_2) - return operation + return self._mk_operation_from_operation_descriptor(operation_descriptor_container, + operation_cls_getter, + self._handle_operation_2) return None def _handle_operation_1(self, params: ExecuteParameters) -> ExecuteResult: - """This operation does not manipulate the mdib at all, it only registers the call.""" + """Do not manipulate the mdib at all, it only increment the call counter.""" argument = params.operation_request.argument self.operation1_called += 1 self.operation1_args = argument - self._logger.info('_handle_operation_1 called arg={}', argument) + self._logger.info('_handle_operation_1 called arg=%r', argument) return ExecuteResult(params.operation_instance.operation_target_handle, InvocationState.FINISHED) def _handle_operation_2(self, params: ExecuteParameters) -> ExecuteResult: - """This operation manipulate it operation target, and only registers the call.""" + """Manipulate the operation target, and increments the call counter.""" argument = params.operation_request.argument self.operation2_called += 1 self.operation2_args = argument - self._logger.info('_handle_operation_2 called arg={}', argument) + self._logger.info('_handle_operation_2 called arg=%r', argument) + op_target_entity = self._mdib.entities.by_handle(params.operation_instance.operation_target_handle) + if op_target_entity.state.MetricValue is None: + op_target_entity.state.mk_metric_value() + op_target_entity.state.MetricValue.Value = argument with self._mdib.metric_state_transaction() as mgr: - my_state = mgr.get_state(params.operation_instance.operation_target_handle) - if my_state.MetricValue is None: - my_state.mk_metric_value() - my_state.MetricValue.Value = argument + mgr.write_entity(op_target_entity) return ExecuteResult(params.operation_instance.operation_target_handle, InvocationState.FINISHED) class MyProvider2(ProviderRole): - """This provider handles operations with code == MY_CODE_3. + """The provider handles operations with code == MY_CODE_3. + Operations with these codes already exist in the mdib that is used for this test. """ - def __init__(self, mdib, log_prefix): + def __init__(self, mdib: ProviderMdibProtocol, log_prefix: str): super().__init__(mdib, log_prefix) self.operation3_args = None self.operation3_called = 0 @@ -149,38 +161,40 @@ def make_operation_instance(self, operation_cls_getter: OperationClassGetter) -> OperationDefinitionBase | None: if operation_descriptor_container.coding == MY_CODE_3.coding: - self._logger.info( - 'instantiating operation 3 from existing descriptor handle={}'.format( - operation_descriptor_container.Handle)) - operation = self._mk_operation_from_operation_descriptor(operation_descriptor_container, - operation_cls_getter, - self._handle_operation_3) - return operation - else: - return None + self._logger.info('instantiating operation 3 from existing descriptor handle=%s', + operation_descriptor_container.Handle) + return self._mk_operation_from_operation_descriptor(operation_descriptor_container, + operation_cls_getter, + self._handle_operation_3) + return None def _handle_operation_3(self, params: ExecuteParameters) -> ExecuteResult: - """This operation manipulate it operation target, and only registers the call.""" + """Manipulate the operation target, and increments the call counter.""" self.operation3_called += 1 argument = params.operation_request.argument self.operation3_args = argument self._logger.info('_handle_operation_3 called') + op_target_entity = self._mdib.entities.by_handle(params.operation_instance.operation_target_handle) + if op_target_entity.state.MetricValue is None: + op_target_entity.state.mk_metric_value() + op_target_entity.state.MetricValue.Value = argument with self._mdib.metric_state_transaction() as mgr: - my_state = mgr.get_state(params.operation_instance.operation_target_handle) - if my_state.MetricValue is None: - my_state.mk_metric_value() - my_state.MetricValue.Value = argument + mgr.write_entity(op_target_entity) return ExecuteResult(params.operation_instance.operation_target_handle, InvocationState.FINISHED) class MyProductImpl(BaseProduct): - """This class provides all handlers of the fictional product. + """The class provides all handlers of the fictional product. + It instantiates 2 role providers. The number of role providers does not matter, it is a question of how the code is organized. Each role provider should handle one specific role, e.g. audio pause provider, clock provider, ... """ - def __init__(self, mdib, sco, log_prefix=None): + def __init__(self, + mdib: ProviderMdibProtocol, + sco: AbstractScoOperationsRegistry, + log_prefix: str | None = None): super().__init__(mdib, sco, log_prefix) self.my_provider_1 = MyProvider1(mdib, log_prefix=log_prefix) self._ordered_providers.append(self.my_provider_1) @@ -188,15 +202,15 @@ def __init__(self, mdib, sco, log_prefix=None): self._ordered_providers.append(self.my_provider_2) -class Test_Tutorial(unittest.TestCase): +class TestTutorial(unittest.TestCase): """run tutorial examples as unit tests, so that broken examples are automatically detected.""" def setUp(self) -> None: self.my_location = utils.random_location() self.my_location2 = utils.random_location() # tests fill these lists with what they create, teardown cleans up after them. - self.my_devices = [] - self.my_clients = [] + self.my_providers = [] + self.my_consumers = [] self.my_ws_discoveries = [] basic_logging_setup() @@ -205,17 +219,17 @@ def setUp(self) -> None: def tearDown(self) -> None: self._logger.info('###### tearDown ... ##########') - for cl in self.my_clients: - self._logger.info('stopping {}', cl) - cl.stop_all() - for d in self.my_devices: - self._logger.info('stopping {}', d) - d.stop_all() - for w in self.my_ws_discoveries: - self._logger.info('stopping {}', w) - w.stop() - - def test_createDevice(self): + for consumer in self.my_consumers: + self._logger.info('stopping %r', consumer) + consumer.stop_all() + for provider in self.my_providers: + self._logger.info('stopping %r', provider) + provider.stop_all() + for discovery in self.my_ws_discoveries: + self._logger.info('stopping %r', discovery) + discovery.stop() + + def test_create_provider(self): # A WsDiscovery instance is needed to publish devices on the network. # In this case we want to publish them only on localhost 127.0.0.1. my_ws_discovery = WSDiscovery('127.0.0.1') @@ -223,20 +237,20 @@ def test_createDevice(self): my_ws_discovery.start() # to create a device, this what you usually do: - my_generic_device = createGenericDevice(my_ws_discovery, self.my_location, my_mdib_path) - self.my_devices.append(my_generic_device) + my_generic_provider = create_generic_provider(my_ws_discovery, self.my_location, my_mdib_path) + self.my_providers.append(my_generic_provider) - def test_searchDevice(self): + def test_search_provider(self): # create one discovery and two device that we can then search for my_ws_discovery = WSDiscovery('127.0.0.1') self.my_ws_discoveries.append(my_ws_discovery) my_ws_discovery.start() - my_generic_device1 = createGenericDevice(my_ws_discovery, self.my_location, my_mdib_path) - self.my_devices.append(my_generic_device1) + my_generic_provider1 = create_generic_provider(my_ws_discovery, self.my_location, my_mdib_path) + self.my_providers.append(my_generic_provider1) - my_generic_device2 = createGenericDevice(my_ws_discovery, self.my_location2, my_mdib_path) - self.my_devices.append(my_generic_device2) + my_generic_provider2 = create_generic_provider(my_ws_discovery, self.my_location2, my_mdib_path) + self.my_providers.append(my_generic_provider2) # Search for devices # ------------------ @@ -252,31 +266,24 @@ def test_searchDevice(self): # (that can even be printers). # TODO: enable this step once https://github.com/Draegerwerk/sdc11073/issues/223 has been fixed - # now search only for devices in my_location2 + # search for any device at my_location2 services = my_client_ws_discovery.search_services(scopes=ScopesType(self.my_location2.scope_string), timeout=SEARCH_TIMEOUT) self.assertEqual(len(services), 1) - # search for medical devices only (BICEPS Final version only) + # search for medical devices at any location services = my_client_ws_discovery.search_services(types=SdcV1Definitions.MedicalDeviceTypesFilter, timeout=SEARCH_TIMEOUT) self.assertGreaterEqual(len(services), 2) - # search for medical devices only all known protocol versions - all_types = [p.MedicalDeviceTypesFilter for p in ProtocolsRegistry.protocols] - services = my_client_ws_discovery.search_multiple_types(types_list=all_types, - timeout=SEARCH_TIMEOUT) - - self.assertGreaterEqual(len(services), 2) - - def test_createClient(self): + def test_create_client(self): # create one discovery and one device that we can then search for my_ws_discovery = WSDiscovery('127.0.0.1') self.my_ws_discoveries.append(my_ws_discovery) my_ws_discovery.start() - my_generic_device1 = createGenericDevice(my_ws_discovery, self.my_location, my_mdib_path) - self.my_devices.append(my_generic_device1) + my_generic_provider1 = create_generic_provider(my_ws_discovery, self.my_location, my_mdib_path) + self.my_providers.append(my_generic_provider1) my_client_ws_discovery = WSDiscovery('127.0.0.1') self.my_ws_discoveries.append(my_client_ws_discovery) @@ -289,27 +296,24 @@ def test_createClient(self): scopes=ScopesType(self.my_location.scope_string)) self.assertEqual(len(services), 1) # both devices found - my_client = SdcConsumer.from_wsd_service(services[0], ssl_context_container=None) - self.my_clients.append(my_client) - my_client.start_all(not_subscribed_actions=periodic_actions_and_system_error_report) + my_consumer = SdcConsumer.from_wsd_service(services[0], ssl_context_container=None) + self.my_consumers.append(my_consumer) + my_consumer.start_all(not_subscribed_actions=periodic_actions_and_system_error_report) ############# Mdib usage ############################## # In data oriented tests a mdib instance is very handy: # The mdib collects all data and makes it easily available for the test # The MdibBase wraps data in "container" objects. # The basic idea is that every node that has a handle becomes directly accessible via its handle. - my_mdib = ConsumerMdib(my_client) + my_mdib = EntityConsumerMdib(my_consumer) my_mdib.init_mdib() # my_mdib keeps itself now updated # now query some data # mdib has three lookups: descriptions, states and context_states # each lookup can be searched by different keys, # e.g. looking for a descriptor by type looks like this: - location_context_descriptor_containers = my_mdib.descriptions.NODETYPE.get(pm.LocationContextDescriptor) - self.assertEqual(len(location_context_descriptor_containers), 1) - # we can look for the corresponding state by handle: - location_context_state_containers = my_mdib.context_states.descriptor_handle.get( - location_context_descriptor_containers[0].Handle) - self.assertEqual(len(location_context_state_containers), 1) + location_context_entities = my_mdib.entities.by_node_type(pm.LocationContextDescriptor) + self.assertEqual(len(location_context_entities), 1) + self.assertEqual(len(location_context_entities[0].states), 1) def test_call_operation(self): # create one discovery and one device that we can then search for @@ -317,8 +321,8 @@ def test_call_operation(self): self.my_ws_discoveries.append(my_ws_discovery) my_ws_discovery.start() - my_generic_device1 = createGenericDevice(my_ws_discovery, self.my_location, my_mdib_path) - self.my_devices.append(my_generic_device1) + my_generic_provider1 = create_generic_provider(my_ws_discovery, self.my_location, my_mdib_path) + self.my_providers.append(my_generic_provider1) my_client_ws_discovery = WSDiscovery('127.0.0.1') self.my_ws_discoveries.append(my_client_ws_discovery) @@ -331,36 +335,42 @@ def test_call_operation(self): scopes=ScopesType(self.my_location.scope_string)) self.assertEqual(len(services), 1) # both devices found - my_client = SdcConsumer.from_wsd_service(services[0], ssl_context_container=None) - self.my_clients.append(my_client) - my_client.start_all(not_subscribed_actions=periodic_actions_and_system_error_report) - my_mdib = ConsumerMdib(my_client) + my_consumer = SdcConsumer.from_wsd_service(services[0], ssl_context_container=None) + self.my_consumers.append(my_consumer) + my_consumer.start_all(not_subscribed_actions=periodic_actions_and_system_error_report) + my_mdib = EntityConsumerMdib(my_consumer) my_mdib.init_mdib() # we want to set a patient. # first we must find the operation that has PatientContextDescriptor as operation target - patient_context_descriptor_containers = my_mdib.descriptions.NODETYPE.get(pm.PatientContextDescriptor) - self.assertEqual(len(patient_context_descriptor_containers), 1) - my_patient_context_descriptor_container = patient_context_descriptor_containers[0] - all_operations = my_mdib.descriptions.NODETYPE.get(pm.SetContextStateOperationDescriptor, []) - my_operations = [op for op in all_operations if - op.OperationTarget == my_patient_context_descriptor_container.Handle] + patient_context_entities = my_mdib.entities.by_node_type(pm.PatientContextDescriptor) + self.assertEqual(len(patient_context_entities), 1) + my_patient_context_entity = patient_context_entities[0] + all_operation_entities = my_mdib.entities.by_node_type(pm.SetContextStateOperationDescriptor) + my_operations = [op for op in all_operation_entities if + op.descriptor.OperationTarget == my_patient_context_entity.handle] self.assertEqual(len(my_operations), 1) my_operation = my_operations[0] - # make a proposed patient context: - context_service = my_client.context_service_client - proposed_patient = context_service.mk_proposed_context_object(my_patient_context_descriptor_container.Handle) + # make a proposed new patient context: + context_service = my_consumer.context_service_client + proposed_patient = my_patient_context_entity.new_state() + # The new state has as a placeholder the descriptor handle as handle + # => provider shall create a new state proposed_patient.Firstname = 'Jack' proposed_patient.Lastname = 'Miller' - future = context_service.set_context_state(operation_handle=my_operation.Handle, + future = context_service.set_context_state(operation_handle=my_operation.handle, proposed_context_states=[proposed_patient]) result = future.result(timeout=5) self.assertEqual(result.InvocationInfo.InvocationState, msg_types.InvocationState.FINISHED) + my_patient_context_entity.update() + # provider should have replaced the placeholder handle with a new one. + self.assertFalse(proposed_patient.Handle in my_patient_context_entity.states) def test_operation_handler(self): - """This example shows how to implement own handlers for operations, and it shows multiple ways how a client can - find the desired operation. + """The example shows how to implement own handlers for operations. + + It shows multiple ways how a client can find the desired operation. """ # Create a device like in the examples above, but provide an own role provider. # This role provider is used instead of the default one. @@ -370,15 +380,15 @@ def test_operation_handler(self): specific_components = SdcProviderComponents(role_provider_class=MyProductImpl) # use the minimalistic mdib from reference test: - mdib_path = os.path.join(here, '../examples/ReferenceTest/reference_mdib.xml') - my_generic_device = createGenericDevice(my_ws_discovery, - self.my_location, - mdib_path, - specific_components=specific_components) + mdib_path = os.path.join(here, '../examples/ReferenceTest/reference_mdib.xml') # noqa: PTH118 + my_generic_provider = create_generic_provider(my_ws_discovery, + self.my_location, + mdib_path, + specific_components=specific_components) - self.my_devices.append(my_generic_device) + self.my_providers.append(my_generic_provider) - # connect a client to this device: + # connect a consumer to this provider: my_client_ws_discovery = WSDiscovery('127.0.0.1') self.my_ws_discoveries.append(my_client_ws_discovery) my_client_ws_discovery.start() @@ -388,24 +398,24 @@ def test_operation_handler(self): self.assertEqual(len(services), 1) self.service = SdcConsumer.from_wsd_service(services[0], ssl_context_container=None) - my_client = self.service - self.my_clients.append(my_client) - my_client.start_all(not_subscribed_actions=periodic_actions_and_system_error_report) - my_mdib = ConsumerMdib(my_client) + my_consumer = self.service + self.my_consumers.append(my_consumer) + my_consumer.start_all(not_subscribed_actions=periodic_actions_and_system_error_report) + my_mdib = EntityConsumerMdib(my_consumer) my_mdib.init_mdib() sco_handle = 'sco.mds0' - my_product_impl = my_generic_device.product_lookup[sco_handle] + my_product_impl = my_generic_provider.product_lookup[sco_handle] # call activate operation: # A client should NEVER! use the handle of the operation directly, always use the code(s) to identify things. # Handles are random values without any meaning, they are only unique id's in the mdib. - operations = my_mdib.descriptions.coding.get(MY_CODE_1.coding) + operation_entities = my_mdib.entities.by_coding(MY_CODE_1.coding) # the mdib contains 2 operations with the same code. To keep things simple, just use the first one here. - self._logger.info('looking for operations with code {}', MY_CODE_1.coding) - op = operations[0] + self._logger.info('looking for operations with code %r', MY_CODE_1.coding) + op_entity = operation_entities[0] argument = 'foo' - self._logger.info('calling operation {}, argument = {}', op, argument) - future = my_client.set_service_client.activate(op.Handle, arguments=[argument]) + self._logger.info('calling operation %s, argument = %r', op_entity.handle, argument) + future = my_consumer.set_service_client.activate(op_entity.handle, arguments=[argument]) result = future.result() print(result) self.assertEqual(my_product_impl.my_provider_1.operation1_called, 1) @@ -415,27 +425,32 @@ def test_operation_handler(self): # call set_string operation sco_handle = 'sco.vmd1.mds0' - my_product_impl = my_generic_device.product_lookup[sco_handle] + my_product_impl = my_generic_provider.product_lookup[sco_handle] - self._logger.info('looking for operations with code {}', MY_CODE_2.coding) - op = my_mdib.descriptions.coding.get_one(MY_CODE_2.coding) + self._logger.info('looking for operations with code %r', MY_CODE_2.coding) + op_entities = my_mdib.entities.by_coding(MY_CODE_2.coding) + my_op = op_entities[0] for value in ('foo', 'bar'): - self._logger.info('calling operation {}, argument = {}', op, value) - future = my_client.set_service_client.set_string(op.Handle, value) + self._logger.info('calling operation %s, argument = %r', my_op.handle, value) + future = my_consumer.set_service_client.set_string(my_op.handle, value) result = future.result() print(result) + time.sleep(1) self.assertEqual(my_product_impl.my_provider_1.operation2_args, value) - state = my_mdib.states.descriptor_handle.get_one(op.OperationTarget) - self.assertEqual(state.MetricValue.Value, value) + op_target_entity = my_mdib.entities.by_handle(my_op.descriptor.OperationTarget) + self.assertEqual(op_target_entity.state.MetricValue.Value, value) self.assertEqual(my_product_impl.my_provider_1.operation2_called, 2) # call setValue operation - state_descr = my_mdib.descriptions.coding.get_one(MY_CODE_3_TARGET.coding) - operations = my_mdib.get_operation_descriptors_for_descriptor_handle(state_descr.Handle) - op = operations[0] - future = my_client.set_service_client.set_numeric_value(op.Handle, Decimal('42')) + op_target_entities = my_mdib.entities.by_coding(MY_CODE_3_TARGET.coding) + op_target_entity = op_target_entities[0] + + all_operations = my_mdib.entities.by_node_type(pm.SetValueOperationDescriptor) + my_ops = [op for op in all_operations if op.descriptor.OperationTarget == op_target_entity.handle] + + future = my_consumer.set_service_client.set_numeric_value(my_ops[0].handle, Decimal('42')) result = future.result() print(result) self.assertEqual(my_product_impl.my_provider_2.operation3_args, 42) - state = my_mdib.states.descriptor_handle.get_one(op.OperationTarget) - self.assertEqual(state.MetricValue.Value, 42) + ent = my_mdib.entities.by_handle(op_target_entity.handle) + self.assertEqual(ent.state.MetricValue.Value, 42)