diff --git a/rascal2/core/commands.py b/rascal2/core/commands.py index ebf8717..acd87ee 100644 --- a/rascal2/core/commands.py +++ b/rascal2/core/commands.py @@ -1,8 +1,10 @@ """File for Qt commands.""" +import copy from enum import IntEnum, unique from typing import Callable +import RATapi from PyQt6 import QtGui from RATapi import ClassList @@ -96,3 +98,80 @@ def update_attribute(self): def id(self): return CommandID.EditProject + + +class SaveResults(QtGui.QUndoCommand): + """Command for saving the Results object. + + Parameters + ---------- + problem : RATapi.rat_core.ProblemDefinition + The problem + results : Union[RATapi.outputs.Results, RATapi.outputs.BayesResults] + The calculation results. + log : str + log text from the given calculation. + """ + + def __init__(self, problem, results, log: str, presenter): + super().__init__() + self.presenter = presenter + self.results = results + self.log = log + self.problem = self.get_parameter_values(problem) + self.old_problem = self.get_parameter_values(RATapi.inputs.make_problem(self.presenter.model.project)) + self.old_results = copy.deepcopy(self.presenter.model.results) + self.old_log = self.presenter.model.result_log + self.setText("Save calculation results") + + def get_parameter_values(self, problem_definition: RATapi.rat_core.ProblemDefinition): + """Get parameter values from problem definition.""" + parameter_field = { + "parameters": "params", + "bulk_in": "bulkIn", + "bulk_out": "bulkOut", + "scalefactors": "scalefactors", + "domain_ratios": "domainRatio", + "background_parameters": "backgroundParams", + "resolution_parameters": "resolutionParams", + } + + values = {} + for class_list in RATapi.project.parameter_class_lists: + entry = values.setdefault(class_list, []) + entry.extend(getattr(problem_definition, parameter_field[class_list])) + return values + + def set_parameter_values(self, values): + """Update the project given a set of results.""" + + for key, value in values.items(): + for index in range(len(value)): + getattr(self.presenter.model.project, key)[index].value = value[index] + return values + + def undo(self): + self.swap_results(self.old_problem, self.old_results, self.old_log) + + def redo(self): + self.swap_results(self.problem, self.results, self.log) + + def swap_results(self, problem, results, log): + """Swap problem, result and log in model with given one + + Parameters + ---------- + problem : RATapi.rat_core.ProblemDefinition + The problem definition + results : Union[RATapi.outputs.Results, RATapi.outputs.BayesResults] + The calculation results. + log : str + log text from the given calculation. + """ + self.set_parameter_values(problem) + self.presenter.model.update_results(copy.deepcopy(results)) + self.presenter.model.result_log = log + chi_text = "" if results is None else f"{results.calculationResults.sumChi:.6g}" + self.presenter.view.controls_widget.chi_squared.setText(chi_text) + self.presenter.view.terminal_widget.clear() + self.presenter.view.terminal_widget.write(log) diff --git a/rascal2/core/runner.py b/rascal2/core/runner.py index c7dc4f3..c55ef39 100644 --- a/rascal2/core/runner.py +++ b/rascal2/core/runner.py @@ -80,6 +80,7 @@ def run(queue, rat_inputs: tuple, procedure: str, display: bool): if display: RAT.events.register(RAT.events.EventTypes.Message, queue.put) RAT.events.register(RAT.events.EventTypes.Progress, queue.put) + RAT.events.register(RAT.events.EventTypes.Plot, queue.put) queue.put(LogData(INFO, "Starting RAT")) try: diff --git a/rascal2/static/images/hide-settings.png b/rascal2/static/images/hide-settings.png new file mode 100644 index 0000000..20f784b Binary files /dev/null and b/rascal2/static/images/hide-settings.png differ diff --git a/rascal2/ui/model.py b/rascal2/ui/model.py index 171bdfe..c8ffd94 100644 --- a/rascal2/ui/model.py +++ b/rascal2/ui/model.py @@ -1,6 +1,8 @@ from pathlib import Path +from typing import Union import RATapi as RAT +import RATapi.outputs from PyQt6 import QtCore @@ -9,12 +11,14 @@ class MainWindowModel(QtCore.QObject): project_updated = QtCore.pyqtSignal() controls_updated = QtCore.pyqtSignal() + results_updated = QtCore.pyqtSignal() def __init__(self): super().__init__() self.project = None self.results = None + self.result_log = "" self.controls = None self.save_path = "" @@ -33,21 +37,16 @@ def create_project(self, name: str, save_path: str): self.controls = RAT.Controls() self.save_path = save_path - def handle_results(self, problem_definition: RAT.rat_core.ProblemDefinition): - """Update the project given a set of results.""" - parameter_field = { - "parameters": "params", - "bulk_in": "bulkIn", - "bulk_out": "bulkOut", - "scalefactors": "scalefactors", - "domain_ratios": "domainRatio", - "background_parameters": "backgroundParams", - "resolution_parameters": "resolutionParams", - } - - for class_list in RAT.project.parameter_class_lists: - for index, value in enumerate(getattr(problem_definition, parameter_field[class_list])): - getattr(self.project, class_list)[index].value = value + def update_results(self, results: Union[RATapi.outputs.Results, RATapi.outputs.BayesResults]): + """Update the project given a set of results. + + Parameters + ---------- + results : Union[RATapi.outputs.Results, RATapi.outputs.BayesResults] + The calculation results. + """ + self.results = results + self.results_updated.emit() def update_project(self, new_values: dict) -> None: """Replaces the project with a new project. @@ -111,12 +110,12 @@ def load_r1_project(self, load_path: str): self.controls = RAT.Controls() self.save_path = str(Path(load_path).parent) - def update_controls(self, new_values): - """ + def update_controls(self, new_values: dict): + """Updates the control attributes. Parameters ---------- - new_values: Dict + new_values: dict The attribute name-value pair to updated on the controls. """ vars(self.controls).update(new_values) diff --git a/rascal2/ui/presenter.py b/rascal2/ui/presenter.py index 6ec7f06..38d800c 100644 --- a/rascal2/ui/presenter.py +++ b/rascal2/ui/presenter.py @@ -138,7 +138,7 @@ def run(self): self.view.terminal_widget.progress_bar.setVisible(False) if self.view.settings.clear_terminal: self.view.terminal_widget.clear() - + self.model.project, _ = RAT.examples.non_polarised.DSPC_standard_layers.DSPC_standard_layers() rat_inputs = RAT.inputs.make_input(self.model.project, self.model.controls) display_on = self.model.controls.display != RAT.utils.enums.Display.Off @@ -150,7 +150,14 @@ def run(self): def handle_results(self): """Handle a RAT run being finished.""" - self.model.handle_results(self.runner.updated_problem) + self.view.undo_stack.push( + commands.SaveResults( + self.runner.updated_problem, + self.runner.results, + self.view.terminal_widget.text_area.toPlainText(), + self, + ) + ) self.view.handle_results(self.runner.results) def handle_interrupt(self): @@ -171,6 +178,8 @@ def handle_event(self): self.view.controls_widget.chi_squared.setText(chi_squared) elif isinstance(event, RAT.events.ProgressEventData): self.view.terminal_widget.update_progress(event) + elif isinstance(event, RAT.events.PlotEventData): + self.view.plotting_widget.plot_event(event) elif isinstance(event, LogData): self.view.logging.log(event.level, event.msg) diff --git a/rascal2/ui/view.py b/rascal2/ui/view.py index 49f8076..8109d18 100644 --- a/rascal2/ui/view.py +++ b/rascal2/ui/view.py @@ -6,7 +6,7 @@ from rascal2.core.settings import MDIGeometries, Settings from rascal2.dialogs.project_dialog import PROJECT_FILES, LoadDialog, LoadR1Dialog, NewProjectDialog, StartupDialog from rascal2.dialogs.settings_dialog import SettingsDialog -from rascal2.widgets import ControlsWidget, TerminalWidget +from rascal2.widgets import ControlsWidget, PlotWidget, TerminalWidget from rascal2.widgets.project import ProjectWidget from rascal2.widgets.startup import StartUpWidget @@ -32,11 +32,9 @@ def __init__(self): self.presenter = MainWindowPresenter(self) self.mdi = QtWidgets.QMdiArea() - # TODO replace the widgets below - # plotting: NO ISSUE YET - # https://github.com/RascalSoftware/RasCAL-2/issues/5 - self.plotting_widget = QtWidgets.QWidget() - self.terminal_widget = TerminalWidget(self) + + self.plotting_widget = PlotWidget(self) + self.terminal_widget = TerminalWidget() self.controls_widget = ControlsWidget(self) self.project_widget = ProjectWidget(self) @@ -255,7 +253,6 @@ def setup_mdi(self): "Fitting Controls": self.controls_widget, } self.setup_mdi_widgets() - self.terminal_widget.text_area.setVisible(True) for title, widget in reversed(widgets.items()): widget.setWindowTitle(title) diff --git a/rascal2/widgets/__init__.py b/rascal2/widgets/__init__.py index ce729e5..4d032a1 100644 --- a/rascal2/widgets/__init__.py +++ b/rascal2/widgets/__init__.py @@ -1,5 +1,6 @@ from rascal2.widgets.controls import ControlsWidget from rascal2.widgets.inputs import AdaptiveDoubleSpinBox, get_validated_input +from rascal2.widgets.plotter import PlotWidget from rascal2.widgets.terminal import TerminalWidget -__all__ = ["ControlsWidget", "AdaptiveDoubleSpinBox", "get_validated_input", "TerminalWidget"] +__all__ = ["ControlsWidget", "AdaptiveDoubleSpinBox", "get_validated_input", "PlotWidget", "TerminalWidget"] diff --git a/rascal2/widgets/controls.py b/rascal2/widgets/controls.py index dea5974..09cab59 100644 --- a/rascal2/widgets/controls.py +++ b/rascal2/widgets/controls.py @@ -15,7 +15,7 @@ class ControlsWidget(QtWidgets.QWidget): """Widget for editing the Controls window.""" def __init__(self, parent): - super().__init__(parent) + super().__init__() self.presenter = parent.presenter self.presenter.model.controls_updated.connect(self.update_ui) diff --git a/rascal2/widgets/plotter.py b/rascal2/widgets/plotter.py new file mode 100644 index 0000000..91c7a16 --- /dev/null +++ b/rascal2/widgets/plotter.py @@ -0,0 +1,145 @@ +from typing import Optional, Union + +import RATapi +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.figure import Figure +from PyQt6 import QtCore, QtGui, QtWidgets + +from rascal2.config import path_for + + +class PlotWidget(QtWidgets.QWidget): + """Creates a UI for displaying the path lengths from the simulation result""" + + def __init__(self, parent): + super().__init__() + + self.current_plot_data = None + + self.parent_model = parent.presenter.model + main_layout = QtWidgets.QHBoxLayout() + control_layout = QtWidgets.QHBoxLayout() + plot_layout = QtWidgets.QVBoxLayout() + main_layout.addLayout(control_layout, 0) + main_layout.addLayout(plot_layout, 4) + self.setLayout(main_layout) + + self.create_plot_control() + control_layout.addWidget(self.plot_controls) + + slider_layout = QtWidgets.QVBoxLayout() + self.toggle_button = QtWidgets.QToolButton() + self.toggle_button.toggled.connect(self.toggle_settings) + self.toggle_button.setCheckable(True) + self.toggle_settings(self.toggle_button.isChecked()) + self.slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Vertical) + slider_layout.addWidget(self.toggle_button) + slider_layout.addWidget(self.slider) + slider_layout.setAlignment(self.slider, QtCore.Qt.AlignmentFlag.AlignHCenter) + control_layout.addLayout(slider_layout) + + self.figure = Figure() + self.figure.subplots(1, 2) + self.canvas = FigureCanvas(self.figure) + self.canvas.setParent(self) + plot_layout.addWidget(self.canvas) + self.setMinimumHeight(300) + + self.parent_model.results_updated.connect( + lambda: self.plot(self.parent_model.project, self.parent_model.results) + ) + + def create_plot_control(self): + """Creates the controls for customising plot""" + self.plot_controls = QtWidgets.QWidget() + self.x_axis = QtWidgets.QComboBox() + self.x_axis.addItems(["Log", "Linear"]) + self.x_axis.currentTextChanged.connect(lambda: self.plot_event()) + self.y_axis = QtWidgets.QComboBox() + self.y_axis.addItems(["Ref", "Q^4"]) + self.y_axis.currentTextChanged.connect(lambda: self.plot_event()) + self.show_error_bar = QtWidgets.QCheckBox("Show Error Bars") + self.show_error_bar.setChecked(True) + self.show_error_bar.checkStateChanged.connect(lambda: self.plot_event()) + self.show_grid = QtWidgets.QCheckBox("Show Grid") + self.show_grid.checkStateChanged.connect(lambda: self.plot_event()) + self.show_legend = QtWidgets.QCheckBox("Show Legend") + self.show_legend.setChecked(True) + self.show_legend.checkStateChanged.connect(lambda: self.plot_event()) + + layout = QtWidgets.QVBoxLayout() + layout.addWidget(QtWidgets.QLabel("X-Axis")) + layout.addWidget(self.x_axis) + layout.addWidget(QtWidgets.QLabel("Y-Axis")) + layout.addWidget(self.y_axis) + layout.addWidget(self.show_error_bar) + layout.addWidget(self.show_grid) + layout.addWidget(self.show_legend) + layout.addStretch(1) + self.plot_controls.setLayout(layout) + + def toggle_settings(self, toggled_on: bool): + """Toggles the visibility of the plot controls""" + self.plot_controls.setVisible(toggled_on) + if toggled_on: + self.toggle_button.setIcon(QtGui.QIcon(path_for("hide-settings.png"))) + else: + self.toggle_button.setIcon(QtGui.QIcon(path_for("settings.png"))) + + def plot(self, project: RATapi.Project, results: Union[RATapi.outputs.Results, RATapi.outputs.BayesResults]): + """Plots the reflectivity and SLD profiles. + + Parameters + ---------- + problem : RATapi.Project + The project + results : Union[RATapi.outputs.Results, RATapi.outputs.BayesResults] + The calculation results. + """ + if project is None or results is None: + for axis in self.figure.axes: + axis.clear() + self.canvas.draw() + return + + data = RATapi.events.PlotEventData() + + data.modelType = project.model + data.reflectivity = results.reflectivity + data.shiftedData = results.shiftedData + data.sldProfiles = results.sldProfiles + data.resampledLayers = results.resampledLayers + data.dataPresent = RATapi.inputs.make_data_present(project) + data.subRoughs = results.contrastParams.subRoughs + data.resample = RATapi.inputs.make_resample(project) + data.contrastNames = [contrast.name for contrast in project.contrasts] + self.plot_event(data) + + def plot_event(self, data: Optional[RATapi.events.PlotEventData] = None): + """Updates the ref and SLD plots from a provided or cached plot event + + Parameters + ---------- + data : Optional[RATapi.events.PlotEventData] + plot event data, cached data is used if none is provided + """ + + if data is not None: + self.current_plot_data = data + + if self.current_plot_data is None: + return + + show_legend = self.show_legend.isChecked() if self.current_plot_data.contrastNames else False + RATapi.plotting.plot_ref_sld_helper( + self.current_plot_data, + self.figure, + delay=False, + linear_x=self.x_axis.currentText() == "Linear", + q4=self.y_axis.currentText() == "Q^4", + show_error_bar=self.show_error_bar.isChecked(), + show_grid=self.show_grid.isChecked(), + show_legend=show_legend, + ) + self.figure.tight_layout(pad=1) + self.canvas.draw() diff --git a/rascal2/widgets/project/project.py b/rascal2/widgets/project/project.py index 14528ed..7ab38c2 100644 --- a/rascal2/widgets/project/project.py +++ b/rascal2/widgets/project/project.py @@ -24,7 +24,7 @@ def __init__(self, parent): parent: MainWindowView An instance of the MainWindowView """ - super().__init__(parent) + super().__init__() self.parent = parent self.parent_model = self.parent.presenter.model diff --git a/rascal2/widgets/terminal.py b/rascal2/widgets/terminal.py index 15999a4..68ad3bf 100644 --- a/rascal2/widgets/terminal.py +++ b/rascal2/widgets/terminal.py @@ -8,14 +8,10 @@ class TerminalWidget(QtWidgets.QWidget): """Widget for displaying program output.""" - def __init__(self, parent=None): - super().__init__(parent) + def __init__(self): + super().__init__() self.text_area = QtWidgets.QPlainTextEdit() - # Something wierd is going on where the text area shows up in the top - # left of the main window under the menus not sure why but this a workaround. - # So far only happening on Windows 11 - self.text_area.setVisible(False) self.text_area.setReadOnly(True) font = QtGui.QFont() font.setFamily("Courier") diff --git a/tests/test_plotter.py b/tests/test_plotter.py new file mode 100644 index 0000000..d61aeea --- /dev/null +++ b/tests/test_plotter.py @@ -0,0 +1,99 @@ +from unittest.mock import MagicMock, patch + +import pytest +import RATapi +from PyQt6 import QtWidgets + +from rascal2.widgets.plotter import PlotWidget + + +class MockWindowView(QtWidgets.QMainWindow): + """A mock MainWindowView class.""" + + def __init__(self): + super().__init__() + self.presenter = MagicMock() + self.presenter.model = MagicMock() + + +view = MockWindowView() + + +@pytest.fixture +def plot_widget(): + plot_widget = PlotWidget(view) + plot_widget.canvas = MagicMock() + + return plot_widget + + +def test_toggle_setting(plot_widget): + """Test that plot settings are hidden when the button is toggled.""" + assert not plot_widget.plot_controls.isVisibleTo(plot_widget) + plot_widget.toggle_button.toggle() + assert plot_widget.plot_controls.isVisibleTo(plot_widget) + plot_widget.toggle_button.toggle() + assert not plot_widget.plot_controls.isVisibleTo(plot_widget) + + +@patch("RATapi.plotting.RATapi.plotting.plot_ref_sld_helper") +def test_plot_event(mock_plot_sld, plot_widget): + """Test that plot helper recieved correct flags from UI .""" + data = RATapi.events.PlotEventData() + data.contrastNames = ["Hello"] + + assert plot_widget.current_plot_data is None + plot_widget.plot_event(data) + assert plot_widget.current_plot_data is data + mock_plot_sld.assert_called_with( + data, + plot_widget.figure, + delay=False, + linear_x=False, + q4=False, + show_error_bar=True, + show_grid=False, + show_legend=True, + ) + plot_widget.canvas.draw.assert_called_once() + data.contrastNames = [] + plot_widget.plot_event(data) + mock_plot_sld.assert_called_with( + data, + plot_widget.figure, + delay=False, + linear_x=False, + q4=False, + show_error_bar=True, + show_grid=False, + show_legend=False, + ) + data.contrastNames = ["Hello"] + plot_widget.x_axis.setCurrentText("Linear") + plot_widget.y_axis.setCurrentText("Q^4") + plot_widget.show_error_bar.setChecked(False) + plot_widget.show_grid.setChecked(True) + plot_widget.show_legend.setChecked(False) + mock_plot_sld.assert_called_with( + data, + plot_widget.figure, + delay=False, + linear_x=True, + q4=True, + show_error_bar=False, + show_grid=True, + show_legend=False, + ) + + +@patch("RATapi.inputs.make_input") +def test_plot(mock_inputs, plot_widget): + """Test that plot settings are hidden when the button is toggled.""" + project = MagicMock() + result = MagicMock() + data = MagicMock + with patch("RATapi.events.PlotEventData", return_value=data): + assert plot_widget.current_plot_data is None + plot_widget.plot(project, result) + assert plot_widget.current_plot_data is data + plot_widget.canvas.draw.assert_called_once() diff --git a/tests/test_presenter.py b/tests/test_presenter.py index d6b41e1..0d869da 100644 --- a/tests/test_presenter.py +++ b/tests/test_presenter.py @@ -7,6 +7,7 @@ from PyQt6 import QtWidgets from RATapi import Controls from RATapi.events import ProgressEventData +from RATapi.inputs import ProblemDefinition from rascal2.core.runner import LogData from rascal2.ui.presenter import MainWindowPresenter @@ -81,13 +82,16 @@ def test_run_and_interrupt(mock_runner, mock_inputs, presenter): presenter.runner.interrupt.assert_called_once() -def test_handle_results(presenter): +@patch("RATapi.inputs.make_problem") +def test_handle_results(mock_problem_def, presenter): """Test that results are handed to the view correctly.""" presenter.runner = MagicMock() - presenter.runner.results = "TEST RESULTS" + presenter.runner.updated_problem = ProblemDefinition() + presenter.runner.results = MagicMock() + presenter.runner.results.calculationResults.sumChi = 0.04 presenter.handle_results() - presenter.view.handle_results.assert_called_once_with("TEST RESULTS") + presenter.view.handle_results.assert_called_once_with(presenter.runner.results) def test_stop_run(presenter): diff --git a/tests/test_view.py b/tests/test_view.py index 5051ee2..0163e6b 100644 --- a/tests/test_view.py +++ b/tests/test_view.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest +from PyQt6 import QtWidgets from rascal2.core.settings import MDIGeometries, Settings from rascal2.ui.view import MainWindowView @@ -13,7 +14,8 @@ @pytest.fixture def test_view(): """An instance of MainWindowView.""" - return MainWindowView() + with patch("rascal2.widgets.plotter.FigureCanvas", return_value=QtWidgets.QWidget()): + yield MainWindowView() @pytest.mark.parametrize(