diff --git a/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py b/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py index c60094f8..d6735db8 100644 --- a/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py +++ b/src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py @@ -23,18 +23,15 @@ from nomad.datamodel.datamodel import EntryArchive from structlog.stdlib import BoundLogger - from nomad_simulations.schema_packages.workflow import SinglePoint - from nomad.datamodel.metainfo.workflow import Link, TaskReference from nomad.metainfo import Quantity, Reference -from nomad_simulations.schema_packages.model_method import DFT, TB, ModelMethod -from nomad_simulations.schema_packages.workflow import ( - BeyondDFT, - BeyondDFTMethod, -) +from nomad_simulations.schema_packages.model_method import DFT, TB +from nomad_simulations.schema_packages.workflow import BeyondDFT, BeyondDFTMethod from nomad_simulations.schema_packages.workflow.base_workflows import check_n_tasks +from .single_point import SinglePoint + class DFTPlusTBMethod(BeyondDFTMethod): """ @@ -106,7 +103,15 @@ class DFTPlusTB(BeyondDFT): """ @check_n_tasks(n_tasks=2) - def link_task_inputs_outputs(self, tasks: list[TaskReference]) -> None: + def link_task_inputs_outputs( + self, tasks: list[TaskReference], logger: 'BoundLogger' + ) -> None: + if not self.inputs or not self.outputs: + logger.warning( + 'The `DFTPlusTB` workflow needs to have `inputs` and `outputs` defined in order to link with the `tasks`.' + ) + return None + dft_task = tasks[0] tb_task = tasks[1] @@ -144,7 +149,7 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: 'A `DFTPlusTB` workflow must have two `SinglePoint` tasks references.' ) return - if not isinstance(task.task, 'SinglePoint'): + if not isinstance(task.task, SinglePoint): logger.error( 'The referenced tasks in the `DFTPlusTB` workflow must be of type `SinglePoint`.' ) @@ -158,11 +163,14 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: tasks=self.tasks, tasks_names=['DFT SinglePoint Task', 'TB SinglePoint Task'], ) - if method_refs is not None and len(method_refs) == 2: - self.method = DFTPlusTBMethod( - dft_method_ref=method_refs[0], - tb_method_ref=method_refs[1], - ) + if method_refs is not None: + method_workflow = DFTPlusTBMethod() + for method in method_refs: + if isinstance(method, DFT): + method_workflow.dft_method_ref = method + elif isinstance(method, TB): + method_workflow.tb_method_ref = method + self.method = method_workflow # Resolve `tasks[*].inputs` and `tasks[*].outputs` - self.link_task_inputs_outputs(tasks=self.tasks) + self.link_task_inputs_outputs(tasks=self.tasks, logger=logger) diff --git a/tests/workflow/test_base_workflows.py b/tests/workflow/test_base_workflows.py index a961f4a4..99a97562 100644 --- a/tests/workflow/test_base_workflows.py +++ b/tests/workflow/test_base_workflows.py @@ -84,12 +84,16 @@ def test_resolve_all_outputs(self, tasks: list[Task], result: list[Outputs]): [ # no task (None, None), + # only one task ([TaskReference()], []), + # two empty tasks ([TaskReference(), TaskReference()], []), + # two tasks with only empty task ( [TaskReference(task=SinglePoint()), TaskReference(task=SinglePoint())], [], ), + # two tasks with task with one input ModelSystem each ( [ TaskReference( @@ -101,6 +105,7 @@ def test_resolve_all_outputs(self, tasks: list[Task], result: list[Outputs]): ], [], ), + # two tasks with task with one input ModelSystem each and only DFT input ( [ TaskReference( @@ -121,6 +126,7 @@ def test_resolve_all_outputs(self, tasks: list[Task], result: list[Outputs]): ], [DFT], ), + # two tasks with task with inputs for ModelSystem and DFT and TB ( [ TaskReference( diff --git a/tests/workflow/test_dft_plus_tb.py b/tests/workflow/test_dft_plus_tb.py index f74dafaa..f37bc4b0 100644 --- a/tests/workflow/test_dft_plus_tb.py +++ b/tests/workflow/test_dft_plus_tb.py @@ -37,14 +37,330 @@ class TestDFTPlusTB: - def test_link_task_inputs_outputs(self): + @pytest.mark.parametrize( + 'inputs, outputs, tasks, result_tasks', + [ + # no inputs, outputs, tasks + (None, None, None, []), + # only 1 task + (None, None, [TaskReference()], []), + # empty tasks + ( + None, + None, + [TaskReference(), TaskReference()], + [], + ), + # only one task is populated + ( + None, + None, + [ + TaskReference(task=SinglePoint()), + TaskReference(), + ], + [], + ), + # only one task is populated with inputs + ( + None, + None, + [ + TaskReference(task=SinglePoint(inputs=[Link()])), + TaskReference(task=SinglePoint()), + ], + [], + ), + # only one task is populated with outputs + ( + None, + None, + [ + TaskReference(task=SinglePoint(outputs=[Link(name='output dft')])), + TaskReference(task=SinglePoint()), + ], + [], + ), + # positive testing + ( + [Link(name='input system')], + [Link(name='output tb')], + [ + TaskReference(task=SinglePoint(outputs=[Link(name='output dft')])), + TaskReference(task=SinglePoint()), + ], + [ + TaskReference( + task=SinglePoint(outputs=[Link(name='output dft')]), + inputs=[Link(name='Input Model System')], + outputs=[Link(name='Output DFT Data')], + ), + TaskReference( + task=SinglePoint(), + inputs=[Link(name='Output DFT Data')], + outputs=[Link(name='Output TB Data')], + ), + ], + ), + ], + ) + def test_link_task_inputs_outputs( + self, + inputs: list[Link], + outputs: list[Link], + tasks: list[TaskReference], + result_tasks: list[TaskReference], + ): """ Test the `link_task_inputs_outputs` method of the `DFTPlusTB` section. """ - assert True + workflow = DFTPlusTB() + workflow.tasks = tasks + workflow.inputs = inputs + workflow.outputs = outputs - def test_normalize(self): + workflow.link_task_inputs_outputs(tasks=workflow.tasks, logger=logger) + + if not result_tasks: + assert not workflow.m_xpath('tasks[0].inputs') and not workflow.m_xpath( + 'tasks[0].outputs' + ) + assert not workflow.m_xpath('tasks[1].inputs') and not workflow.m_xpath( + 'tasks[1].outputs' + ) + else: + for i, task in enumerate(workflow.tasks): + assert task.inputs[0].name == result_tasks[i].inputs[0].name + assert task.outputs[0].name == result_tasks[i].outputs[0].name + + @pytest.mark.parametrize( + 'inputs, outputs, tasks, result_name, result_methods, result_tasks', + [ + # all none + (None, None, None, None, None, []), + # only one task + (None, None, [TaskReference()], None, None, []), + # two empty tasks + (None, None, [TaskReference(), TaskReference()], None, None, []), + # only one task has a task + ( + None, + None, + [TaskReference(task=SinglePoint()), TaskReference()], + None, + None, + [], + ), + # both tasks with empty task sections, one is not SinglePoint + ( + None, + None, + [TaskReference(task=DFTPlusTB()), TaskReference(task=SinglePoint())], + None, + None, + [], + ), + # both tasks with empty SinglePoint task sections; name is resolved + ( + None, + None, + [TaskReference(task=SinglePoint()), TaskReference(task=SinglePoint())], + 'DFT+TB', + None, + [], + ), + # both tasks have input for ModelSystem + ( + None, + None, + [ + TaskReference( + task=SinglePoint( + inputs=[Link(name='input system', section=ModelSystem())] + ) + ), + TaskReference( + task=SinglePoint( + inputs=[Link(name='input system', section=ModelSystem())] + ) + ), + ], + 'DFT+TB', + None, + [], + ), + # one task has an input with a ref to DFT section + ( + None, + None, + [ + TaskReference( + task=SinglePoint( + inputs=[ + Link(name='input system', section=ModelSystem()), + Link(name='dft method', section=DFT()), + ] + ) + ), + TaskReference( + task=SinglePoint( + inputs=[Link(name='input system', section=ModelSystem())] + ) + ), + ], + 'DFT+TB', + [DFT], + [], + ), + # both tasks have inputs with refs to DFT and TB sections + ( + None, + None, + [ + TaskReference( + task=SinglePoint( + inputs=[ + Link(name='input system', section=ModelSystem()), + Link(name='dft method', section=DFT()), + ] + ) + ), + TaskReference( + task=SinglePoint( + inputs=[ + Link(name='input system', section=ModelSystem()), + Link(name='tb method', section=TB()), + ] + ) + ), + ], + 'DFT+TB', + [DFT, TB], + [], + ), + # one task has an output, but the workflow inputs and outputs are empty + ( + None, + None, + [ + TaskReference( + task=SinglePoint( + inputs=[ + Link(name='input system', section=ModelSystem()), + Link(name='dft method', section=DFT()), + ], + outputs=[Link(name='output dft', section=Outputs())], + ) + ), + TaskReference( + task=SinglePoint( + inputs=[ + Link(name='input system', section=ModelSystem()), + Link(name='tb method', section=TB()), + ], + ) + ), + ], + 'DFT+TB', + [DFT, TB], + [], + ), + # positive testing + ( + [Link(name='input system')], + [Link(name='output tb')], + [ + TaskReference( + task=SinglePoint( + inputs=[ + Link(name='input system', section=ModelSystem()), + Link(name='dft method', section=DFT()), + ], + outputs=[Link(name='output dft', section=Outputs())], + ) + ), + TaskReference( + task=SinglePoint( + inputs=[ + Link(name='input system', section=ModelSystem()), + Link(name='tb method', section=TB()), + ], + outputs=[Link(name='output tb', section=Outputs())], + ) + ), + ], + 'DFT+TB', + [DFT, TB], + [ + TaskReference( + task=SinglePoint(outputs=[Link(name='output dft')]), + inputs=[Link(name='Input Model System')], + outputs=[Link(name='Output DFT Data')], + ), + TaskReference( + task=SinglePoint(), + inputs=[Link(name='Output DFT Data')], + outputs=[Link(name='Output TB Data')], + ), + ], + ), + ], + ) + def test_normalize( + self, + inputs: list[Link], + outputs: list[Link], + tasks: list[TaskReference], + result_name: Optional[str], + result_methods: Optional[list[ModelMethod]], + result_tasks: Optional[list[TaskReference]], + ): """ Test the `normalize` method of the `DFTPlusTB` section. """ - assert True + archive = EntryArchive() + + # Add `Simulation` to archive + simulation = generate_simulation( + model_system=ModelSystem(), model_method=ModelMethod(), outputs=Outputs() + ) + archive.data = simulation + + # Add `SinglePoint` to archive + workflow = DFTPlusTB() + workflow.inputs = inputs + workflow.outputs = outputs + workflow.tasks = tasks + archive.workflow2 = workflow + + workflow.normalize(archive=archive, logger=logger) + + # Test `name` of the workflow + assert workflow.name == result_name + + # Test `method` of the workflow + if len(result_tasks) > 0: + assert workflow.tasks[0].name == 'DFT SinglePoint Task' + assert workflow.tasks[1].name == 'TB SinglePoint Task' + if not result_methods: + assert not workflow.m_xpath( + 'method.dft_method_ref' + ) and not workflow.m_xpath('method.tb_method_ref') + else: + # ! comparing directly does not work becasue one is a section, the other a reference + assert isinstance(workflow.method.dft_method_ref, result_methods[0]) + if len(result_methods) == 2: + assert isinstance(workflow.method.tb_method_ref, result_methods[1]) + + # Test `tasks` of the workflow + if not result_tasks: + assert not workflow.m_xpath('tasks[0].inputs') and not workflow.m_xpath( + 'tasks[0].outputs' + ) + assert not workflow.m_xpath('tasks[1].inputs') and not workflow.m_xpath( + 'tasks[1].outputs' + ) + else: + for i, task in enumerate(workflow.tasks): + assert task.inputs[0].name == result_tasks[i].inputs[0].name + assert task.outputs[0].name == result_tasks[i].outputs[0].name