From 908a34b1cb3716ffadc3ca759d37120c35e8c905 Mon Sep 17 00:00:00 2001 From: JosePizarro3 Date: Wed, 5 Jun 2024 11:56:44 +0200 Subject: [PATCH] Moved testing to test_general.py Added explicitly input for methods in ModelSystem and testing Added testing for _set_system_branch_depth and TODO for the functions in these testings which should be implemented in the datamodel --- src/nomad_simulations/general.py | 34 +-- src/nomad_simulations/model_system.py | 62 +++--- tests/conftest.py | 17 +- tests/test_general.py | 292 ++++++++++++++++++++++++++ tests/test_model_system.py | 253 ++++------------------ 5 files changed, 397 insertions(+), 261 deletions(-) create mode 100644 tests/test_general.py diff --git a/src/nomad_simulations/general.py b/src/nomad_simulations/general.py index bbceac87..791addf3 100644 --- a/src/nomad_simulations/general.py +++ b/src/nomad_simulations/general.py @@ -18,18 +18,16 @@ import numpy as np from typing import List -from structlog.stdlib import BoundLogger -from nomad.units import ureg -from nomad.metainfo import SubSection, Quantity, MEnum, Section, Datetime +from nomad.metainfo import SubSection, Quantity, Section, Datetime from nomad.datamodel.metainfo.annotations import ELNAnnotation from nomad.datamodel.data import EntryData from nomad.datamodel.metainfo.basesections import Entity, Activity -from .model_system import ModelSystem -from .model_method import ModelMethod -from .outputs import Outputs -from .utils import is_not_representative, get_composition +from nomad_simulations.model_system import ModelSystem +from nomad_simulations.model_method import ModelMethod +from nomad_simulations.outputs import Outputs +from nomad_simulations.utils import is_not_representative, get_composition class Program(Entity): @@ -178,7 +176,9 @@ def _set_system_branch_depth( ): for system_child in system_parent.model_system: system_child.branch_depth = branch_depth + 1 - self._set_system_branch_depth(system_child, branch_depth + 1) + self._set_system_branch_depth( + system_parent=system_child, branch_depth=branch_depth + 1 + ) def resolve_composition_formula(self, system_parent: ModelSystem) -> None: """Determine and set the composition formula for `system_parent` and all of its @@ -217,7 +217,9 @@ def set_composition_formula( for subsystem in subsystems ] if system.composition_formula is None: - system.composition_formula = get_composition(subsystem_labels) + system.composition_formula = get_composition( + children_names=subsystem_labels + ) def get_composition_recurs(system: ModelSystem, atom_labels: List[str]) -> None: """Traverse the system hierarchy downward and set the branch composition for @@ -229,10 +231,12 @@ def get_composition_recurs(system: ModelSystem, atom_labels: List[str]) -> None: to the atom indices stored in system. """ subsystems = system.model_system - set_composition_formula(system, subsystems, atom_labels) + set_composition_formula( + system=system, subsystems=subsystems, atom_labels=atom_labels + ) if subsystems: for subsystem in subsystems: - get_composition_recurs(subsystem, atom_labels) + get_composition_recurs(system=subsystem, atom_labels=atom_labels) atoms_state = ( system_parent.cell[0].atoms_state if system_parent.cell is not None else [] @@ -242,7 +246,7 @@ def get_composition_recurs(system: ModelSystem, atom_labels: List[str]) -> None: if atoms_state is not None else [] ) - get_composition_recurs(system_parent, atom_labels) + get_composition_recurs(system=system_parent, atom_labels=atom_labels) def normalize(self, archive, logger) -> None: super(EntryData, self).normalize(archive, logger) @@ -263,8 +267,8 @@ def normalize(self, archive, logger) -> None: system_parent.branch_depth = 0 if len(system_parent.model_system) == 0: continue - self._set_system_branch_depth(system_parent) + self._set_system_branch_depth(system_parent=system_parent) - if is_not_representative(system_parent, logger): + if is_not_representative(model_system=system_parent, logger=logger): continue - self.resolve_composition_formula(system_parent) + self.resolve_composition_formula(system_parent=system_parent) diff --git a/src/nomad_simulations/model_system.py b/src/nomad_simulations/model_system.py index 1b6f0842..a1551a29 100644 --- a/src/nomad_simulations/model_system.py +++ b/src/nomad_simulations/model_system.py @@ -19,7 +19,7 @@ import re import numpy as np import ase -from typing import Tuple, Optional, List +from typing import Tuple, Optional from structlog.stdlib import BoundLogger from matid import SymmetryAnalyzer, Classifier # pylint: disable=import-error @@ -39,7 +39,6 @@ Formula, get_normalized_wyckoff, search_aflow_prototype, - get_composition, ) from nomad.metainfo import Quantity, SubSection, SectionProxy, MEnum, Section, Context @@ -47,8 +46,8 @@ from nomad.datamodel.metainfo.basesections import Entity, System from nomad.datamodel.metainfo.annotations import ELNAnnotation -from .atoms_state import AtomsState -from .utils import get_sibling_section, is_not_representative +from nomad_simulations.atoms_state import AtomsState +from nomad_simulations.utils import get_sibling_section, is_not_representative class GeometricSpace(Entity): @@ -185,7 +184,7 @@ def get_geometric_space_for_atomic_cell(self, logger: BoundLogger) -> None: Args: logger (BoundLogger): The logger to log messages. """ - atoms = self.to_ase_atoms(logger) # function defined in AtomicCell + atoms = self.to_ase_atoms(logger=logger) # function defined in AtomicCell cell = atoms.get_cell() self.length_vector_a, self.length_vector_b, self.length_vector_c = ( cell.lengths() * ureg.angstrom @@ -198,7 +197,7 @@ def get_geometric_space_for_atomic_cell(self, logger: BoundLogger) -> None: def normalize(self, archive, logger) -> None: # Skip normalization for `Entity` try: - self.get_geometric_space_for_atomic_cell(logger) + self.get_geometric_space_for_atomic_cell(logger=logger) except Exception: logger.warning( 'Could not extract the geometric space information from ASE Atoms object.', @@ -348,11 +347,11 @@ def to_ase_atoms(self, logger: BoundLogger) -> Optional[ase.Atoms]: 'Could not find `AtomicCell.periodic_boundary_conditions`. They will be set to [False, False, False].' ) self.periodic_boundary_conditions = [False, False, False] - ase_atoms.set_pbc(self.periodic_boundary_conditions) + ase_atoms.set_pbc(pbc=self.periodic_boundary_conditions) # Lattice vectors if self.lattice_vectors is not None: - ase_atoms.set_cell(self.lattice_vectors.to('angstrom').magnitude) + ase_atoms.set_cell(cell=self.lattice_vectors.to('angstrom').magnitude) else: logger.info('Could not find `AtomicCell.lattice_vectors`.') @@ -363,7 +362,9 @@ def to_ase_atoms(self, logger: BoundLogger) -> Optional[ase.Atoms]: 'Length of `AtomicCell.positions` does not coincide with the length of the `AtomicCell.atoms_state`.' ) return None - ase_atoms.set_positions(self.positions.to('angstrom').magnitude) + ase_atoms.set_positions( + newpositions=self.positions.to('angstrom').magnitude + ) else: logger.warning('Could not find `AtomicCell.positions`.') return None @@ -528,7 +529,7 @@ def resolve_analyzed_atomic_cell( ) # ? why do we need to pass units atomic_cell.wyckoff_letters = wyckoff atomic_cell.equivalent_atoms = equivalent_atoms - atomic_cell.get_geometric_space_for_atomic_cell(logger) + atomic_cell.get_geometric_space_for_atomic_cell(logger=logger) return atomic_cell def resolve_bulk_symmetry( @@ -550,7 +551,7 @@ def resolve_bulk_symmetry( """ symmetry = {} try: - ase_atoms = original_atomic_cell.to_ase_atoms(logger) + ase_atoms = original_atomic_cell.to_ase_atoms(logger=logger) symmetry_analyzer = SymmetryAnalyzer( ase_atoms, symmetry_tol=config.normalize.symmetry_tolerance ) @@ -584,12 +585,12 @@ def resolve_bulk_symmetry( # Populating the primitive AtomState information primitive_atomic_cell = self.resolve_analyzed_atomic_cell( - symmetry_analyzer, 'primitive', logger + symmetry_analyzer=symmetry_analyzer, cell_type='primitive', logger=logger ) # Populating the conventional AtomState information conventional_atomic_cell = self.resolve_analyzed_atomic_cell( - symmetry_analyzer, 'conventional', logger + symmetry_analyzer=symmetry_analyzer, cell_type='conventional', logger=logger ) # Getting prototype_formula, prototype_aflow_id, and strukturbericht designation from @@ -606,10 +607,11 @@ def resolve_bulk_symmetry( conventional_wyckoff = conventional_atomic_cell.wyckoff_letters # Normalize wyckoff letters norm_wyckoff = get_normalized_wyckoff( - conventional_num, conventional_wyckoff + atomic_numbers=conventional_num, wyckoff_letters=conventional_wyckoff ) aflow_prototype = search_aflow_prototype( - symmetry.get('space_group_number'), norm_wyckoff + space_group=symmetry.get('space_group_number'), + norm_wyckoff=norm_wyckoff, ) if aflow_prototype: strukturbericht = aflow_prototype.get('Strukturbericht Designation') @@ -642,7 +644,9 @@ def normalize(self, archive, logger) -> None: ( primitive_atomic_cell, conventional_atomic_cell, - ) = self.resolve_bulk_symmetry(atomic_cell, logger) + ) = self.resolve_bulk_symmetry( + original_atomic_cell=atomic_cell, logger=logger + ) self.m_parent.m_add_sub_section(ModelSystem.cell, primitive_atomic_cell) self.m_parent.m_add_sub_section(ModelSystem.cell, conventional_atomic_cell) # Reference to the standarized cell, and if not, fallback to the originally parsed one @@ -712,11 +716,11 @@ def resolve_chemical_formulas(self, formula: Formula) -> None: Args: formula (Formula): The Formula object from NOMAD atomutils containing the chemical formulas. """ - self.descriptive = formula.format('descriptive') - self.reduced = formula.format('reduced') - self.iupac = formula.format('iupac') - self.hill = formula.format('hill') - self.anonymous = formula.format('anonymous') + self.descriptive = formula.format(fmt='descriptive') + self.reduced = formula.format(fmt='reduced') + self.iupac = formula.format(fmt='iupac') + self.hill = formula.format(fmt='hill') + self.anonymous = formula.format(fmt='anonymous') def normalize(self, archive, logger) -> None: super().normalize(archive, logger) @@ -727,10 +731,10 @@ def normalize(self, archive, logger) -> None: if atomic_cell is None: logger.warning('Could not resolve the sibling `AtomicCell` section.') return - ase_atoms = atomic_cell.to_ase_atoms(logger) + ase_atoms = atomic_cell.to_ase_atoms(logger=logger) formula = None try: - formula = Formula(ase_atoms.get_chemical_formula()) + formula = Formula(formula=ase_atoms.get_chemical_formula()) # self.chemical_composition = ase_atoms.get_chemical_formula(mode="all") except ValueError as e: logger.warning( @@ -739,7 +743,7 @@ def normalize(self, archive, logger) -> None: error=str(e), ) if formula: - self.resolve_chemical_formulas(formula) + self.resolve_chemical_formulas(formula=formula) self.m_cache['elemental_composition'] = formula.elemental_composition() @@ -957,7 +961,7 @@ def resolve_system_type_and_dimensionality( radii='covalent', cluster_threshold=config.normalize.cluster_threshold, ) - classification = classifier.classify(ase_atoms) + classification = classifier.classify(input_system=ase_atoms) except Exception as e: logger.warning( 'MatID system classification failed.', exc_info=e, error=str(e) @@ -993,7 +997,7 @@ def normalize(self, archive, logger) -> None: super().normalize(archive, logger) # We don't need to normalize if the system is not representative - if is_not_representative(self, logger): + if is_not_representative(model_system=self, logger=logger): return # Extracting ASE Atoms object from the originally parsed AtomicCell section @@ -1004,7 +1008,7 @@ def normalize(self, archive, logger) -> None: return if self.cell[0].name == 'AtomicCell': self.cell[0].type = 'original' - ase_atoms = self.cell[0].to_ase_atoms(logger) + ase_atoms = self.cell[0].to_ase_atoms(logger=logger) if not ase_atoms: return @@ -1016,7 +1020,9 @@ def normalize(self, archive, logger) -> None: ( self.type, self.dimensionality, - ) = self.resolve_system_type_and_dimensionality(ase_atoms, logger) + ) = self.resolve_system_type_and_dimensionality( + ase_atoms=ase_atoms, logger=logger + ) # Creating and normalizing Symmetry section if self.type == 'bulk' and self.symmetry is not None: sec_symmetry = self.m_create(Symmetry) diff --git a/tests/conftest.py b/tests/conftest.py index 5dce6b7a..2f0baf04 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -120,11 +120,11 @@ def generate_model_system( def generate_atomic_cell( - lattice_vectors: List = [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - positions=None, - periodic_boundary_conditions=None, - chemical_symbols: List = ['H', 'H', 'O'], - atomic_numbers: List = [1, 1, 8], + lattice_vectors: List[List[float]] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + positions: Optional[list] = None, + periodic_boundary_conditions: List[bool] = [False, False, False], + chemical_symbols: List[str] = ['H', 'H', 'O'], + atomic_numbers: List[int] = [1, 1, 8], ) -> AtomicCell: """ Generate an `AtomicCell` section with the given parameters. @@ -133,18 +133,13 @@ def generate_atomic_cell( if positions is None and chemical_symbols is not None: n_atoms = len(chemical_symbols) positions = [[i / n_atoms, i / n_atoms, i / n_atoms] for i in range(n_atoms)] - # Define periodic boundary conditions if not provided - if periodic_boundary_conditions is None: - periodic_boundary_conditions = [False, False, False] # Define the atomic cell - atomic_cell = AtomicCell() + atomic_cell = AtomicCell(periodic_boundary_conditions=periodic_boundary_conditions) if lattice_vectors: atomic_cell.lattice_vectors = lattice_vectors * ureg('angstrom') if positions: atomic_cell.positions = positions * ureg('angstrom') - if periodic_boundary_conditions: - atomic_cell.periodic_boundary_conditions = periodic_boundary_conditions # Add the elements information for index, atom in enumerate(chemical_symbols): diff --git a/tests/test_general.py b/tests/test_general.py new file mode 100644 index 00000000..8d83b3c6 --- /dev/null +++ b/tests/test_general.py @@ -0,0 +1,292 @@ +# +# Copyright The NOMAD Authors. +# +# This file is part of NOMAD. See https://nomad-lab.eu for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from typing import List + +from nomad.datamodel import EntryArchive + +from . import logger + +from nomad_simulations.general import Simulation +from nomad_simulations.model_system import ( + ModelSystem, + AtomicCell, + AtomsState, +) + + +class TestSimulation: + """ + Test the `Simulation` class defined in general.py + """ + + @pytest.mark.parametrize( + 'system, result', + [ + ([ModelSystem(name='depth 0')], [0]), + ( + [ + ModelSystem( + name='depth 0', + model_system=[ + ModelSystem(name='depth 1'), + ModelSystem(name='depth 1'), + ], + ) + ], + [0, 1, 1], + ), + ( + [ + ModelSystem( + name='depth 0', + model_system=[ + ModelSystem( + name='depth 1', + model_system=[ModelSystem(name='depth 2')], + ), + ModelSystem(name='depth 1'), + ], + ) + ], + [0, 1, 2, 1], + ), + ], + ) + def test_set_system_branch_depth( + self, system: List[ModelSystem], result: List[int] + ): + """ + Test the `_set_system_branch_depth` method. + + Args: + system (List[ModelSystem]): The system hierarchy to set the branch depths for. + result (List[int]): The expected branch depths for each system in the hierarchy. + """ + simulation = Simulation(model_system=system) + for system_parent in simulation.model_system: + system_parent.branch_depth = 0 + if len(system_parent.model_system) == 0: + continue + simulation._set_system_branch_depth(system_parent=system_parent) + + # TODO move this into its own method to handle `ModelSystem` hierarchies (see below `get_system_recurs`) + def get_flat_depths( + system_parent: ModelSystem, quantity_name: str, value: list = [] + ): + for system_child in system_parent.model_system: + val = getattr(system_child, quantity_name) + value.append(val) + get_flat_depths( + system_parent=system_child, quantity_name=quantity_name, value=value + ) + return value + + value = get_flat_depths( + system_parent=simulation.model_system[0], + quantity_name='branch_depth', + value=[0], + ) + assert value == result + + @pytest.mark.parametrize( + 'is_representative, has_atom_indices, mol_label_list, n_mol_list, atom_labels_list, composition_formula_list, custom_formulas', + [ + ( + True, + True, + ['H20'], + [3], + [['H', 'O', 'O']], + ['group_H20(1)', 'H20(3)', 'H(1)O(2)', 'H(1)O(2)', 'H(1)O(2)'], + [None, None, None, None, None], + ), # pure system + ( + False, + True, + ['H20'], + [3], + [['H', 'O', 'O']], + [None, None, None, None, None], + [None, None, None, None, None], + ), # non-representative system + ( + True, + True, + [None], + [3], + [['H', 'O', 'O']], + ['Unknown(1)', 'Unknown(3)', 'H(1)O(2)', 'H(1)O(2)', 'H(1)O(2)'], + [None, None, None, None, None], + ), # missing branch labels + ( + True, + True, + ['H20'], + [3], + [[None, None, None]], + ['group_H20(1)', 'H20(3)', 'Unknown(3)', 'Unknown(3)', 'Unknown(3)'], + [None, None, None, None, None], + ), # missing atom labels + ( + True, + False, + ['H20'], + [3], + [['H', 'O', 'O']], + ['group_H20(1)', 'H20(3)', None, None, None], + [None, None, None, None, None], + ), # missing atom indices + ( + True, + True, + ['H20'], + [3], + [['H', 'O', 'O']], + ['waters(1)', 'water_molecules(3)', 'H(1)O(2)', 'H(1)O(2)', 'H(1)O(2)'], + ['waters(1)', 'water_molecules(3)', None, None, None], + ), # custom formulas + ( + True, + True, + ['H20', 'Methane'], + [5, 2], + [['H', 'O', 'O'], ['C', 'H', 'H', 'H', 'H']], + [ + 'group_H20(1)group_Methane(1)', + 'H20(5)', + 'H(1)O(2)', + 'H(1)O(2)', + 'H(1)O(2)', + 'H(1)O(2)', + 'H(1)O(2)', + 'Methane(2)', + 'C(1)H(4)', + 'C(1)H(4)', + ], + [None, None, None, None, None, None, None, None, None, None], + ), # binary mixture + ], + ) + def test_system_hierarchy_for_molecules( + self, + is_representative: bool, + has_atom_indices: bool, + mol_label_list: List[str], + n_mol_list: List[int], + atom_labels_list: List[str], + composition_formula_list: List[str], + custom_formulas: List[str], + ): + """ + Test the `Simulation` normalization for obtaining `Model.System.composition_formula` for atoms and molecules. + + Args: + is_representative (bool): Specifies if branch_depth = 0 is representative or not. + If not representative, the composition formulas should not be generated. + has_atom_indices (bool): Specifies if the atom_indices should be populated during parsing. + Without atom_indices, the composition formulas for the deepest level of the hierarchy + should not be populated. + mol_label_list (List[str]): Molecule types for generating the hierarchy. + n_mol_list (List[int]): Number of molecules for each molecule type. Should be same + length as mol_label_list. + atom_labels_list (List[str]): Atom labels for each molecule type. Should be same length as + mol_label_list, with each entry being a list of corresponding atom labels. + composition_formula_list (List[str]): Resulting composition formulas after normalization. The + ordering is dictated by the recursive traversing of the hierarchy in get_system_recurs(), + which follows each branch to its deepest level before moving to the next branch, i.e., + [model_system.composition_formula, + model_system.model_system[0].composition_formula], + model_system.model_system[0].model_system[0].composition_formula, + model_system.model_system[0].model_system[1].composition_formula, ..., + model_system.model_system[1].composition_formula, ...] + custom_formulas (List[str]): Custom composition formulas that can be set in the generation + of the hierarchy, which will cause the normalize to ignore (i.e., not overwrite) these formula entries. + The ordering is as described above. + """ + + ### Generate the system hierarchy ### + simulation = Simulation() + model_system = ModelSystem(is_representative=True) + simulation.model_system.append(model_system) + model_system.branch_label = 'Total System' + model_system.is_representative = is_representative + model_system.composition_formula = custom_formulas[0] + ctr_comp = 1 + atomic_cell = AtomicCell() + model_system.cell.append(atomic_cell) + if has_atom_indices: + model_system.atom_indices = [] + for mol_label, n_mol, atom_labels in zip( + mol_label_list, n_mol_list, atom_labels_list + ): + # Create a branch in the hierarchy for this molecule type + model_system_mol_group = ModelSystem() + if has_atom_indices: + model_system_mol_group.atom_indices = [] + model_system_mol_group.branch_label = ( + f'group_{mol_label}' if mol_label is not None else None + ) + model_system_mol_group.composition_formula = custom_formulas[ctr_comp] + ctr_comp += 1 + model_system.model_system.append(model_system_mol_group) + for _ in range(n_mol): + # Create a branch in the hierarchy for this molecule + model_system_mol = ModelSystem(branch_label=mol_label) + model_system_mol.branch_label = mol_label + model_system_mol.composition_formula = custom_formulas[ctr_comp] + ctr_comp += 1 + model_system_mol_group.model_system.append(model_system_mol) + # add the corresponding atoms to the global atom list + for atom_label in atom_labels: + if atom_label is not None: + atomic_cell.atoms_state.append( + AtomsState(chemical_symbol=atom_label) + ) + n_atoms = len(atomic_cell.atoms_state) + atom_indices = np.arange(n_atoms - len(atom_labels), n_atoms) + if has_atom_indices: + model_system_mol.atom_indices = atom_indices + model_system_mol_group.atom_indices = np.append( + model_system_mol_group.atom_indices, atom_indices + ) + model_system.atom_indices = np.append( + model_system.atom_indices, atom_indices + ) + + simulation.normalize(EntryArchive(), logger) + + ### Traverse the hierarchy recursively and check the results ### + assert model_system.composition_formula == composition_formula_list[0] + ctr_comp = 1 + + # TODO move this into its own method to handle `ModelSystem` hierarchies (see above `get_flat_depths`) + def get_system_recurs(systems: List[ModelSystem], ctr_comp: int) -> int: + for sys in systems: + assert sys.composition_formula == composition_formula_list[ctr_comp] + ctr_comp += 1 + subsystems = sys.model_system + if subsystems: + ctr_comp = get_system_recurs(subsystems, ctr_comp) + return ctr_comp + + new_ctr_comp = get_system_recurs( + systems=model_system.model_system, ctr_comp=ctr_comp + ) diff --git a/tests/test_model_system.py b/tests/test_model_system.py index 34caf9c3..c317daf7 100644 --- a/tests/test_model_system.py +++ b/tests/test_model_system.py @@ -25,14 +25,7 @@ from . import logger from .conftest import generate_atomic_cell -from nomad_simulations.model_system import ( - Symmetry, - ChemicalFormula, - ModelSystem, - AtomicCell, - AtomsState, -) -from nomad_simulations.general import Simulation +from nomad_simulations.model_system import Symmetry, ChemicalFormula, ModelSystem class TestAtomicCell: @@ -96,17 +89,25 @@ def test_generate_ase_atoms( ): """ Test the creation of `ase.Atoms` from `AtomicCell`. + + Args: + chemical_symbols (List[str]): List of chemical symbols. + atomic_numbers (List[int]): List of atomic numbers. + formula (str): Chemical formula. + lattice_vectors (List[List[float]]): Lattice vectors. + positions (List[List[float]]): Atomic positions. + periodic_boundary_conditions (List[bool]): Periodic boundary conditions. """ atomic_cell = generate_atomic_cell( - lattice_vectors, - positions, - periodic_boundary_conditions, - chemical_symbols, - atomic_numbers, + lattice_vectors=lattice_vectors, + positions=positions, + periodic_boundary_conditions=periodic_boundary_conditions, + chemical_symbols=chemical_symbols, + atomic_numbers=atomic_numbers, ) # Test `to_ase_atoms` function - ase_atoms = atomic_cell.to_ase_atoms(logger) + ase_atoms = atomic_cell.to_ase_atoms(logger=logger) if not chemical_symbols or len(chemical_symbols) != len(positions): assert ase_atoms is None else: @@ -190,14 +191,21 @@ def test_geometric_space( ): """ Test the `GeometricSpace` quantities normalization from `AtomicCell`. + + Args: + chemical_symbols (List[str]): List of chemical symbols. + atomic_numbers (List[int]): List of atomic numbers. + lattice_vectors (List[List[float]]): Lattice vectors. + positions (List[List[float]]): Atomic positions. + vectors_results (List[Optional[float]]): Expected lengths of cell vectors. + angles_results (List[Optional[float]]): Expected angles between cell vectors. + volume (Optional[float]): Expected volume of the cell. """ - pbc = [False, False, False] atomic_cell = generate_atomic_cell( - lattice_vectors, - positions, - pbc, - chemical_symbols, - atomic_numbers, + lattice_vectors=lattice_vectors, + positions=positions, + chemical_symbols=chemical_symbols, + atomic_numbers=atomic_numbers, ) # Get `GeometricSpace` quantities via normalization of `AtomicCell` @@ -275,6 +283,11 @@ def test_chemical_formula( ): """ Test the `ChemicalFormula` normalization if a sibling `AtomicCell` is created, and thus the `Formula` class can be used. + + Args: + chemical_symbols (List[str]): List of chemical symbols. + atomic_numbers (List[int]): List of atomic numbers. + formulas (List[str]): List of expected formulas. """ atomic_cell = generate_atomic_cell( chemical_symbols=chemical_symbols, atomic_numbers=atomic_numbers @@ -332,17 +345,25 @@ def test_system_type_and_dimensionality( ): """ Test the `ModelSystem` normalization of `type` and `dimensionality` from `AtomicCell`. + + Args: + positions (List[List[float]]): Atomic positions. + pbc (Optional[List[bool]]): Periodic boundary conditions. + system_type (str): Expected system type. + dimensionality (int): Expected dimensionality. """ atomic_cell = generate_atomic_cell( positions=positions, periodic_boundary_conditions=pbc ) - ase_atoms = atomic_cell.to_ase_atoms(logger) + ase_atoms = atomic_cell.to_ase_atoms(logger=logger) model_system = ModelSystem() model_system.cell.append(atomic_cell) ( resolved_system_type, resolved_dimensionality, - ) = model_system.resolve_system_type_and_dimensionality(ase_atoms, logger) + ) = model_system.resolve_system_type_and_dimensionality( + ase_atoms=ase_atoms, logger=logger + ) assert resolved_system_type == system_type assert resolved_dimensionality == dimensionality @@ -360,7 +381,9 @@ def test_symmetry(self): ) ).all() symmetry = Symmetry() - primitive, conventional = symmetry.resolve_bulk_symmetry(atomic_cell, logger) + primitive, conventional = symmetry.resolve_bulk_symmetry( + original_atomic_cell=atomic_cell, logger=logger + ) assert symmetry.bravais_lattice == 'hR' assert symmetry.hall_symbol == '-R 3 2"' assert symmetry.point_group_symbol == '-3m' @@ -442,187 +465,3 @@ def test_normalize(self): assert np.isclose(model_system.elemental_composition[0].atomic_fraction, 2 / 3) assert model_system.elemental_composition[1].element == 'O' assert np.isclose(model_system.elemental_composition[1].atomic_fraction, 1 / 3) - - @pytest.mark.parametrize( - 'is_representative, has_atom_indices, mol_label_list, n_mol_list, atom_labels_list, composition_formula_list, custom_formulas', - [ - ( - True, - True, - ['H20'], - [3], - [['H', 'O', 'O']], - ['group_H20(1)', 'H20(3)', 'H(1)O(2)', 'H(1)O(2)', 'H(1)O(2)'], - [None, None, None, None, None], - ), # pure system - ( - False, - True, - ['H20'], - [3], - [['H', 'O', 'O']], - [None, None, None, None, None], - [None, None, None, None, None], - ), # non-representative system - ( - True, - True, - [None], - [3], - [['H', 'O', 'O']], - ['Unknown(1)', 'Unknown(3)', 'H(1)O(2)', 'H(1)O(2)', 'H(1)O(2)'], - [None, None, None, None, None], - ), # missing branch labels - ( - True, - True, - ['H20'], - [3], - [[None, None, None]], - ['group_H20(1)', 'H20(3)', 'Unknown(3)', 'Unknown(3)', 'Unknown(3)'], - [None, None, None, None, None], - ), # missing atom labels - ( - True, - False, - ['H20'], - [3], - [['H', 'O', 'O']], - ['group_H20(1)', 'H20(3)', None, None, None], - [None, None, None, None, None], - ), # missing atom indices - ( - True, - True, - ['H20'], - [3], - [['H', 'O', 'O']], - ['waters(1)', 'water_molecules(3)', 'H(1)O(2)', 'H(1)O(2)', 'H(1)O(2)'], - ['waters(1)', 'water_molecules(3)', None, None, None], - ), # custom formulas - ( - True, - True, - ['H20', 'Methane'], - [5, 2], - [['H', 'O', 'O'], ['C', 'H', 'H', 'H', 'H']], - [ - 'group_H20(1)group_Methane(1)', - 'H20(5)', - 'H(1)O(2)', - 'H(1)O(2)', - 'H(1)O(2)', - 'H(1)O(2)', - 'H(1)O(2)', - 'Methane(2)', - 'C(1)H(4)', - 'C(1)H(4)', - ], - [None, None, None, None, None, None, None, None, None, None], - ), # binary mixture - ], - ) - def test_system_hierarchy_for_molecules( - self, - is_representative: bool, - has_atom_indices: bool, - mol_label_list: List[str], - n_mol_list: List[int], - atom_labels_list: List[str], - composition_formula_list: List[str], - custom_formulas: List[str], - ): - """Test the `ModelSystem` normalization of 'composition_formula' for atoms and molecules. - - Args: - is_representative (bool): Specifies if branch_depth = 0 is representative or not. - If not representative, the composition formulas should not be generated. - has_atom_indices (bool): Specifies if the atom_indices should be populated during parsing. - Without atom_indices, the composition formulas for the deepest level of the hierarchy - should not be populated. - mol_label_list (List[str]): Molecule types for generating the hierarchy. - n_mol_list (List[int]): Number of molecules for each molecule type. Should be same - length as mol_label_list. - atom_labels_list (List[str]): Atom labels for each molecule type. Should be same length as - mol_label_list, with each entry being a list of corresponding atom labels. - composition_formula_list (List[str]): Resulting composition formulas after normalization. The - ordering is dictated by the recursive traversing of the hierarchy in get_system_recurs(), - which follows each branch to its deepest level before moving to the next branch, i.e., - [model_system.composition_formula, - model_system.model_system[0].composition_formula], - model_system.model_system[0].model_system[0].composition_formula, - model_system.model_system[0].model_system[1].composition_formula, ..., - model_system.model_system[1].composition_formula, ...] - custom_formulas (List[str]): Custom composition formulas that can be set in the generation - of the hierarchy, which will cause the normalize to ignore (i.e., not overwrite) these formula entries. - The ordering is as described above. - - Returns: - None - """ - - ### Generate the system hierarchy ### - simulation = Simulation() - model_system = ModelSystem(is_representative=True) - simulation.model_system.append(model_system) - model_system.branch_label = 'Total System' - model_system.is_representative = is_representative - model_system.composition_formula = custom_formulas[0] - ctr_comp = 1 - atomic_cell = AtomicCell() - model_system.cell.append(atomic_cell) - if has_atom_indices: - model_system.atom_indices = [] - for mol_label, n_mol, atom_labels in zip( - mol_label_list, n_mol_list, atom_labels_list - ): - # Create a branch in the hierarchy for this molecule type - model_system_mol_group = ModelSystem() - if has_atom_indices: - model_system_mol_group.atom_indices = [] - model_system_mol_group.branch_label = ( - f'group_{mol_label}' if mol_label is not None else None - ) - model_system_mol_group.composition_formula = custom_formulas[ctr_comp] - ctr_comp += 1 - model_system.model_system.append(model_system_mol_group) - for _ in range(n_mol): - # Create a branch in the hierarchy for this molecule - model_system_mol = ModelSystem(branch_label=mol_label) - model_system_mol.branch_label = mol_label - model_system_mol.composition_formula = custom_formulas[ctr_comp] - ctr_comp += 1 - model_system_mol_group.model_system.append(model_system_mol) - # add the corresponding atoms to the global atom list - for atom_label in atom_labels: - if atom_label is not None: - atomic_cell.atoms_state.append( - AtomsState(chemical_symbol=atom_label) - ) - n_atoms = len(atomic_cell.atoms_state) - atom_indices = np.arange(n_atoms - len(atom_labels), n_atoms) - if has_atom_indices: - model_system_mol.atom_indices = atom_indices - model_system_mol_group.atom_indices = np.append( - model_system_mol_group.atom_indices, atom_indices - ) - model_system.atom_indices = np.append( - model_system.atom_indices, atom_indices - ) - - simulation.normalize(EntryArchive(), logger) - - ### Traverse the hierarchy recursively and check the results ### - assert model_system.composition_formula == composition_formula_list[0] - ctr_comp = 1 - - def get_system_recurs(sec_system, ctr_comp): - for sys in sec_system: - assert sys.composition_formula == composition_formula_list[ctr_comp] - ctr_comp += 1 - sec_subsystem = sys.model_system - if sec_subsystem: - ctr_comp = get_system_recurs(sec_subsystem, ctr_comp) - return ctr_comp - - get_system_recurs(model_system.model_system, ctr_comp)