diff --git a/src/nomad_simulations/model_method.py b/src/nomad_simulations/model_method.py index f2d4ce40..68619892 100644 --- a/src/nomad_simulations/model_method.py +++ b/src/nomad_simulations/model_method.py @@ -31,10 +31,10 @@ Context, ) -from .numerical_settings import NumericalSettings -from .model_system import ModelSystem -from .atoms_state import OrbitalsState, CoreHole -from .utils import is_not_representative +from nomad_simulations.numerical_settings import NumericalSettings +from nomad_simulations.model_system import ModelSystem, AtomicCell +from nomad_simulations.atoms_state import OrbitalsState, CoreHole +from nomad_simulations.utils import is_not_representative class ModelMethod(ArchiveSection): @@ -448,14 +448,11 @@ class TB(ModelMethodElectronic): """, ) - def resolve_type(self, logger: BoundLogger) -> Optional[str]: + def resolve_type(self) -> Optional[str]: """ Resolves the `type` of the `TB` section if it is not already defined, and from the `m_def.name` of the section. - Args: - logger (BoundLogger): The logger to log messages. - Returns: (Optional[str]): The resolved `type` of the `TB` section. """ @@ -482,23 +479,29 @@ def resolve_orbital_references( Returns: Optional[List[OrbitalsState]]: The resolved references to the `OrbitalsState` sections. """ - model_system = model_systems[model_index] + try: + model_system = model_systems[model_index] + except IndexError: + logger.warning( + f'The `ModelSystem` section with index {model_index} was not found.' + ) + return None # If `ModelSystem` is not representative, the normalization will not run - if is_not_representative(model_system, logger): + if is_not_representative(model_system=model_system, logger=logger): return None # If `AtomicCell` is not found, the normalization will not run - atomic_cell = model_system.cell[0] - if atomic_cell is None: + if not model_system.cell: logger.warning('`AtomicCell` section was not found.') return None + atomic_cell = model_system.cell[0] # If there is no child `ModelSystem`, the normalization will not run atoms_state = atomic_cell.atoms_state model_system_child = model_system.model_system - if model_system_child is None: - logger.warning('No child `ModelSystem` section was found.') + if not atoms_state or not model_system_child: + logger.warning('No `AtomsState` or child `ModelSystem` section were found.') return None # We flatten the `OrbitalsState` sections from the `ModelSystem` section @@ -509,7 +512,13 @@ def resolve_orbital_references( continue indices = active_atom.atom_indices for index in indices: - active_atoms_state = atoms_state[index] + try: + active_atoms_state = atoms_state[index] + except IndexError: + logger.warning( + f'The `AtomsState` section with index {index} was not found.' + ) + continue orbitals_state = active_atoms_state.orbitals_state for orbital in orbitals_state: orbitals_ref.append(orbital) @@ -522,11 +531,7 @@ def normalize(self, archive, logger) -> None: self.name = 'TB' # Resolve `type` to be defined by the lower level class (Wannier, DFTB, xTB or SlaterKoster) if it is not already defined - self.type = ( - self.resolve_type(logger) - if (self.type is None or self.type == 'unavailable') - else self.type - ) + self.type = self.resolve_type() # Resolve `orbitals_ref` from the info in the child `ModelSystem` section and the `OrbitalsState` sections model_systems = self.m_xpath('m_parent.model_system', dict=False) @@ -536,8 +541,10 @@ def normalize(self, archive, logger) -> None: ) return # This normalization only considers the last `ModelSystem` (default `model_index` argument set to -1) - orbitals_ref = self.resolve_orbital_references(model_systems, logger) - if orbitals_ref is not None and self.orbitals_ref is None: + orbitals_ref = self.resolve_orbital_references( + model_systems=model_systems, logger=logger + ) + if orbitals_ref is not None and len(orbitals_ref) > 0 and not self.orbitals_ref: self.n_orbitals = len(orbitals_ref) self.orbitals_ref = orbitals_ref @@ -586,32 +593,16 @@ class Wannier(TB): """, ) - def resolve_localization_type(self, logger: BoundLogger) -> Optional[str]: - """ - Resolves the `localization_type` of the `Wannier` section if it is not already defined, and from the - `is_maximally_localized` boolean. - - Args: - logger (BoundLogger): The logger to log messages. - - Returns: - (Optional[str]): The resolved `localization_type` of the `Wannier` section. - """ - if self.localization_type is None: - if self.is_maximally_localized: - return 'maximally_localized' - else: - return 'single_shot' - logger.info( - 'Could not find if the Wannier tight-binding model is maximally localized or not.' - ) - return None - def normalize(self, archive, logger): super().normalize(archive, logger) # Resolve `localization_type` from `is_maximally_localized` - self.localization_type = self.resolve_localization_type(logger) + if self.localization_type is None: + if self.is_maximally_localized is not None: + if self.is_maximally_localized: + self.localization_type = 'maximally_localized' + else: + self.localization_type = 'single_shot' class SlaterKosterBond(ArchiveSection): @@ -680,29 +671,39 @@ def __init__(self, m_def: Section = None, m_context: Context = None, **kwargs): def resolve_bond_name_from_references( self, - orbital_1: OrbitalsState, - orbital_2: OrbitalsState, - bravais_vector: tuple, + orbital_1: Optional[OrbitalsState], + orbital_2: Optional[OrbitalsState], + bravais_vector: Optional[tuple], logger: BoundLogger, ) -> Optional[str]: """ Resolves the `name` of the `SlaterKosterBond` from the references to the `OrbitalsState` sections. Args: - orbital_1 (OrbitalsState): The first `OrbitalsState` section. - orbital_2 (OrbitalsState): The second `OrbitalsState` section. - bravais_vector (tuple): The bravais vector of the cell. + orbital_1 (Optional[OrbitalsState]): The first `OrbitalsState` section. + orbital_2 (Optional[OrbitalsState]): The second `OrbitalsState` section. + bravais_vector (Optional[tuple]): The bravais vector of the cell. logger (BoundLogger): The logger to log messages. Returns: (Optional[str]): The resolved `name` of the `SlaterKosterBond`. """ - bond_name = None + # Initial check + if orbital_1 is None or orbital_2 is None: + logger.warning('The `OrbitalsState` sections are not defined.') + return None + if bravais_vector is None: + logger.warning('The `bravais_vector` is not defined.') + return None + + # Check for `l_quantum_symbol` in `OrbitalsState` sections if orbital_1.l_quantum_symbol is None or orbital_2.l_quantum_symbol is None: logger.warning( 'The `l_quantum_symbol` of the `OrbitalsState` bonds are not defined.' ) return None + + bond_name = None value = [orbital_1.l_quantum_symbol, orbital_2.l_quantum_symbol, bravais_vector] # Check if `value` is found in the `self._bond_name_map` and return the key for key, val in self._bond_name_map.items(): @@ -714,14 +715,15 @@ def resolve_bond_name_from_references( def normalize(self, archive, logger) -> None: super().normalize(archive, logger) + # Resolve the SK bond `name` from the `OrbitalsState` references and the `bravais_vector` if self.orbital_1 and self.orbital_2 and self.bravais_vector is not None: - bravais_vector = tuple(self.bravais_vector) # transformed for comparing - self.name = ( - self.resolve_bond_name_from_references( - self.orbital_1, self.orbital_2, bravais_vector, logger - ) - if self.name is None - else self.name + if self.bravais_vector is not None: + bravais_vector = tuple(self.bravais_vector) # transformed for comparing + self.name = self.resolve_bond_name_from_references( + orbital_1=self.orbital_1, + orbital_2=self.orbital_2, + bravais_vector=bravais_vector, + logger=logger, ) diff --git a/tests/test_model_method.py b/tests/test_model_method.py new file mode 100644 index 00000000..80071673 --- /dev/null +++ b/tests/test_model_method.py @@ -0,0 +1,478 @@ +# +# Copyright The NOMAD Authors. +# +# This file is part of NOMAD. See https://nomad-lab.eu for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from typing import List, Union, Optional, Tuple + +from nomad.datamodel import EntryArchive + +from nomad_simulations.model_method import TB, Wannier, SlaterKoster, SlaterKosterBond +from nomad_simulations.atoms_state import OrbitalsState, AtomsState +from nomad_simulations import Simulation +from nomad_simulations.model_system import ModelSystem, AtomicCell + +from . import logger +from .conftest import generate_simulation + + +class TestTB: + """ + Test the `TB` class defined in `model_method.py`. + """ + + @pytest.mark.parametrize( + 'tb_section, result', + [(Wannier(), 'Wannier'), (SlaterKoster(), 'SlaterKoster'), (TB(), None)], + ) + def test_resolve_type(self, tb_section: TB, result: Optional[str]): + """ + Test the `resolve_type` method. + + Args: + tb_section (TB): The TB section to resolve the type from. + result (Optional[str]): The expected type of the TB section. + """ + assert tb_section.resolve_type() == result + + @pytest.mark.parametrize( + 'model_systems, model_index, result', + [ + # no `ModelSystem` sections + ([], 0, None), + # `model_index` out of range + ([ModelSystem()], 1, None), + # no `is_representative` in `ModelSystem` + ([ModelSystem(is_representative=False)], 0, None), + # no `cell` section in `ModelSystem` + ([ModelSystem(is_representative=True)], 0, None), + # no `AtomsState` in `AtomicCell` + ([ModelSystem(is_representative=True, cell=[AtomicCell()])], 0, None), + # no `model_system` child section under `ModelSystem` + ( + [ + ModelSystem( + is_representative=True, + cell=[AtomicCell(atoms_state=[AtomsState()])], + ) + ], + 0, + None, + ), + # `type` for the `model_system` child is not `'active_atom'` + ( + [ + ModelSystem( + is_representative=True, + cell=[AtomicCell(atoms_state=[AtomsState()])], + model_system=[ModelSystem(type='bulk')], + ) + ], + 0, + [], + ), + # wrong index for `AtomsState in active atom + ( + [ + ModelSystem( + is_representative=True, + cell=[AtomicCell(atoms_state=[AtomsState()])], + model_system=[ + ModelSystem(type='active_atom', atom_indices=[2]) + ], + ) + ], + 0, + [], + ), + # empty `OrbitalsState` in `AtomsState` + ( + [ + ModelSystem( + is_representative=True, + cell=[AtomicCell(atoms_state=[AtomsState(orbitals_state=[])])], + model_system=[ + ModelSystem(type='active_atom', atom_indices=[0]) + ], + ) + ], + 0, + [], + ), + # valid case + ( + [ + ModelSystem( + is_representative=True, + cell=[ + AtomicCell( + atoms_state=[ + AtomsState( + orbitals_state=[ + OrbitalsState(l_quantum_symbol='s') + ] + ) + ] + ) + ], + model_system=[ + ModelSystem(type='active_atom', atom_indices=[0]) + ], + ) + ], + 0, + [OrbitalsState(l_quantum_symbol='s')], + ), + ], + ) + def test_resolve_orbital_references( + self, + model_systems: Optional[List[ModelSystem]], + model_index: int, + result: Optional[List[OrbitalsState]], + ): + """ + Test the `resolve_orbital_references` method. + + Args: + model_systems (Optional[List[ModelSystem]]): The `model_system` section to add to `Simulation`. + model_index (int): The index of the `ModelSystem` section to resolve the orbital references from. + result (Optional[List[OrbitalsState]]): The expected orbital references. + """ + tb_method = TB() + simulation = generate_simulation(model_method=tb_method) + simulation.model_system = model_systems + orbitals_ref = tb_method.resolve_orbital_references( + model_systems=model_systems, + logger=logger, + model_index=model_index, + ) + if not orbitals_ref: + assert orbitals_ref == result + else: + assert orbitals_ref[0].l_quantum_symbol == result[0].l_quantum_symbol + + @pytest.mark.parametrize( + 'tb_section, result_type, model_systems, result', + [ + # no method `type` extracted + (TB(), 'unavailable', [], None), + # method `type` extracted + (Wannier(), 'Wannier', [], None), + # no `ModelSystem` sections + (Wannier(), 'Wannier', [], None), + # no `is_representative` in `ModelSystem` + (Wannier(), 'Wannier', [ModelSystem(is_representative=False)], None), + # no `cell` section in `ModelSystem` + (Wannier(), 'Wannier', [ModelSystem(is_representative=True)], None), + # no `AtomsState` in `AtomicCell` + ( + Wannier(), + 'Wannier', + [ModelSystem(is_representative=True, cell=[AtomicCell()])], + None, + ), + # no `model_system` child section under `ModelSystem` + ( + Wannier(), + 'Wannier', + [ + ModelSystem( + is_representative=True, + cell=[AtomicCell(atoms_state=[AtomsState()])], + ) + ], + None, + ), + # `type` for the `model_system` child is not `'active_atom'` + ( + Wannier(), + 'Wannier', + [ + ModelSystem( + is_representative=True, + cell=[AtomicCell(atoms_state=[AtomsState()])], + model_system=[ModelSystem(type='bulk')], + ) + ], + None, + ), + # wrong index for `AtomsState in active atom + ( + Wannier(), + 'Wannier', + [ + ModelSystem( + is_representative=True, + cell=[AtomicCell(atoms_state=[AtomsState()])], + model_system=[ + ModelSystem(type='active_atom', atom_indices=[2]) + ], + ) + ], + None, + ), + # empty `OrbitalsState` in `AtomsState` + ( + Wannier(), + 'Wannier', + [ + ModelSystem( + is_representative=True, + cell=[AtomicCell(atoms_state=[AtomsState(orbitals_state=[])])], + model_system=[ + ModelSystem(type='active_atom', atom_indices=[0]) + ], + ) + ], + None, + ), + # `Wannier.orbitals_ref` already set up + ( + Wannier(orbitals_ref=[OrbitalsState(l_quantum_symbol='d')]), + 'Wannier', + [ + ModelSystem( + is_representative=True, + cell=[ + AtomicCell( + atoms_state=[ + AtomsState( + orbitals_state=[ + OrbitalsState(l_quantum_symbol='s') + ] + ) + ] + ) + ], + model_system=[ + ModelSystem(type='active_atom', atom_indices=[0]) + ], + ) + ], + [OrbitalsState(l_quantum_symbol='d')], + ), + # valid case + ( + Wannier(), + 'Wannier', + [ + ModelSystem( + is_representative=True, + cell=[ + AtomicCell( + atoms_state=[ + AtomsState( + orbitals_state=[ + OrbitalsState(l_quantum_symbol='s') + ] + ) + ] + ) + ], + model_system=[ + ModelSystem(type='active_atom', atom_indices=[0]) + ], + ) + ], + [OrbitalsState(l_quantum_symbol='s')], + ), + ], + ) + def test_normalize( + self, + tb_section: TB, + result_type: Optional[str], + model_systems: Optional[List[ModelSystem]], + result: Optional[List[OrbitalsState]], + ): + """ + Test the `resolve_orbital_references` method. + + Args: + tb_section (TB): The TB section to resolve the type from. + result_type (Optional[str]): The expected type of the TB section. + model_systems (Optional[List[ModelSystem]]): The `model_system` section to add to `Simulation`. + result (Optional[List[OrbitalsState]]): The expected orbital references. + """ + simulation = generate_simulation(model_method=tb_section) + simulation.model_system = model_systems + tb_section.normalize(EntryArchive(), logger) + assert tb_section.type == result_type + if tb_section.orbitals_ref is not None: + assert len(tb_section.orbitals_ref) == 1 + assert ( + tb_section.orbitals_ref[0].l_quantum_symbol + == result[0].l_quantum_symbol + ) + else: + assert tb_section.orbitals_ref == result + + +class TestWannier: + """ + Test the `Wannier` class defined in `model_method.py`. + """ + + @pytest.mark.parametrize( + 'localization_type, is_maximally_localized, result_localization_type', + [ + # `localization_type` and `is_maximally_localized` are `None` + (None, None, None), + # `localization_type` set while `is_maximally_localized` is `None` + ('single_shot', None, 'single_shot'), + # normalizing from `is_maximally_localized` + (None, True, 'maximally_localized'), + (None, False, 'single_shot'), + ], + ) + def test_normalize( + self, + localization_type: Optional[str], + is_maximally_localized: bool, + result_localization_type: Optional[str], + ): + """ + Test the `normalize` method . + + Args: + localization_type (Optional[str]): The localization type. + is_maximally_localized (bool): If the localization is maximally-localized or a single-shot. + result_localization_type (Optional[str]): The expected `localization_type` after normalization. + """ + wannier = Wannier( + localization_type=localization_type, + is_maximally_localized=is_maximally_localized, + ) + wannier.normalize(EntryArchive(), logger) + assert wannier.localization_type == result_localization_type + + +class TestSlaterKosterBond: + """ + Test the `SlaterKosterBond` class defined in `model_method.py`. + """ + + @pytest.mark.parametrize( + 'orbital_1, orbital_2, bravais_vector, result', + [ + # no `OrbitalsState` sections + (None, None, (), None), + (None, OrbitalsState(), (), None), + (OrbitalsState(), None, (), None), + # no `bravais_vector` + (OrbitalsState(), OrbitalsState(), None, None), + # no `l_quantum_symbol` in `OrbitalsState` + (OrbitalsState(), OrbitalsState(), (0, 0, 0), None), + # valid cases + ( + OrbitalsState(l_quantum_symbol='s'), + OrbitalsState(l_quantum_symbol='s'), + (0, 0, 0), + 'sss', + ), + ( + OrbitalsState(l_quantum_symbol='s'), + OrbitalsState(l_quantum_symbol='p'), + (0, 0, 0), + 'sps', + ), + ], + ) + def test_resolve_bond_name_from_references( + self, + orbital_1: Optional[OrbitalsState], + orbital_2: Optional[OrbitalsState], + bravais_vector: Optional[tuple], + result: Optional[str], + ): + """ + Test the `resolve_bond_name_from_references` method. + + Args: + orbital_1 (Optional[OrbitalsState]): The first `OrbitalsState` section. + orbital_2 (Optional[OrbitalsState]): The second `OrbitalsState` section. + bravais_vector (Optional[tuple]): The bravais vector. + result (Optional[str]): The expected bond name. + """ + sk_bond = SlaterKosterBond() + bond_name = sk_bond.resolve_bond_name_from_references( + orbital_1=orbital_1, + orbital_2=orbital_2, + bravais_vector=bravais_vector, + logger=logger, + ) + assert bond_name == result + + @pytest.mark.parametrize( + 'orbital_1, orbital_2, bravais_vector, result', + [ + # no `OrbitalsState` sections + (None, None, [], None), + (None, OrbitalsState(), [], None), + (OrbitalsState(), None, [], None), + # no `bravais_vector` + (OrbitalsState(), OrbitalsState(), None, None), + # no `l_quantum_symbol` in `OrbitalsState` + (OrbitalsState(), OrbitalsState(), (0, 0, 0), None), + # valid cases + ( + OrbitalsState(l_quantum_symbol='s'), + OrbitalsState(l_quantum_symbol='s'), + (0, 0, 0), + 'sss', + ), + ( + OrbitalsState(l_quantum_symbol='s'), + OrbitalsState(l_quantum_symbol='p'), + (0, 0, 0), + 'sps', + ), + ], + ) + def test_normalize( + self, + orbital_1: Optional[OrbitalsState], + orbital_2: Optional[OrbitalsState], + bravais_vector: Optional[tuple], + result: Optional[str], + ): + """ + Test the `normalize` method. + + Args: + orbital_1 (Optional[OrbitalsState]): The first `OrbitalsState` section. + orbital_2 (Optional[OrbitalsState]): The second `OrbitalsState` section. + bravais_vector (Optional[tuple]): The bravais vector. + result (Optional[str]): The expected SK bond `name` after normalization. + """ + sk_bond = SlaterKosterBond() + atoms_state = AtomsState() + simulation = Simulation( + model_system=[ModelSystem(cell=[AtomicCell(atoms_state=[atoms_state])])] + ) + if orbital_1 is not None: + atoms_state.orbitals_state.append(orbital_1) + sk_bond.orbital_1 = atoms_state.orbitals_state[0] + if orbital_2 is not None: + atoms_state.orbitals_state.append(orbital_2) + sk_bond.orbital_2 = atoms_state.orbitals_state[-1] + if bravais_vector is not None and len(bravais_vector) != 0: + sk_bond.bravais_vector = bravais_vector + sk_bond.normalize(EntryArchive(), logger) + assert sk_bond.name == result