diff --git a/nexus_constructor/json/json_warnings.py b/nexus_constructor/json/json_warnings.py index cd0835f34..2d83cefa0 100644 --- a/nexus_constructor/json/json_warnings.py +++ b/nexus_constructor/json/json_warnings.py @@ -11,22 +11,22 @@ @attr.s class InvalidJson: - message = attr.ib(type=str) + message: str = attr.ib() @attr.s class InvalidShape: - message = attr.ib(type=str) + message: str = attr.ib() @attr.s class InvalidTransformation: - message = attr.ib(type=str) + message: str = attr.ib() @attr.s class TransformDependencyMissing: - message = attr.ib(type=str) + message: str = attr.ib() @attr.s @@ -36,12 +36,12 @@ class RelativeDependsonWrong: @attr.s class NameFieldMissing: - message = attr.ib(type=str) + message: str = attr.ib() @attr.s class NXClassAttributeMissing: - message = attr.ib(type=str) + message: str = attr.ib() JsonWarning = Union[ diff --git a/nexus_constructor/json/load_from_json.py b/nexus_constructor/json/load_from_json.py index 0eb03b76c..b05464212 100644 --- a/nexus_constructor/json/load_from_json.py +++ b/nexus_constructor/json/load_from_json.py @@ -72,7 +72,6 @@ class JSONReader: def __init__(self): - self.entry_node: Group = None self.model = Model() self.sample_name: str = "" self.warnings = JsonWarningsContainer() @@ -221,14 +220,15 @@ def load_model_from_json(self, filename: str) -> bool: return self._load_from_json_dict(json_dict) def _load_from_json_dict(self, json_dict: Dict) -> bool: - self.entry_node = self._read_json_object(json_dict[CommonKeys.CHILDREN][0]) - self.model.entry.attributes = self.entry_node.attributes - for child in self.entry_node.children: - if isinstance(child, (Dataset, Link, FileWriter, Group)): - self.model.entry[child.name] = child - else: - self.model.entry.children.append(child) - child.parent_node = self.model.entry + entry_node = self._read_json_object(json_dict[CommonKeys.CHILDREN][0]) + if entry_node: + self.model.entry.attributes = entry_node.attributes + for child in entry_node.children: + if isinstance(child, (Dataset, Link, FileWriter, Group)): + self.model.entry[child.name] = child + else: + self.model.entry.children.append(child) + child.parent_node = self.model.entry self._set_transforms_depends_on() self._set_components_depends_on() self._append_transformations_to_nx_group() @@ -253,76 +253,87 @@ def _replace_placeholder(self, placeholder: str): } return None - def _read_json_object(self, json_object: Dict, parent_node: Group = None): + def _read_json_object( + self, json_object: Optional[Dict], parent_node: Optional[Group] = None + ): """ Tries to create a component based on the contents of the JSON file. :param json_object: A component from the JSON dictionary. :param parent_name: The name of the parent object. Used for warning messages if something goes wrong. """ - nexus_object: Union[Group, FileWriterModule] = None + nexus_object: Union[Group, FileWriterModule, None] = None use_placeholder = False if isinstance(json_object, str) and json_object in PLACEHOLDER_WITH_NX_CLASSES: json_object = self._replace_placeholder(json_object) if not json_object: - return + return None use_placeholder = True - if ( - CommonKeys.TYPE in json_object - and json_object[CommonKeys.TYPE] == NodeType.GROUP - ): - try: - name = json_object[CommonKeys.NAME] - except KeyError: - self._add_object_warning(CommonKeys.NAME, parent_node) + if json_object: + if ( + CommonKeys.TYPE in json_object + and json_object[CommonKeys.TYPE] == NodeType.GROUP + ): + try: + name = json_object[CommonKeys.NAME] + except KeyError: + self._add_object_warning(CommonKeys.NAME, parent_node) + return None + nx_class = _find_nx_class(json_object.get(CommonKeys.ATTRIBUTES)) + if nx_class == SAMPLE_CLASS_NAME: + self.sample_name = name + if not self._validate_nx_class(name, nx_class): + self._add_object_warning( + f"valid Nexus class {nx_class}", parent_node + ) + if nx_class in COMPONENT_TYPES: + nexus_object = Component(name=name, parent_node=parent_node) + children_dict = json_object[CommonKeys.CHILDREN] + self._add_transform_and_shape_to_component( + nexus_object, children_dict + ) + self.model.append_component(nexus_object) + else: + nexus_object = Group(name=name, parent_node=parent_node) + if nexus_object: + nexus_object.nx_class = nx_class + if CommonKeys.CHILDREN in json_object: + for child in json_object[CommonKeys.CHILDREN]: + node = self._read_json_object(child, nexus_object) + if node and isinstance(node, StreamModule): + nexus_object.children.append(node) + nexus_object.remove_stream_module(node.writer_module) + elif node and node.name not in nexus_object: + nexus_object[node.name] = node + elif CommonKeys.MODULE in json_object and NodeType.CONFIG in json_object: + module_type = json_object[CommonKeys.MODULE] + if ( + module_type == WriterModules.DATASET.value + or module_type == WriterModules.FILEWRITER.value + ) and json_object[NodeType.CONFIG][ + CommonKeys.NAME + ] == CommonAttrs.DEPENDS_ON: + nexus_object = None + elif module_type in [x.value for x in WriterModules]: + nexus_object = create_fw_module_object( + module_type, json_object[NodeType.CONFIG], parent_node + ) + if nexus_object: + nexus_object.parent_node = parent_node + else: + self._add_object_warning("valid module type", parent_node) + return None + elif json_object == USERS_PLACEHOLDER: + self.model.entry.users_placeholder = True return None - nx_class = _find_nx_class(json_object.get(CommonKeys.ATTRIBUTES)) - if nx_class == SAMPLE_CLASS_NAME: - self.sample_name = name - if not self._validate_nx_class(name, nx_class): - self._add_object_warning(f"valid Nexus class {nx_class}", parent_node) - if nx_class in COMPONENT_TYPES: - nexus_object = Component(name=name, parent_node=parent_node) - children_dict = json_object[CommonKeys.CHILDREN] - self._add_transform_and_shape_to_component(nexus_object, children_dict) - self.model.append_component(nexus_object) else: - nexus_object = Group(name=name, parent_node=parent_node) - nexus_object.nx_class = nx_class - if CommonKeys.CHILDREN in json_object: - for child in json_object[CommonKeys.CHILDREN]: - node = self._read_json_object(child, nexus_object) - if node and isinstance(node, StreamModule): - nexus_object.children.append(node) - nexus_object.remove_stream_module(node.writer_module) - elif node and node.name not in nexus_object: - nexus_object[node.name] = node - elif CommonKeys.MODULE in json_object and NodeType.CONFIG in json_object: - module_type = json_object[CommonKeys.MODULE] - if ( - module_type == WriterModules.DATASET.value - or module_type == WriterModules.FILEWRITER.value - ) and json_object[NodeType.CONFIG][ - CommonKeys.NAME - ] == CommonAttrs.DEPENDS_ON: - nexus_object = None - elif module_type in [x.value for x in WriterModules]: - nexus_object = create_fw_module_object( - module_type, json_object[NodeType.CONFIG], parent_node + self._add_object_warning( + f"valid {CommonKeys.TYPE} or {CommonKeys.MODULE}", parent_node ) - nexus_object.parent_node = parent_node - else: - self._add_object_warning("valid module type", parent_node) - return None - elif json_object == USERS_PLACEHOLDER: - self.model.entry.users_placeholder = True - return None else: - self._add_object_warning( - f"valid {CommonKeys.TYPE} or {CommonKeys.MODULE}", parent_node - ) + self._add_object_warning("!!No json_object!!", parent_node) # Add attributes to nexus_object. - if nexus_object: + if nexus_object and json_object: json_attrs = json_object.get(CommonKeys.ATTRIBUTES) if json_attrs: attributes = Attributes() diff --git a/nexus_constructor/json/load_from_json_utils.py b/nexus_constructor/json/load_from_json_utils.py index a14e5d874..9b980d935 100644 --- a/nexus_constructor/json/load_from_json_utils.py +++ b/nexus_constructor/json/load_from_json_utils.py @@ -64,7 +64,7 @@ def _find_attribute_from_dict(attribute_name: str, entry: dict) -> Any: def _find_attribute_from_list_or_dict( attribute_name: str, - entry: Union[list, dict], + entry: Optional[Union[list, dict, str]], ) -> Any: """ Attempts to determine the value of an attribute in a dictionary or a list of dictionaries. @@ -81,7 +81,7 @@ def _find_attribute_from_list_or_dict( return _find_attribute_from_dict(attribute_name, entry) -def _find_nx_class(entry: Union[list, dict]) -> str: +def _find_nx_class(entry: Union[list, dict, str, None]) -> str: """ Tries to find the NX class value from a dictionary or a list of dictionaries. :param entry: A dictionary or list of dictionaries. diff --git a/nexus_constructor/json/transform_id.py b/nexus_constructor/json/transform_id.py index 7f85f4453..7d58ee2ea 100644 --- a/nexus_constructor/json/transform_id.py +++ b/nexus_constructor/json/transform_id.py @@ -8,5 +8,5 @@ class TransformId: Uniquely identifies a Transformation """ - component_name = attr.ib(type=str) - transform_name = attr.ib(type=str) + component_name: str = attr.ib() + transform_name: str = attr.ib() diff --git a/nexus_constructor/model/component.py b/nexus_constructor/model/component.py index b225398f6..c4e939b04 100644 --- a/nexus_constructor/model/component.py +++ b/nexus_constructor/model/component.py @@ -84,10 +84,12 @@ class Component(Group): Base class for a component object. In the NeXus file this would translate to the component group. """ - _depends_on = attr.ib(type=Transformation, default=None) - has_link = attr.ib(type=bool, default=None) - component_info: "ComponentInfo" = None - stored_transforms: list = None + _depends_on: Transformation = attr.ib(default=None) + has_link: bool = attr.ib(default=None) + component_info: Optional["ComponentInfo"] = None + stored_transforms: list = [] + name: str = "" + parent_node: Optional[Group] = None @property def stored_items(self) -> List: @@ -271,7 +273,7 @@ def _create_and_add_transform( units: str, vector: QVector3D, depends_on: Transformation, - values: Union[Dataset, Group], + values: Union[Dataset, Group, StreamModule], target_pos: int = -1, ) -> Transformation: if name is None: diff --git a/nexus_constructor/model/group.py b/nexus_constructor/model/group.py index 8893254fc..f14731208 100644 --- a/nexus_constructor/model/group.py +++ b/nexus_constructor/model/group.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import attr @@ -19,7 +19,7 @@ from nexus_constructor.model.value_type import ValueTypes if TYPE_CHECKING: - from nexus_constructor.model.module import FileWriterModule # noqa: F401 + from nexus_constructor.model.module import FileWriterModule TRANSFORMS_GROUP_NAME = "transformations" @@ -37,15 +37,15 @@ class Group: Base class for any group which has a set of children and an nx_class attribute. """ - name = attr.ib(type=str) - parent_node = attr.ib(type="Group", default=None) - children: List[Union["FileWriterModule", "Group"]] = attr.ib( # noqa: F821 + name: str = attr.ib() + parent_node: Optional["Group"] = attr.ib(default=None) + children: List[Union["FileWriterModule", "Group"]] = attr.ib( factory=list, init=False ) - attributes = attr.ib(type=Attributes, factory=Attributes, init=False) + attributes: Attributes = attr.ib(factory=Attributes, init=False) values = None - possible_stream_modules = attr.ib( - type=List[str], default=attr.Factory(create_list_of_possible_streams) + possible_stream_modules: List[str] = attr.ib( + default=attr.Factory(create_list_of_possible_streams) ) _group_placeholder: bool = False diff --git a/nexus_constructor/model/module.py b/nexus_constructor/model/module.py index 1219310b2..2fb70b623 100644 --- a/nexus_constructor/model/module.py +++ b/nexus_constructor/model/module.py @@ -1,6 +1,6 @@ from abc import ABC from enum import Enum -from typing import TYPE_CHECKING, Dict, List, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union import attr import h5py @@ -11,7 +11,7 @@ from nexus_constructor.model.helpers import get_absolute_path if TYPE_CHECKING: - from nexus_constructor.model.group import Group # noqa: F401 + from nexus_constructor.model.group import Group from nexus_constructor.model.value_type import JsonSerialisableType, ValueType @@ -57,9 +57,11 @@ class StreamModules(Enum): @attr.s class FileWriterModule(ABC): - attributes = attr.ib(type=Attributes, factory=Attributes, init=False) - writer_module = attr.ib(type=str, init=False) - parent_node = attr.ib(type="Group") + name = "" + attributes: Attributes = attr.ib(factory=Attributes, init=False) + writer_module: str = attr.ib(init=False) + parent_node: Optional["Group"] = attr.ib() + children = [] def as_dict(self, error_collector: List[str]): raise NotImplementedError() @@ -74,8 +76,8 @@ def absolute_path(self): @attr.s class StreamModule(FileWriterModule): - source = attr.ib(type=str) - topic = attr.ib(type=str) + source: str = attr.ib() + topic: str = attr.ib() def as_dict(self, error_collector: List[str]): return_dict = { @@ -91,13 +93,13 @@ def as_dict(self, error_collector: List[str]): @attr.s class NS10Stream(StreamModule): - writer_module = attr.ib(type=str, default=WriterModules.NS10.value, init=False) + writer_module: str = attr.ib(default=WriterModules.NS10.value, init=False) @attr.s class SE00Stream(StreamModule): - type = attr.ib(type=str) - writer_module = attr.ib(type=str, default=WriterModules.SE00.value, init=False) + type: str = attr.ib() + writer_module: str = attr.ib(default=WriterModules.SE00.value, init=False) def as_dict(self, error_collector: List[str]): module_dict = StreamModule.as_dict(self, error_collector) @@ -108,20 +110,20 @@ def as_dict(self, error_collector: List[str]): @attr.s class SENVStream(StreamModule): - writer_module = attr.ib(type=str, default=WriterModules.SENV.value, init=False) + writer_module: str = attr.ib(default=WriterModules.SENV.value, init=False) @attr.s class TDCTStream(StreamModule): - writer_module = attr.ib(type=str, default=WriterModules.TDCTIME.value, init=False) + writer_module: str = attr.ib(default=WriterModules.TDCTIME.value, init=False) @attr.s class EV42Stream(StreamModule): - writer_module = attr.ib(type=str, default=WriterModules.EV42.value, init=False) - adc_pulse_debug = attr.ib(type=bool, default=None) - cue_interval = attr.ib(type=int, default=None) - chunk_size = attr.ib(type=int, default=None) + writer_module: str = attr.ib(default=WriterModules.EV42.value, init=False) + adc_pulse_debug: bool = attr.ib(default=None) + cue_interval: int = attr.ib(default=None) + chunk_size: int = attr.ib(default=None) def as_dict(self, error_collector: List[str]): module_dict = StreamModule.as_dict(self, error_collector) @@ -136,17 +138,17 @@ def as_dict(self, error_collector: List[str]): @attr.s class EV44Stream(EV42Stream): - writer_module = attr.ib(type=str, default=WriterModules.EV44.value, init=False) + writer_module: str = attr.ib(default=WriterModules.EV44.value, init=False) @attr.s class F142Stream(StreamModule): - type = attr.ib(type=str) - cue_interval = attr.ib(type=int, default=None) - chunk_size = attr.ib(type=int, default=None) - value_units = attr.ib(type=str, default=None) - array_size = attr.ib(type=list, default=None) - writer_module = attr.ib(type=str, default=WriterModules.F142.value, init=False) + type: str = attr.ib() + cue_interval: int = attr.ib(default=None) + chunk_size: int = attr.ib(default=None) + value_units: str = attr.ib(default=None) + array_size: list = attr.ib(default=None) + writer_module: str = attr.ib(default=WriterModules.F142.value, init=False) def as_dict(self, error_collector: List[str]): module_dict = StreamModule.as_dict(self, error_collector) @@ -165,17 +167,15 @@ def as_dict(self, error_collector: List[str]): @attr.s class F144Stream(F142Stream): - writer_module = attr.ib(type=str, default=WriterModules.F144.value, init=False) + writer_module: str = attr.ib(default=WriterModules.F144.value, init=False) @attr.s class FileWriter(FileWriterModule): - name = attr.ib(type=str) - type = attr.ib(type=str, default="string") - values = attr.ib(type=str, default=None) - writer_module = attr.ib( - type=str, default=WriterModules.FILEWRITER.value, init=False - ) + name: str = attr.ib() + type: str = attr.ib(default="string") + values: str = attr.ib(default=None) + writer_module: str = attr.ib(default=WriterModules.FILEWRITER.value, init=False) def as_dict(self, error_collector: List[str]): return { @@ -190,9 +190,9 @@ def as_nexus(self, nexus_node, error_collector: List[str]): @attr.s class Link(FileWriterModule): - name = attr.ib(type=str) - source = attr.ib(type=str) - writer_module = attr.ib(type=str, default=WriterModules.LINK.value, init=False) + name: str = attr.ib() + source: str = attr.ib() + writer_module: str = attr.ib(default=WriterModules.LINK.value, init=False) values = None def as_dict(self, error_collector: List[str]): @@ -208,10 +208,10 @@ def as_nexus(self, nexus_node, error_collector: List[str]): @attr.s class Dataset(FileWriterModule): - name = attr.ib(type=str) - values = attr.ib(type=Union[List[ValueType], ValueType]) - type = attr.ib(type=str, default=None) - writer_module = attr.ib(type=str, default=WriterModules.DATASET.value, init=False) + name: str = attr.ib() + values: Union[List[ValueType], ValueType] = attr.ib() + type: str = attr.ib(default=None) + writer_module: str = attr.ib(default=WriterModules.DATASET.value, init=False) def as_dict(self, error_collector: List[str]): values = self.values @@ -250,8 +250,8 @@ def _cast_to_type(self, data): @attr.s class ADARStream(StreamModule): - array_size = attr.ib(type=list, init=False) - writer_module = attr.ib(type=str, default=WriterModules.ADAR.value, init=False) + array_size: list = attr.ib(init=False) + writer_module: str = attr.ib(default=WriterModules.ADAR.value, init=False) def as_dict(self, error_collector: List[str]): module_dict = StreamModule.as_dict(self, error_collector) @@ -273,11 +273,11 @@ def as_dict(self, error_collector: List[str]): @attr.s class HS01Shape: - size = attr.ib(type=int) - label = attr.ib(type=str) - unit = attr.ib(type=str) - edges = attr.ib(type=List[int]) - dataset_name = attr.ib(type=str) + size: int = attr.ib() + label: str = attr.ib() + unit: str = attr.ib() + edges: List[int] = attr.ib() + dataset_name: str = attr.ib() def as_dict(self, error_collector: List[str]): return { @@ -291,11 +291,11 @@ def as_dict(self, error_collector: List[str]): @attr.s class HS01Stream(StreamModule): - type = attr.ib(type=str, default=None) - error_type = attr.ib(type=str, default=None) - edge_type = attr.ib(type=str, default=None) - shape = attr.ib(type=List[HS01Shape], default=[]) - writer_module = attr.ib(type=str, default=WriterModules.HS01.value, init=False) + type: str = attr.ib(default=None) + error_type: str = attr.ib(default=None) + edge_type: str = attr.ib(default=None) + shape: List[HS01Shape] = attr.ib(default=[]) + writer_module: str = attr.ib(default=WriterModules.HS01.value, init=False) def as_dict(self, error_collector: List[str]): module_dict = StreamModule.as_dict(self, error_collector) diff --git a/tests/json/test_load_from_json.py b/tests/json/test_load_from_json.py index b5a8f5909..940c5eb10 100644 --- a/tests/json/test_load_from_json.py +++ b/tests/json/test_load_from_json.py @@ -280,7 +280,7 @@ def test_GIVEN_json_with_sample_WHEN_loading_from_json_THEN_new_model_contains_n nexus_json_dictionary["children"][0]["children"][1]["name"] = sample_name json_reader._load_from_json_dict(nexus_json_dictionary) - assert json_reader.entry_node[sample_name].name == sample_name + assert json_reader.model.entry.children[1].name == sample_name def test_GIVEN_component_with_name_WHEN_loading_from_json_THEN_new_model_contains_component_with_json_name(