diff --git a/nexus_constructor/component/component.py b/nexus_constructor/component/component.py index 05adbaad2..c176424e0 100644 --- a/nexus_constructor/component/component.py +++ b/nexus_constructor/component/component.py @@ -8,7 +8,11 @@ from nexus_constructor.component.pixel_shape import PixelShape from nexus_constructor.component.transformations_list import TransformationsList from nexus_constructor.nexus import nexus_wrapper as nx -from nexus_constructor.nexus.nexus_wrapper import get_nx_class, to_string + +from nexus_constructor.nexus.nexus_wrapper import ( + get_nx_class, + get_name_of_node, +) from nexus_constructor.field_utils import get_fields_with_update_functions from nexus_constructor.pixel_data import PixelMapping, PixelGrid, PixelData from nexus_constructor.pixel_data_to_nexus_utils import ( @@ -170,9 +174,11 @@ def _get_transform( if ( transforms and depends_on - == to_string(transforms[-1].dataset.attrs[CommonAttrs.DEPENDS_ON]) + == self.file.get_attribute_value( + transforms[-1].dataset, CommonAttrs.DEPENDS_ON + ) and depends_on - in [x.split("/")[-1] for x in transforms[-1].dataset.parent.keys()] + in [get_name_of_node(x) for x in transforms[-1].dataset.parent.values()] ): # depends_on is recursive, ie one transformation in this group depends on another transformation in the group, and it is also relative transform_dataset = self.file.nexus_file[ @@ -193,7 +199,10 @@ def _get_transform( new_transform = create_transformation(self.file, transform_dataset) new_transform.parent = transforms transforms.append(new_transform) - if CommonAttrs.DEPENDS_ON in transform_dataset.attrs.keys(): + if ( + self.file.get_attribute_value(transform_dataset, CommonAttrs.DEPENDS_ON) + is not None + ): self._get_transform( self.file.get_attribute_value( transform_dataset, CommonAttrs.DEPENDS_ON @@ -235,27 +244,15 @@ def add_translation( :param name: name of the translation group (Optional) :param depends_on: existing transformation which the new one depends on (otherwise relative to origin) """ - transforms_group = self.file.create_transformations_group_if_does_not_exist( - self.group - ) - if name is None: - name = _generate_incremental_name( - TransformationType.TRANSLATION, transforms_group - ) unit_vector, magnitude = _normalise(vector) - field = self.file.set_field_value(transforms_group, name, magnitude, float) - self.file.set_attribute_value(field, CommonAttrs.UNITS, "m") - self.file.set_attribute_value( - field, CommonAttrs.VECTOR, qvector3d_to_numpy_array(unit_vector) + return self._create_transform( + name, + TransformationType.TRANSLATION, + magnitude, + "m", + unit_vector, + depends_on, ) - self.file.set_attribute_value( - field, CommonAttrs.TRANSFORMATION_TYPE, TransformationType.TRANSLATION - ) - - translation_transform = create_transformation(self.file, field) - translation_transform.ui_value = magnitude - translation_transform.depends_on = depends_on - return translation_transform def add_rotation( self, @@ -271,25 +268,39 @@ def add_rotation( :param name: Name of the rotation group (Optional) :param depends_on: existing transformation which the new one depends on (otherwise relative to origin) """ + return self._create_transform( + name, TransformationType.ROTATION, angle, "degrees", axis, depends_on + ) + + def _create_transform( + self, + name: str, + transformation_type: TransformationType, + angle_or_magnitude: float, + units: str, + vector: QVector3D, + depends_on: Transformation, + ): transforms_group = self.file.create_transformations_group_if_does_not_exist( self.group ) if name is None: - name = _generate_incremental_name( - TransformationType.ROTATION, transforms_group - ) - field = self.file.set_field_value(transforms_group, name, angle, float) - self.file.set_attribute_value(field, CommonAttrs.UNITS, "degrees") + name = _generate_incremental_name(transformation_type, transforms_group) + + field = self.file.set_field_value( + transforms_group, name, angle_or_magnitude, float + ) + self.file.set_attribute_value(field, CommonAttrs.UNITS, units) self.file.set_attribute_value( - field, CommonAttrs.VECTOR, qvector3d_to_numpy_array(axis) + field, CommonAttrs.VECTOR, qvector3d_to_numpy_array(vector) ) self.file.set_attribute_value( - field, CommonAttrs.TRANSFORMATION_TYPE, TransformationType.ROTATION + field, CommonAttrs.TRANSFORMATION_TYPE, transformation_type ) - rotation_transform = create_transformation(self.file, field) - rotation_transform.depends_on = depends_on - rotation_transform.ui_value = angle - return rotation_transform + transform = create_transformation(self.file, field) + transform.ui_value = angle_or_magnitude + transform.depends_on = depends_on + return transform def _transform_is_in_this_component(self, transform: Transformation) -> bool: return transform._dataset.parent.parent.name == self.absolute_path @@ -299,24 +310,23 @@ def remove_transformation(self, transform: Transformation): raise PermissionError( "Transform is not in this component, do not have permission to delete" ) - - dependents = transform.get_dependents() + dependents = transform.dependents if dependents: raise DependencyError( f"Cannot delete transformation, it is a dependency of {dependents}" ) # Remove whole transformations group if this is the only transformation in it - if len(transform._dataset.parent.keys()) == 1: - self.file.delete_node(transform._dataset.parent) + if len(transform.dataset.parent.keys()) == 1: + self.file.delete_node(transform.dataset.parent) # Otherwise just remove the transformation from the group else: - self.file.delete_node(transform._dataset) + self.file.delete_node(transform.dataset) @property def depends_on(self) -> Optional[Transformation]: depends_on_path = self.file.get_field_value(self.group, CommonAttrs.DEPENDS_ON) - if depends_on_path is None: + if depends_on_path in [None, "."]: return None return create_transformation(self.file, self.file.nexus_file[depends_on_path]) diff --git a/nexus_constructor/component/component_factory.py b/nexus_constructor/component/component_factory.py index c6897fda0..e9c964c77 100644 --- a/nexus_constructor/component/component_factory.py +++ b/nexus_constructor/component/component_factory.py @@ -1,4 +1,3 @@ -from nexus_constructor.common_attrs import CommonAttrs from nexus_constructor.component.component import Component from nexus_constructor.component.chopper_shape import ChopperShape from nexus_constructor.component.pixel_shape import PixelShape @@ -6,25 +5,19 @@ CHOPPER_CLASS_NAME, PIXEL_COMPONENT_TYPES, ) -from nexus_constructor.nexus.nexus_wrapper import NexusWrapper +from nexus_constructor.nexus.nexus_wrapper import NexusWrapper, get_nx_class import h5py def create_component( nexus_wrapper: NexusWrapper, component_group: h5py.Group ) -> Component: - if ( - nexus_wrapper.get_attribute_value(component_group, CommonAttrs.NX_CLASS) - == CHOPPER_CLASS_NAME - ): + nx_class = get_nx_class(component_group) + if nx_class == CHOPPER_CLASS_NAME: return Component( nexus_wrapper, component_group, ChopperShape(nexus_wrapper, component_group) ) - if ( - nexus_wrapper.get_attribute_value(component_group, CommonAttrs.NX_CLASS) - in PIXEL_COMPONENT_TYPES - and "pixel_shape" in component_group - ): + if nx_class in PIXEL_COMPONENT_TYPES and "pixel_shape" in component_group: return Component( nexus_wrapper, component_group, PixelShape(nexus_wrapper, component_group) ) diff --git a/nexus_constructor/component_tree_model.py b/nexus_constructor/component_tree_model.py index fec4be763..199d9a707 100644 --- a/nexus_constructor/component_tree_model.py +++ b/nexus_constructor/component_tree_model.py @@ -129,7 +129,7 @@ def _remove_transformation(self, index: QModelIndex): def _remove_component(self, index: QModelIndex): component = index.internalPointer() transforms = component.transforms - if transforms and transforms[0].get_dependents(): + if transforms and transforms[0].dependents(): reply = QMessageBox.question( None, "Delete component?", diff --git a/nexus_constructor/field_utils.py b/nexus_constructor/field_utils.py index dff7b53e6..8cfd273b0 100644 --- a/nexus_constructor/field_utils.py +++ b/nexus_constructor/field_utils.py @@ -7,7 +7,10 @@ from nexus_constructor.common_attrs import CommonAttrs from nexus_constructor.field_widget import FieldWidget from nexus_constructor.invalid_field_names import INVALID_FIELD_NAMES -from nexus_constructor.nexus.nexus_wrapper import get_name_of_node +from nexus_constructor.nexus.nexus_wrapper import ( + get_name_of_node, + get_nx_class, +) from nexus_constructor.validators import FieldType from nexus_constructor.nexus.nexus_wrapper import h5Node @@ -87,10 +90,7 @@ def find_field_type(item: h5Node) -> Callable: elif isinstance(item, h5py.Group): if isinstance(item.parent.get(item.name, getlink=True), h5py.SoftLink): return update_existing_link_field - elif ( - CommonAttrs.NX_CLASS in item.attrs.keys() - and item.attrs[CommonAttrs.NX_CLASS] == CommonAttrs.NC_STREAM - ): + elif get_nx_class(item) == CommonAttrs.NC_STREAM: return update_existing_stream_field logging.debug( f"Object {get_name_of_node(item)} not handled as field - could be used for other parts of UI instead" diff --git a/nexus_constructor/file_writer_ctrl_window.py b/nexus_constructor/file_writer_ctrl_window.py index 9854fc054..7513d826d 100644 --- a/nexus_constructor/file_writer_ctrl_window.py +++ b/nexus_constructor/file_writer_ctrl_window.py @@ -7,8 +7,8 @@ from nexus_constructor.validators import BrokerAndTopicValidator from ui.led import Led from ui.filewriter_ctrl_frame import Ui_FilewriterCtrl -from PySide2.QtWidgets import QMainWindow, QLineEdit -from PySide2.QtCore import QTimer, QAbstractItemModel +from PySide2.QtWidgets import QMainWindow, QLineEdit, QApplication +from PySide2.QtCore import QTimer, QAbstractItemModel, QSettings from PySide2.QtGui import QStandardItemModel, QCloseEvent from PySide2 import QtCore from nexus_constructor.instrument import Instrument @@ -39,9 +39,25 @@ class File: last_time = attr.ib(default=0) +class FileWriterSettings: + STATUS_BROKER_ADDR = "status_broker_addr" + COMMAND_BROKER_ADDR = "command_broker_addr" + FILE_BROKER_ADDR = "file_broker_addr" + USE_START_TIME = "use_start_time" + USE_STOP_TIME = "use_stop_time" + FILE_NAME = "file_name" + + +def extract_bool_from_qsettings(setting: Union[str, bool]): + if type(setting) == str: + setting = setting == "True" + return setting + + class FileWriterCtrl(Ui_FilewriterCtrl, QMainWindow): - def __init__(self, instrument: Instrument): + def __init__(self, instrument: Instrument, settings: QSettings): super().__init__() + self.settings = settings self.instrument = instrument self.setupUi() self.known_writers = {} @@ -49,6 +65,60 @@ def __init__(self, instrument: Instrument): self.status_consumer = None self.command_producer = None + def _restore_settings(self): + """ + Restore persistent broker config settings from file. + """ + self.status_broker_edit.setText( + self.settings.value(FileWriterSettings.STATUS_BROKER_ADDR) + ) + self.command_broker_edit.setText( + self.settings.value(FileWriterSettings.COMMAND_BROKER_ADDR) + ) + self.command_widget.broker_line_edit.setText( + self.settings.value(FileWriterSettings.FILE_BROKER_ADDR) + ) + self.command_widget.start_time_enabled.setChecked( + extract_bool_from_qsettings( + self.settings.value(FileWriterSettings.USE_START_TIME, False) + ) + ) + self.command_widget.stop_time_enabled.setChecked( + extract_bool_from_qsettings( + self.settings.value(FileWriterSettings.USE_STOP_TIME, False) + ) + ) + self.command_widget.nexus_file_name_edit.setText( + self.settings.value(FileWriterSettings.FILE_NAME) + ) + + def _store_settings(self): + """ + Store persistent broker config settings to file. + """ + self.settings.setValue( + FileWriterSettings.STATUS_BROKER_ADDR, self.status_broker_edit.text() + ) + self.settings.setValue( + FileWriterSettings.COMMAND_BROKER_ADDR, self.command_broker_edit.text() + ) + self.settings.setValue( + FileWriterSettings.FILE_BROKER_ADDR, + self.command_widget.broker_line_edit.text(), + ) + self.settings.setValue( + FileWriterSettings.USE_START_TIME, + self.command_widget.start_time_enabled.isChecked(), + ) + self.settings.setValue( + FileWriterSettings.USE_STOP_TIME, + self.command_widget.stop_time_enabled.isChecked(), + ) + self.settings.setValue( + FileWriterSettings.FILE_NAME, + self.command_widget.nexus_file_name_edit.text(), + ) + def setupUi(self): super().setupUi(self) @@ -96,6 +166,8 @@ def setupUi(self): self.file_list_model.setHeaderData(1, QtCore.Qt.Horizontal, "Last seen") self.file_list_model.setHeaderData(2, QtCore.Qt.Horizontal, "File writer") self.files_list.setModel(self.file_list_model) + self._restore_settings() + QApplication.instance().aboutToQuit.connect(self._store_settings) @staticmethod def _set_up_broker_fields( @@ -215,7 +287,6 @@ def send_command(self): stop_time, service_id, abort_on_uninitialised_stream, - use_swmr, ) = self.command_widget.get_arguments() self.command_producer.send_command( bytes( diff --git a/nexus_constructor/filewriter_command_widget.py b/nexus_constructor/filewriter_command_widget.py index e7c65772f..1a0b6ad78 100644 --- a/nexus_constructor/filewriter_command_widget.py +++ b/nexus_constructor/filewriter_command_widget.py @@ -35,8 +35,6 @@ def __init__(self, parent=None): self.nexus_file_name_edit = QLineEdit() self.ok_button = QPushButton("Ok") - if parent is not None: - self.ok_button.clicked.connect(parent.close) self.broker_line_edit = QLineEdit() self.broker_line_edit.setPlaceholderText("broker:port") @@ -84,8 +82,6 @@ def __init__(self, parent=None): self.service_id_lineedit = QLineEdit() self.service_id_lineedit.setPlaceholderText("(Optional)") self.abort_on_uninitialised_stream_checkbox = QCheckBox() - self.use_swmr_checkbox = QCheckBox() - self.use_swmr_checkbox.setChecked(True) self.layout().addRow("nexus_file_name", self.nexus_file_name_edit) self.layout().addRow("broker", self.broker_line_edit) @@ -97,7 +93,6 @@ def __init__(self, parent=None): self.layout().addRow( "abort_on_uninitialised_stream", self.abort_on_uninitialised_stream_checkbox ) - self.layout().addRow("use_hdf_swmr", self.use_swmr_checkbox) self.layout().addRow(self.ok_button) def state_changed(self, is_start_time: bool, state: Qt.CheckState): @@ -112,7 +107,7 @@ def state_changed(self, is_start_time: bool, state: Qt.CheckState): def get_arguments( self, - ) -> Tuple[str, str, Union[str, None], Union[str, None], str, bool, bool]: + ) -> Tuple[str, str, Union[str, None], Union[str, None], str, bool]: """ gets the arguments of required and optional fields for the filewriter command. :return: Tuple containing all of the fields. @@ -129,5 +124,4 @@ def get_arguments( self.service_id_lineedit.text(), self.abort_on_uninitialised_stream_checkbox.checkState() == Qt.CheckState.Checked, - self.use_swmr_checkbox.checkState() == Qt.CheckState.Checked, ) diff --git a/nexus_constructor/geometry/cylindrical_geometry.py b/nexus_constructor/geometry/cylindrical_geometry.py index 52f3b1901..d3600e9ff 100644 --- a/nexus_constructor/geometry/cylindrical_geometry.py +++ b/nexus_constructor/geometry/cylindrical_geometry.py @@ -150,7 +150,7 @@ def axis_direction(self) -> QVector3D: return QVector3D(0, 0, 1) @property - def off_geometry(self, steps: int = 20) -> OFFGeometry: + def off_geometry(self, steps: int = 10) -> OFFGeometry: unit_conversion_factor = calculate_unit_conversion_factor(self.units, METRES) # A list of vertices describing the circle at the bottom of the cylinder diff --git a/nexus_constructor/instrument.py b/nexus_constructor/instrument.py index 1fd69c88b..e60bf185d 100644 --- a/nexus_constructor/instrument.py +++ b/nexus_constructor/instrument.py @@ -1,8 +1,5 @@ from typing import List - import h5py - -from nexus_constructor.common_attrs import CommonAttrs from nexus_constructor.nexus import nexus_wrapper as nx from nexus_constructor.component.component import Component from nexus_constructor.nexus.nexus_wrapper import get_nx_class @@ -12,7 +9,7 @@ COMPONENTS_IN_ENTRY = ["NXmonitor", "NXsample"] -def _convert_name_with_spaces(component_name): +def _convert_name_with_spaces(component_name: str) -> str: return component_name.replace(" ", "_") @@ -43,13 +40,12 @@ def refresh_depends_on(_, node): Refresh the depends_on attribute of each transformation, which also results in registering dependents """ if isinstance(node, h5py.Group): - if CommonAttrs.NX_CLASS in node.attrs.keys(): - if node.attrs[CommonAttrs.NX_CLASS] == "NXtransformations": - for transformation_name, transformation_node in node.items(): - transform = create_transformation( - self.nexus, node[transformation_name] - ) - transform.depends_on = transform.depends_on + if get_nx_class(node) == "NXtransformations": + for transformation_name, transformation_node in node.items(): + transform = create_transformation( + self.nexus, node[transformation_name] + ) + transform.depends_on = transform.depends_on self.nexus.nexus_file.visititems(refresh_depends_on) @@ -83,10 +79,8 @@ def get_component_list(self) -> List[Component]: def find_components(_, node): if isinstance(node, h5py.Group): - if CommonAttrs.NX_CLASS in node.attrs.keys(): - nx_class = get_nx_class(node) - if nx_class and nx_class in self.nx_component_classes: - component_list.append(create_component(self.nexus, node)) + if get_nx_class(node) in self.nx_component_classes: + component_list.append(create_component(self.nexus, node)) self.nexus.entry.visititems(find_components) return component_list diff --git a/nexus_constructor/instrument_view.py b/nexus_constructor/instrument_view.py index d0d49921e..406c6bf28 100644 --- a/nexus_constructor/instrument_view.py +++ b/nexus_constructor/instrument_view.py @@ -243,6 +243,7 @@ def delete_component(self, name: str): try: self.component_entities[name].setParent(None) self.component_entities.pop(name) + self.transformations.pop(name) except KeyError: logging.error( f"Unable to delete component {name} because it doesn't exist." diff --git a/nexus_constructor/main_window.py b/nexus_constructor/main_window.py index 579db3c56..0fc730aca 100644 --- a/nexus_constructor/main_window.py +++ b/nexus_constructor/main_window.py @@ -1,5 +1,7 @@ import uuid from typing import Dict + +from PySide2.QtCore import QSettings from PySide2.QtWidgets import ( QMainWindow, QApplication, @@ -94,7 +96,9 @@ def show_control_file_writer_window(self): if self.file_writer_control_window is None: from nexus_constructor.file_writer_ctrl_window import FileWriterCtrl - self.file_writer_ctrl_window = FileWriterCtrl(self.instrument) + self.file_writer_ctrl_window = FileWriterCtrl( + self.instrument, QSettings("ess", "nexus-constructor") + ) self.file_writer_ctrl_window.show() def show_edit_component_dialog(self): diff --git a/nexus_constructor/nexus/nexus_wrapper.py b/nexus_constructor/nexus/nexus_wrapper.py index 8ec2f449b..d0ad32752 100644 --- a/nexus_constructor/nexus/nexus_wrapper.py +++ b/nexus_constructor/nexus/nexus_wrapper.py @@ -141,12 +141,8 @@ def find_entries_in_file(self, nexus_file: h5py.File): def append_nx_entries_to_list(name, node): if isinstance(node, h5py.Group): - if CommonAttrs.NX_CLASS in node.attrs.keys(): - if ( - node.attrs[CommonAttrs.NX_CLASS] == b"NXentry" - or node.attrs[CommonAttrs.NX_CLASS] == "NXentry" - ): - entries_in_root[name] = node + if get_nx_class(node) == "NXentry": + entries_in_root[name] = node nexus_file["/"].visititems(append_nx_entries_to_list) if len(entries_in_root.keys()) > 1: @@ -349,12 +345,8 @@ def delete_attribute(self, node: h5Node, name: str): def create_transformations_group_if_does_not_exist(self, parent_group: h5Node): for child in parent_group: - if CommonAttrs.NX_CLASS in parent_group[child].attrs.keys(): - if ( - parent_group[child].attrs[CommonAttrs.NX_CLASS] - == "NXtransformations" - ): - return parent_group[child] + if get_nx_class(parent_group[child]) == "NXtransformations": + return parent_group[child] return self.create_nx_group( "transformations", "NXtransformations", parent_group ) @@ -367,10 +359,5 @@ def get_instrument_group_from_entry(entry: h5py.Group) -> h5py.Group: :return: the instrument group object. """ for node in entry.values(): - if isinstance(node, h5py.Group): - if CommonAttrs.NX_CLASS in node.attrs.keys(): - if node.attrs[CommonAttrs.NX_CLASS] in [ - "NXinstrument", - b"NXinstrument", - ]: - return node + if isinstance(node, h5py.Group) and get_nx_class(node) == "NXinstrument": + return node diff --git a/nexus_constructor/off_renderer.py b/nexus_constructor/off_renderer.py index 0675b3cc2..cb81ae4e9 100644 --- a/nexus_constructor/off_renderer.py +++ b/nexus_constructor/off_renderer.py @@ -49,7 +49,7 @@ def convert_faces_into_triangles(faces): return triangles -def create_vertex_buffer(vertices, faces): +def create_vertex_buffer(vertices, triangles): """ For each point in each triangle in each face, add its points to the vertices list. To do this we: @@ -57,11 +57,9 @@ def create_vertex_buffer(vertices, faces): Get the vertices that are in the triangles Adding them into a flat list of points :param vertices: The vertices in the mesh - :param faces: The faces in the mesh + :param triangles: A list of the triangles that make up each face in the mesh :return: A list of the points in the faces """ - triangles = convert_faces_into_triangles(faces) - flattened_triangles = flatten(triangles) return flatten( @@ -69,15 +67,14 @@ def create_vertex_buffer(vertices, faces): ) -def create_normal_buffer(vertices, faces): +def create_normal_buffer(vertices, triangles): """ Creates normal vectors for each vertex on the mesh. Qt requires each vertex to have it's own normal. :param vertices: The vertices for the mesh - :param faces: The faces in the mesh + :param triangles: A list of the triangles that make up each face in the mesh :return: A list of the normal points for the faces """ - triangles = convert_faces_into_triangles(faces) normal_buffer_values = [] for triangle in triangles: # Get the vertices of each triangle @@ -131,8 +128,9 @@ def __init__( faces, vertices = repeat_shape_over_positions(model, positions) - vertex_buffer_values = list(create_vertex_buffer(vertices, faces)) - normal_buffer_values = create_normal_buffer(vertices, faces) + triangles = convert_faces_into_triangles(faces) + vertex_buffer_values = list(create_vertex_buffer(vertices, triangles)) + normal_buffer_values = create_normal_buffer(vertices, triangles) positionAttribute = self.create_attribute( vertex_buffer_values, self.q_attribute.defaultPositionAttributeName() diff --git a/nexus_constructor/stream_fields_widget.py b/nexus_constructor/stream_fields_widget.py index 419bee554..378cb2758 100644 --- a/nexus_constructor/stream_fields_widget.py +++ b/nexus_constructor/stream_fields_widget.py @@ -99,6 +99,7 @@ def __init__(self, parent): self.type_label = QLabel("Type: ") self.type_combo = QComboBox() self.type_combo.addItems(F142_TYPES) + self.type_combo.setCurrentText("double") self.value_units_edit = QLineEdit() self.value_units_label = QLabel("Value Units:") diff --git a/nexus_constructor/transformations.py b/nexus_constructor/transformations.py index fd269a104..38b4efd8d 100644 --- a/nexus_constructor/transformations.py +++ b/nexus_constructor/transformations.py @@ -9,7 +9,7 @@ from nexus_constructor.nexus import nexus_wrapper as nx from typing import TypeVar, Union, List, Optional -from nexus_constructor.nexus.nexus_wrapper import h5Node, to_string +from nexus_constructor.nexus.nexus_wrapper import h5Node, get_nx_class from nexus_constructor.transformation_types import TransformationType TransformationOrComponent = TypeVar( @@ -34,18 +34,18 @@ def __eq__(self, other): @property def name(self): - return nx.get_name_of_node(self._dataset) + return nx.get_name_of_node(self.dataset) @name.setter def name(self, new_name: str): - self.file.rename_node(self._dataset, new_name) + self.file.rename_node(self.dataset, new_name) self._update_dependent_depends_on() def _update_dependent_depends_on(self): """ Updates all of the directly dependent "depends_on" fields for this transformation. """ - for dependent in self.get_dependents(): + for dependent in self.dependents: dependent.depends_on = self @property @@ -72,7 +72,7 @@ def absolute_path(self): this is guaranteed to be unique so it can be used as an ID for this Transformation :return: absolute path of the transform dataset in the NeXus file, """ - return self._dataset.name + return self.dataset.name @property def type(self): @@ -80,7 +80,7 @@ def type(self): Get transformation type, should be "Translation" or "Rotation" """ return self.file.get_attribute_value( - self._dataset, CommonAttrs.TRANSFORMATION_TYPE + self.dataset, CommonAttrs.TRANSFORMATION_TYPE ).capitalize() @type.setter @@ -89,16 +89,16 @@ def type(self, new_type: str): Set transformation type, should be "Translation" or "Rotation" """ self.file.set_attribute_value( - self._dataset, CommonAttrs.TRANSFORMATION_TYPE, new_type.capitalize() + self.dataset, CommonAttrs.TRANSFORMATION_TYPE, new_type.capitalize() ) @property def units(self): - return self.file.get_attribute_value(self._dataset, CommonAttrs.UNITS) + return self.file.get_attribute_value(self.dataset, CommonAttrs.UNITS) @units.setter def units(self, new_units): - self.file.set_attribute_value(self._dataset, CommonAttrs.UNITS, new_units) + self.file.set_attribute_value(self.dataset, CommonAttrs.UNITS, new_units) @property def vector(self): @@ -106,7 +106,7 @@ def vector(self): Returns rotation axis or translation direction as a QVector3D """ vector_as_np_array = self.file.get_attribute_value( - self._dataset, CommonAttrs.VECTOR + self.dataset, CommonAttrs.VECTOR ) return QVector3D( vector_as_np_array[0], vector_as_np_array[1], vector_as_np_array[2] @@ -116,7 +116,7 @@ def vector(self): def vector(self, new_vector: QVector3D): vector_as_np_array = np.array([new_vector.x(), new_vector.y(), new_vector.z()]) self.file.set_attribute_value( - self._dataset, CommonAttrs.VECTOR, vector_as_np_array + self.dataset, CommonAttrs.VECTOR, vector_as_np_array ) @property @@ -124,15 +124,15 @@ def dataset(self) -> h5Node: return self._dataset @dataset.setter - def dataset(self, new_data): + def dataset(self, new_data: h5Node): """ Used for setting the transformation dataset to a stream group, link or scalar/array field :param new_data: the new data being set """ old_attrs = {} - for k, v in self._dataset.attrs.items(): + for k, v in self.dataset.attrs.items(): old_attrs[k] = v - dataset_name = self._dataset.name + dataset_name = self.dataset.name del self.file.nexus_file[dataset_name] if isinstance(new_data, h5py.Dataset): @@ -146,10 +146,10 @@ def dataset(self, new_data): ) self._dataset = self.file.nexus_file[dataset_name] for k, v in old_attrs.items(): - self._dataset.attrs[k] = v + self.file.set_attribute_value(self.dataset, k, v) - if CommonAttrs.UI_VALUE not in self._dataset.attrs: - self._dataset.attrs[CommonAttrs.UI_VALUE] = 0 + if self.file.get_attribute_value(self.dataset, CommonAttrs.UI_VALUE) is None: + self.file.set_attribute_value(self.dataset, CommonAttrs.UI_VALUE, 0) @property def ui_value(self) -> float: @@ -160,8 +160,8 @@ def ui_value(self) -> float: if isinstance(self.dataset, h5py.Dataset): if np.isscalar(self.dataset[()]): try: - self.ui_value = float(self._dataset[()]) - return float(self._dataset[()]) + self.ui_value = float(self.dataset[()]) + return float(self.dataset[()]) except ValueError: logging.debug( "transformation value is not cast-able to float/int, using UI placeholder value instead." @@ -169,19 +169,19 @@ def ui_value(self) -> float: else: # Dataset value is array - try to use the first value of the array as the UI value try: - self.ui_value = float(self._dataset[...][0]) - return float(self._dataset[...][0]) + self.ui_value = float(self.dataset[...][0]) + return float(self.dataset[...][0]) except ValueError: - # Not castable to float - either return the UI value if present in the group or the default value if not. + # Not cast-able to float - either return the UI value if it's present in the group or the default + # value if not. pass - - if CommonAttrs.UI_VALUE not in self._dataset.attrs: + if self.file.get_attribute_value(self.dataset, CommonAttrs.UI_VALUE) is None: # Link or stream default_value = 0.0 self.ui_value = default_value return default_value - return self.file.get_attribute_value(self._dataset, CommonAttrs.UI_VALUE)[()] + return self.file.get_attribute_value(self.dataset, CommonAttrs.UI_VALUE)[()] @ui_value.setter def ui_value(self, new_value: float): @@ -189,20 +189,20 @@ def ui_value(self, new_value: float): Used for setting the magnitude of the transformation in the 3d view :param new_value: the placeholder magnitude for the 3d view """ - self.file.set_attribute_value(self._dataset, CommonAttrs.UI_VALUE, new_value) + self.file.set_attribute_value(self.dataset, CommonAttrs.UI_VALUE, new_value) @property def depends_on(self) -> Optional["Transformation"]: depends_on_path = self.file.get_attribute_value( - self._dataset, CommonAttrs.DEPENDS_ON + self.dataset, CommonAttrs.DEPENDS_ON ) if depends_on_path not in (None, "."): - if f"{self._dataset.parent.name}/{depends_on_path}" in self.file.nexus_file: + if f"{self.dataset.parent.name}/{depends_on_path}" in self.file.nexus_file: # depends_on is relative return create_transformation( self.file, self.file.nexus_file[ - f"{self._dataset.parent.name}/{depends_on_path}" + f"{self.dataset.parent.name}/{depends_on_path}" ], ) return create_transformation( @@ -216,7 +216,7 @@ def depends_on(self, depends_on: "Transformation"): to use string for depends_on type here, because the current class is not defined yet """ existing_depends_on = self.file.get_attribute_value( - self._dataset, CommonAttrs.DEPENDS_ON + self.dataset, CommonAttrs.DEPENDS_ON ) if ( @@ -228,10 +228,10 @@ def depends_on(self, depends_on: "Transformation"): ).deregister_dependent(self) if depends_on is None: - self.file.set_attribute_value(self._dataset, CommonAttrs.DEPENDS_ON, ".") + self.file.set_attribute_value(self.dataset, CommonAttrs.DEPENDS_ON, ".") else: self.file.set_attribute_value( - self._dataset, CommonAttrs.DEPENDS_ON, depends_on.absolute_path + self.dataset, CommonAttrs.DEPENDS_ON, depends_on.absolute_path ) depends_on.register_dependent(self) @@ -242,13 +242,13 @@ def register_dependent(self, dependent: TransformationOrComponent): :param dependent: transform or component that depends on this one """ - if CommonAttrs.DEPENDEE_OF not in self._dataset.attrs.keys(): + if self.file.get_attribute_value(self.dataset, CommonAttrs.DEPENDEE_OF) is None: self.file.set_attribute_value( - self._dataset, CommonAttrs.DEPENDEE_OF, dependent.absolute_path + self.dataset, CommonAttrs.DEPENDEE_OF, dependent.absolute_path ) else: dependee_of_list = self.file.get_attribute_value( - self._dataset, CommonAttrs.DEPENDEE_OF + self.dataset, CommonAttrs.DEPENDEE_OF ) if not isinstance(dependee_of_list, np.ndarray): dependee_of_list = np.array([dependee_of_list]) @@ -258,7 +258,7 @@ def register_dependent(self, dependent: TransformationOrComponent): dependee_of_list, np.array([dependent.absolute_path]) ) self.file.set_attribute_value( - self._dataset, CommonAttrs.DEPENDEE_OF, dependee_of_list + self.dataset, CommonAttrs.DEPENDEE_OF, dependee_of_list ) def deregister_dependent(self, former_dependent: TransformationOrComponent): @@ -267,29 +267,33 @@ def deregister_dependent(self, former_dependent: TransformationOrComponent): Note, "dependee_of" attribute is not part of the NeXus format :param former_dependent: transform or component that used to depend on this one """ - if CommonAttrs.DEPENDEE_OF in self._dataset.attrs.keys(): + if ( + self.file.get_attribute_value(self.dataset, CommonAttrs.DEPENDEE_OF) + is not None + ): dependee_of_list = self.file.get_attribute_value( - self._dataset, CommonAttrs.DEPENDEE_OF + self.dataset, CommonAttrs.DEPENDEE_OF ) if ( not isinstance(dependee_of_list, np.ndarray) and dependee_of_list == former_dependent.absolute_path ): # Must be a single string rather than a list, so simply delete it - self.file.delete_attribute(self._dataset, CommonAttrs.DEPENDEE_OF) + self.file.delete_attribute(self.dataset, CommonAttrs.DEPENDEE_OF) elif isinstance(dependee_of_list, np.ndarray): dependee_of_list = dependee_of_list[ dependee_of_list != former_dependent.absolute_path ] self.file.set_attribute_value( - self._dataset, CommonAttrs.DEPENDEE_OF, dependee_of_list + self.dataset, CommonAttrs.DEPENDEE_OF, dependee_of_list ) else: logging.warning( f"Unable to de-register dependent {former_dependent.absolute_path} from {self.absolute_path} due to it not being registered." ) - def get_dependents(self) -> List[Union["Component", "Transformation"]]: + @property + def dependents(self) -> List[Union["Component", "Transformation"]]: """ Returns the direct dependents of a transform, i.e. anything that has depends_on pointing to this transformation. """ @@ -297,9 +301,12 @@ def get_dependents(self) -> List[Union["Component", "Transformation"]]: return_dependents = [] - if CommonAttrs.DEPENDEE_OF in self._dataset.attrs.keys(): + if ( + self.file.get_attribute_value(self.dataset, CommonAttrs.DEPENDEE_OF) + is not None + ): dependents = self.file.get_attribute_value( - self._dataset, CommonAttrs.DEPENDEE_OF + self.dataset, CommonAttrs.DEPENDEE_OF ) if not isinstance(dependents, np.ndarray): dependents = [dependents] @@ -314,28 +321,36 @@ def get_dependents(self) -> List[Union["Component", "Transformation"]]: return return_dependents def remove_from_dependee_chain(self): - all_dependees = self.get_dependents() - new_depends_on = self.depends_on - if self.depends_on is not None and self.depends_on.absolute_path == "/": - new_depends_on = None - else: - for dependee in all_dependees: + """ + Remove this transformation from the depends_on chain by pointing any dependees to this transformation's depends_on. + If this transformation either has no depends_on or points to itself, just deregister it as a dependent. + """ + for dependee in self.dependents: + if self.depends_on not in [None, "."]: + # This transformation has a depends_on, so update the dependee to point to that instead if isinstance(dependee, Transformation): - new_depends_on.register_dependent(dependee) - for dependee in all_dependees: - dependee.depends_on = new_depends_on + # If a dependee is a transformation, register the dependee of this transform as a dependee to this + # transform's depends_on + self.depends_on.register_dependent(dependee) + # Update the dependee to point to this transformation's depends_on + dependee.depends_on = self.depends_on + # Regardless of if this transformation has a depends_on field, deregister it from any dependees. self.deregister_dependent(dependee) + # Set this transformation's depends_on to None to remove it from the chain. self.depends_on = None class NXLogTransformation(Transformation): @property def ui_value(self) -> float: - if "value" not in self._dataset.keys(): - if CommonAttrs.UI_VALUE not in self._dataset.attrs: + if "value" not in self.dataset.keys(): + if ( + self.file.get_attribute_value(self.dataset, CommonAttrs.UI_VALUE) + is None + ): self.ui_value = 0 - return self.file.get_attribute_value(self._dataset, CommonAttrs.UI_VALUE) - value_group = self._dataset["value"] + return self.file.get_attribute_value(self.dataset, CommonAttrs.UI_VALUE) + value_group = self.dataset["value"] if np.isscalar(value_group): return value_group[()] else: @@ -343,15 +358,17 @@ def ui_value(self) -> float: @ui_value.setter def ui_value(self, new_value): - self.file.set_attribute_value(self._dataset, CommonAttrs.UI_VALUE, new_value) + self.file.set_attribute_value(self.dataset, CommonAttrs.UI_VALUE, new_value) @property - def units(self) -> str: - self.file.get_attribute_value(self._dataset["value"], "units") + def units(self) -> Optional[str]: + return self.file.get_attribute_value(self.dataset["value"], CommonAttrs.UNITS) @units.setter def units(self, new_units: str): - self.file.set_attribute_value(self._dataset["value"], "units", new_units) + self.file.set_attribute_value( + self.dataset["value"], CommonAttrs.UNITS, new_units + ) @property def dataset(self) -> h5Node: @@ -367,9 +384,6 @@ def create_transformation(wrapper: nx.NexusWrapper, node: h5Node) -> Transformat Factory for creating different types of transform. If it is an NXlog group then the magnitude and units fields will be different to a normal transformation dataset. """ - if ( - CommonAttrs.NX_CLASS in node.attrs - and to_string(node.attrs[CommonAttrs.NX_CLASS]) == "NXlog" - ): + if get_nx_class(node) == "NXlog": return NXLogTransformation(wrapper, node) return Transformation(wrapper, node) diff --git a/nexus_constructor/treeview_utils.py b/nexus_constructor/treeview_utils.py index c6e2752d9..21aa216cc 100644 --- a/nexus_constructor/treeview_utils.py +++ b/nexus_constructor/treeview_utils.py @@ -78,14 +78,14 @@ def set_button_states( ) else: selected_object = selection_indices[0].internalPointer() - - zoom_action.setEnabled(isinstance(selected_object, Component)) + selected_object_is_component = isinstance(selected_object, Component) + zoom_action.setEnabled(selected_object_is_component) selected_object_is_component_or_transform = isinstance( selected_object, (Component, Transformation) ) duplicate_action.setEnabled(selected_object_is_component_or_transform) - edit_component_action.setEnabled(selected_object_is_component_or_transform) + edit_component_action.setEnabled(selected_object_is_component) selected_object_is_not_link_transform = not isinstance( selected_object, LinkTransformation diff --git a/tests/test_instrument.py b/tests/test_instrument.py index 52f17e529..8c95e642e 100644 --- a/tests/test_instrument.py +++ b/tests/test_instrument.py @@ -99,12 +99,12 @@ def test_dependents_list_is_created_by_instrument(file, nexus_wrapper): transform_1_loaded = Transformation(nexus_wrapper, transform_1) assert ( - len(transform_1_loaded.get_dependents()) == 1 + len(transform_1_loaded.dependents) == 1 ), "Expected transform 1 to have a registered dependent (transform 2)" transform_2_loaded = Transformation(nexus_wrapper, transform_2) assert ( - len(transform_2_loaded.get_dependents()) == 2 + len(transform_2_loaded.dependents) == 2 ), "Expected transform 2 to have 2 registered dependents (transforms 3 and 4)" diff --git a/tests/test_off_renderer.py b/tests/test_off_renderer.py index 51977a444..9908a09c5 100644 --- a/tests/test_off_renderer.py +++ b/tests/test_off_renderer.py @@ -52,9 +52,10 @@ def test_GIVEN_a_square_WHEN_creating_vertex_buffer_THEN_length_is_correct(): QVector3D(0, 1, 0), QVector3D(1, 1, 0), ] - faces = [[0, 1, 2, 3]] + # 2 triangles make up the square + triangles = [[0, 1, 2], [2, 3, 0]] - vertex_buffer = create_vertex_buffer(vertices, faces) + vertex_buffer = create_vertex_buffer(vertices, triangles) assert ( len(list(vertex_buffer)) @@ -80,9 +81,10 @@ def test_GIVEN_a_square_face_WHEN_creating_normal_buffer_THEN_output_is_correct( QVector3D(1, 1, 0), QVector3D(1, 0, 0), ] - faces = [[0, 1, 2, 3]] + # 2 triangles make up the square + triangles = [[0, 1, 2], [2, 3, 0]] - normal = create_normal_buffer(vertices, faces) + normal = create_normal_buffer(vertices, triangles) expected_output = [0.0, 0.0, -1.0] * TRIANGLES_IN_SQUARE * VERTICES_IN_TRIANGLE diff --git a/tests/test_remove_from_dependee_chain.py b/tests/test_remove_from_dependee_chain.py index e7329e508..08b8b4fb3 100644 --- a/tests/test_remove_from_dependee_chain.py +++ b/tests/test_remove_from_dependee_chain.py @@ -6,9 +6,9 @@ def test_remove_from_beginning_1(nexus_wrapper): component1 = add_component_to_file(nexus_wrapper, "field", 42, "component1") rot = component1.add_rotation(QVector3D(1.0, 0.0, 0.0), 90.0) component1.depends_on = rot - assert len(rot.get_dependents()) == 1 + assert len(rot.dependents) == 1 rot.remove_from_dependee_chain() - assert component1.depends_on.absolute_path == "/" + assert component1.depends_on is None def test_remove_from_beginning_2(nexus_wrapper): @@ -17,10 +17,10 @@ def test_remove_from_beginning_2(nexus_wrapper): rot2 = component1.add_rotation(QVector3D(1.0, 0.0, 0.0), 90.0) component1.depends_on = rot1 rot1.depends_on = rot2 - assert len(rot2.get_dependents()) == 1 + assert len(rot2.dependents) == 1 rot1.remove_from_dependee_chain() - assert len(rot2.get_dependents()) == 1 - assert rot2.get_dependents()[0] == component1 + assert len(rot2.dependents) == 1 + assert rot2.dependents[0] == component1 assert component1.depends_on == rot2 @@ -32,11 +32,11 @@ def test_remove_from_beginning_3(nexus_wrapper): component1.depends_on = rot1 component2.depends_on = rot2 rot1.depends_on = rot2 - assert len(rot2.get_dependents()) == 2 + assert len(rot2.dependents) == 2 rot1.remove_from_dependee_chain() - assert len(rot2.get_dependents()) == 2 - assert component2 in rot2.get_dependents() - assert component1 in rot2.get_dependents() + assert len(rot2.dependents) == 2 + assert component2 in rot2.dependents + assert component1 in rot2.dependents assert component1.depends_on == rot2 assert component1.transforms.link.linked_component == component2 @@ -56,5 +56,24 @@ def test_remove_from_middle(nexus_wrapper): rot2.remove_from_dependee_chain() assert rot1.depends_on == rot3 assert component1.transforms.link.linked_component == component3 - assert rot1 in rot3.get_dependents() - assert component3 in rot3.get_dependents() + assert rot1 in rot3.dependents + assert component3 in rot3.dependents + + +def test_remove_from_end(nexus_wrapper): + component1 = add_component_to_file(nexus_wrapper, "field", 42, "component1") + rot1 = component1.add_rotation(QVector3D(1.0, 0.0, 0.0), 90.0) + rot2 = component1.add_rotation(QVector3D(1.0, 0.0, 0.0), 90.0, depends_on=rot1) + rot3 = component1.add_rotation(QVector3D(1.0, 0.0, 0.0), 90.0, depends_on=rot2) + + component1.depends_on = rot3 + + rot1.remove_from_dependee_chain() + + assert rot1.depends_on is None + assert not rot1.dependents + + assert component1.depends_on.absolute_path == rot3.absolute_path + + assert rot2.dependents[0].absolute_path == rot3.absolute_path + assert len(component1.transforms) == 2 diff --git a/tests/test_transformations.py b/tests/test_transformations.py index 04528b4fe..daba17b86 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -182,7 +182,7 @@ def test_set_one_dependent(nexus_wrapper): transform1.register_dependent(transform2) - set_dependents = transform1.get_dependents() + set_dependents = transform1.dependents assert len(set_dependents) == 1 assert set_dependents[0] == transform2 @@ -199,7 +199,7 @@ def test_set_two_dependents(nexus_wrapper): transform1.register_dependent(transform2) transform1.register_dependent(transform3) - set_dependents = transform1.get_dependents() + set_dependents = transform1.dependents assert len(set_dependents) == 2 assert set_dependents[0] == transform2 @@ -220,7 +220,7 @@ def test_set_three_dependents(nexus_wrapper): transform1.register_dependent(transform3) transform1.register_dependent(transform4) - set_dependents = transform1.get_dependents() + set_dependents = transform1.dependents assert len(set_dependents) == 3 assert set_dependents[0] == transform2 @@ -237,7 +237,7 @@ def test_deregister_dependent(nexus_wrapper): transform1.register_dependent(transform2) transform1.deregister_dependent(transform2) - set_dependents = transform1.get_dependents() + set_dependents = transform1.dependents assert not set_dependents @@ -249,7 +249,7 @@ def test_deregister_unregistered_dependent_alt1(nexus_wrapper): transform1.deregister_dependent(transform2) - assert not transform1.get_dependents() + assert not transform1.dependents def test_deregister_unregistered_dependent_alt2(nexus_wrapper): @@ -261,8 +261,8 @@ def test_deregister_unregistered_dependent_alt2(nexus_wrapper): transform1.register_dependent(transform3) transform1.deregister_dependent(transform2) - assert len(transform1.get_dependents()) == 1 - assert transform1.get_dependents()[0] == transform3 + assert len(transform1.dependents) == 1 + assert transform1.dependents[0] == transform3 def test_deregister_unregistered_dependent_alt3(nexus_wrapper): @@ -276,9 +276,9 @@ def test_deregister_unregistered_dependent_alt3(nexus_wrapper): transform1.register_dependent(transform4) transform1.deregister_dependent(transform2) - assert len(transform1.get_dependents()) == 2 - assert transform1.get_dependents()[0] == transform3 - assert transform1.get_dependents()[1] == transform4 + assert len(transform1.dependents) == 2 + assert transform1.dependents[0] == transform3 + assert transform1.dependents[1] == transform4 def test_reregister_dependent(nexus_wrapper): @@ -293,7 +293,7 @@ def test_reregister_dependent(nexus_wrapper): transform1.deregister_dependent(transform2) transform1.register_dependent(transform3) - set_dependents = transform1.get_dependents() + set_dependents = transform1.dependents assert len(set_dependents) == 1 assert set_dependents[0] == transform3 @@ -307,7 +307,7 @@ def test_set_one_dependent_component(nexus_wrapper): transform.register_dependent(component) - set_dependents = transform.get_dependents() + set_dependents = transform.dependents assert len(set_dependents) == 1 assert set_dependents[0] == component @@ -323,7 +323,7 @@ def test_set_two_dependent_components(nexus_wrapper): transform.register_dependent(component1) transform.register_dependent(component2) - set_dependents = transform.get_dependents() + set_dependents = transform.dependents assert len(set_dependents) == 2 assert set_dependents[0] == component1 @@ -342,7 +342,7 @@ def test_set_three_dependent_components(nexus_wrapper): transform.register_dependent(component2) transform.register_dependent(component3) - set_dependents = transform.get_dependents() + set_dependents = transform.dependents assert len(set_dependents) == 3 assert set_dependents[0] == component1 @@ -366,7 +366,7 @@ def test_deregister_three_dependent_components(nexus_wrapper): transform.deregister_dependent(component2) transform.deregister_dependent(component3) - set_dependents = transform.get_dependents() + set_dependents = transform.dependents assert len(set_dependents) == 0 @@ -380,7 +380,7 @@ def test_register_dependent_twice(nexus_wrapper): transform.register_dependent(component1) transform.register_dependent(component1) - set_dependents = transform.get_dependents() + set_dependents = transform.dependents assert len(set_dependents) == 1 diff --git a/tests/ui_tests/test_main_window_utils.py b/tests/ui_tests/test_main_window_utils.py index c3f6645be..abbda3b5c 100644 --- a/tests/ui_tests/test_main_window_utils.py +++ b/tests/ui_tests/test_main_window_utils.py @@ -348,7 +348,6 @@ def test_GIVEN_transformation_is_selected_WHEN_changing_button_states_THEN_expec transformation_selected_actions = { delete_action, duplicate_action, - edit_component_action, new_rotation_action, new_translation_action, create_link_action, diff --git a/tests/ui_tests/test_ui_file_writer_ctrl_window.py b/tests/ui_tests/test_ui_file_writer_ctrl_window.py index 99ea601d6..5f5955792 100644 --- a/tests/ui_tests/test_ui_file_writer_ctrl_window.py +++ b/tests/ui_tests/test_ui_file_writer_ctrl_window.py @@ -1,15 +1,35 @@ import pytest +from PySide2.QtCore import QSettings from PySide2.QtGui import QStandardItemModel from mock import Mock from streaming_data_types import run_start_pl72 -from nexus_constructor.file_writer_ctrl_window import FileWriterCtrl, File, FileWriter from nexus_constructor.validators import BrokerAndTopicValidator +from nexus_constructor.file_writer_ctrl_window import ( + FileWriterCtrl, + FileWriterSettings, + extract_bool_from_qsettings, + File, + FileWriter, +) + + +@pytest.fixture() +def settings(): + settings = QSettings("testing", "NCui_tests") + yield settings + settings.setValue(FileWriterSettings.STATUS_BROKER_ADDR, "") + settings.setValue(FileWriterSettings.COMMAND_BROKER_ADDR, "") + settings.setValue(FileWriterSettings.FILE_NAME, "") + settings.setValue(FileWriterSettings.USE_START_TIME, False) + settings.setValue(FileWriterSettings.USE_STOP_TIME, False) + settings.setValue(FileWriterSettings.FILE_BROKER_ADDR, "") + del settings def test_UI_GIVEN_nothing_WHEN_creating_filewriter_control_window_THEN_broker_field_defaults_are_set_correctly( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) assert not window.command_broker_edit.text() @@ -22,9 +42,9 @@ def test_UI_GIVEN_nothing_WHEN_creating_filewriter_control_window_THEN_broker_fi def test_UI_GIVEN_nothing_WHEN_creating_filewriter_control_window_THEN_broker_validators_are_set_correctly( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) assert isinstance(window.status_broker_edit.validator(), BrokerAndTopicValidator) @@ -49,9 +69,9 @@ def test_UI_GIVEN_time_string_WHEN_setting_time_THEN_last_time_is_stored( def test_UI_GIVEN_no_files_WHEN_stop_file_writing_is_clicked_THEN_button_is_disabled( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) window.files_list.selectedIndexes = lambda: [] @@ -61,9 +81,9 @@ def test_UI_GIVEN_no_files_WHEN_stop_file_writing_is_clicked_THEN_button_is_disa def test_UI_GIVEN_files_WHEN_stop_file_writing_is_clicked_THEN_button_is_enabled( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) window.files_list.selectedIndexes = lambda: [ 1, @@ -77,10 +97,10 @@ def test_UI_GIVEN_files_WHEN_stop_file_writing_is_clicked_THEN_button_is_enabled def test_UI_GIVEN_valid_command_WHEN_sending_command_THEN_command_producer_sends_command( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) window.command_producer = Mock() @@ -108,9 +128,9 @@ def test_UI_GIVEN_valid_command_WHEN_sending_command_THEN_command_producer_sends def test_UI_GIVEN_no_status_consumer_and_no_command_producer_WHEN_checking_status_connection_THEN_both_leds_are_turned_off( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) window.status_consumer = None window.command_producer = None @@ -122,9 +142,9 @@ def test_UI_GIVEN_no_status_consumer_and_no_command_producer_WHEN_checking_statu def test_UI_GIVEN_status_consumer_but_no_command_producer_WHEN_checking_status_connection_THEN_status_led_is_turned_on( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) window.command_producer = None window.status_consumer = Mock() @@ -137,9 +157,9 @@ def test_UI_GIVEN_status_consumer_but_no_command_producer_WHEN_checking_status_c def test_UI_GIVEN_command_producer_WHEN_checking_connection_status_THEN_command_led_is_turned_on( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) window.command_producer = Mock() window.status_consumer = None @@ -155,11 +175,13 @@ def __init__(self, address, topic): self.topic = topic -@pytest.mark.skip(reason="qtbot interferes with other tests") +@pytest.mark.skip( + reason="this test passes, but qtbot unexpectedly interferes with other tests after it has run" +) def test_UI_GIVEN_invalid_broker_WHEN_status_broker_timer_callback_is_called_THEN_nothing_happens( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) window.status_consumer = None window.status_broker_edit.setText("invalid") @@ -168,11 +190,13 @@ def test_UI_GIVEN_invalid_broker_WHEN_status_broker_timer_callback_is_called_THE assert window.status_consumer is None -@pytest.mark.skip(reason="qtbot interferes with other tests") +@pytest.mark.skip( + reason="this test passes, but qtbot unexpectedly interferes with other tests after it has run" +) def test_UI_GIVEN_invalid_broker_WHEN_command_broker_timer_callback_is_called_THEN_nothing_happens( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) window.command_producer = None window.command_broker_edit.setText("invalid") @@ -181,11 +205,13 @@ def test_UI_GIVEN_invalid_broker_WHEN_command_broker_timer_callback_is_called_TH assert window.command_producer is None -@pytest.mark.skip(reason="qtbot interferes with other tests") +@pytest.mark.skip( + reason="this test passes, but qtbot unexpectedly interferes with other tests after it has run" +) def test_UI_GIVEN_valid_broker_WHEN_command_broker_timer_callback_is_called_THEN_producer_is_created( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) window.command_broker_change_timer.stop() window.status_broker_change_timer.stop() @@ -197,11 +223,13 @@ def test_UI_GIVEN_valid_broker_WHEN_command_broker_timer_callback_is_called_THEN assert isinstance(window.command_producer, DummyInterface) -@pytest.mark.skip(reason="qtbot interferes with other tests") +@pytest.mark.skip( + reason="this test passes, but qtbot unexpectedly interferes with other tests after it has run" +) def test_UI_GIVEN_valid_broker_WHEN_status_broker_timer_callback_is_called_THEN_consumer_is_created( - qtbot, instrument + qtbot, instrument, settings ): - window = FileWriterCtrl(instrument) + window = FileWriterCtrl(instrument, settings) qtbot.addWidget(window) window.command_broker_change_timer.stop() window.status_broker_change_timer.stop() @@ -211,3 +239,84 @@ def test_UI_GIVEN_valid_broker_WHEN_status_broker_timer_callback_is_called_THEN_ window.status_broker_timer_changed(DummyInterface) assert isinstance(window.status_consumer, DummyInterface) + + +@pytest.mark.skip( + reason="this test passes, but qtbot unexpectedly interferes with other tests after it has run" +) +def test_UI_settings_are_saved_when_store_settings_is_called( + qtbot, instrument, settings +): + window = FileWriterCtrl(instrument, settings) + qtbot.addWidget(window) + + command_broker = "broker:9092/topic1" + window.command_broker_edit.setText(command_broker) + + status_broker = "broker2:9092/topic2" + window.status_broker_edit.setText(status_broker) + + file_broker = "broker3:9092/topic3" + window.command_widget.broker_line_edit.setText(file_broker) + + use_start_time = True + window.command_widget.start_time_enabled.setChecked(use_start_time) + + use_stop_time = True + window.command_widget.stop_time_enabled.setChecked(use_stop_time) + + filename = "test.nxs" + window.command_widget.nexus_file_name_edit.setText(filename) + + window._store_settings() + + assert settings.value(FileWriterSettings.COMMAND_BROKER_ADDR) == command_broker + assert settings.value(FileWriterSettings.STATUS_BROKER_ADDR) == status_broker + assert settings.value(FileWriterSettings.FILE_BROKER_ADDR) == file_broker + assert ( + extract_bool_from_qsettings(settings.value(FileWriterSettings.USE_START_TIME)) + == use_start_time + ) + assert ( + extract_bool_from_qsettings(settings.value(FileWriterSettings.USE_STOP_TIME)) + == use_stop_time + ) + assert settings.value(FileWriterSettings.FILE_NAME) == filename + + +@pytest.mark.skip( + reason="this test passes, but qtbot unexpectedly interferes with other tests after it has run" +) +def test_UI_stored_settings_are_shown_in_window(qtbot, instrument, settings): + command_broker = "broker:9092/topic2" + status_broker = "broker2:9092/topic3" + file_broker = "broker3:9092/topic4" + use_start_time = True + use_stop_time = False + filename = "test2.nxs" + + settings.setValue(FileWriterSettings.STATUS_BROKER_ADDR, status_broker) + settings.setValue(FileWriterSettings.COMMAND_BROKER_ADDR, command_broker) + settings.setValue(FileWriterSettings.FILE_NAME, filename) + settings.setValue(FileWriterSettings.USE_START_TIME, use_start_time) + settings.setValue(FileWriterSettings.USE_STOP_TIME, use_stop_time) + settings.setValue(FileWriterSettings.FILE_BROKER_ADDR, file_broker) + + # _restore_settings should be called on construction + window = FileWriterCtrl(instrument, settings) + qtbot.addWidget(window) + + assert window.status_broker_edit.text() == status_broker + assert window.command_broker_edit.text() == command_broker + assert use_start_time == window.command_widget.start_time_enabled.isChecked() + assert use_stop_time == window.command_widget.stop_time_enabled.isChecked() + assert filename == window.command_widget.nexus_file_name_edit.text() + assert file_broker == window.command_widget.broker_line_edit.text() + + +def test_UI_disable_stop_button_when_no_files_are_selected(qtbot, instrument, settings): + window = FileWriterCtrl(instrument, settings) + qtbot.addWidget(window) + + assert not window.files_list.selectedIndexes() + assert not window.stop_file_writing_button.isEnabled() diff --git a/ui/filewriter_ctrl_frame.py b/ui/filewriter_ctrl_frame.py index 21f7c945c..2db3fafdf 100644 --- a/ui/filewriter_ctrl_frame.py +++ b/ui/filewriter_ctrl_frame.py @@ -91,7 +91,7 @@ def setupUi(self, FilewriterCtrl): self.command_broker_edit.setPlaceholderText(broker_placeholder_text) self.command_broker_layout.addWidget(self.command_broker_edit) self.command_layout.addLayout(self.command_broker_layout) - self.command_widget = FilewriterCommandWidget() + self.command_widget = FilewriterCommandWidget(FilewriterCtrl) self.command_layout.addWidget(self.command_widget) self.horizontal_layout.addLayout(self.command_layout)