Skip to content

Commit

Permalink
Added testing for model_method.py classes related with TB
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Jun 5, 2024
1 parent b58c9ca commit 55963e8
Show file tree
Hide file tree
Showing 2 changed files with 538 additions and 58 deletions.
118 changes: 60 additions & 58 deletions src/nomad_simulations/model_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
Context,
)

from .numerical_settings import NumericalSettings
from .model_system import ModelSystem
from .atoms_state import OrbitalsState, CoreHole
from .utils import is_not_representative
from nomad_simulations.numerical_settings import NumericalSettings
from nomad_simulations.model_system import ModelSystem, AtomicCell
from nomad_simulations.atoms_state import OrbitalsState, CoreHole
from nomad_simulations.utils import is_not_representative


class ModelMethod(ArchiveSection):
Expand Down Expand Up @@ -448,14 +448,11 @@ class TB(ModelMethodElectronic):
""",
)

def resolve_type(self, logger: BoundLogger) -> Optional[str]:
def resolve_type(self) -> Optional[str]:
"""
Resolves the `type` of the `TB` section if it is not already defined, and from the
`m_def.name` of the section.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[str]): The resolved `type` of the `TB` section.
"""
Expand All @@ -482,23 +479,29 @@ def resolve_orbital_references(
Returns:
Optional[List[OrbitalsState]]: The resolved references to the `OrbitalsState` sections.
"""
model_system = model_systems[model_index]
try:
model_system = model_systems[model_index]
except IndexError:
logger.warning(
f'The `ModelSystem` section with index {model_index} was not found.'
)
return None

# If `ModelSystem` is not representative, the normalization will not run
if is_not_representative(model_system, logger):
if is_not_representative(model_system=model_system, logger=logger):
return None

# If `AtomicCell` is not found, the normalization will not run
atomic_cell = model_system.cell[0]
if atomic_cell is None:
if not model_system.cell:
logger.warning('`AtomicCell` section was not found.')
return None
atomic_cell = model_system.cell[0]

# If there is no child `ModelSystem`, the normalization will not run
atoms_state = atomic_cell.atoms_state
model_system_child = model_system.model_system
if model_system_child is None:
logger.warning('No child `ModelSystem` section was found.')
if not atoms_state or not model_system_child:
logger.warning('No `AtomsState` or child `ModelSystem` section were found.')
return None

# We flatten the `OrbitalsState` sections from the `ModelSystem` section
Expand All @@ -509,7 +512,13 @@ def resolve_orbital_references(
continue
indices = active_atom.atom_indices
for index in indices:
active_atoms_state = atoms_state[index]
try:
active_atoms_state = atoms_state[index]
except IndexError:
logger.warning(
f'The `AtomsState` section with index {index} was not found.'
)
continue
orbitals_state = active_atoms_state.orbitals_state
for orbital in orbitals_state:
orbitals_ref.append(orbital)
Expand All @@ -522,11 +531,7 @@ def normalize(self, archive, logger) -> None:
self.name = 'TB'

# Resolve `type` to be defined by the lower level class (Wannier, DFTB, xTB or SlaterKoster) if it is not already defined
self.type = (
self.resolve_type(logger)
if (self.type is None or self.type == 'unavailable')
else self.type
)
self.type = self.resolve_type()

# Resolve `orbitals_ref` from the info in the child `ModelSystem` section and the `OrbitalsState` sections
model_systems = self.m_xpath('m_parent.model_system', dict=False)
Expand All @@ -536,8 +541,10 @@ def normalize(self, archive, logger) -> None:
)
return
# This normalization only considers the last `ModelSystem` (default `model_index` argument set to -1)
orbitals_ref = self.resolve_orbital_references(model_systems, logger)
if orbitals_ref is not None and self.orbitals_ref is None:
orbitals_ref = self.resolve_orbital_references(
model_systems=model_systems, logger=logger
)
if orbitals_ref is not None and len(orbitals_ref) > 0 and not self.orbitals_ref:
self.n_orbitals = len(orbitals_ref)
self.orbitals_ref = orbitals_ref

Expand Down Expand Up @@ -586,32 +593,16 @@ class Wannier(TB):
""",
)

def resolve_localization_type(self, logger: BoundLogger) -> Optional[str]:
"""
Resolves the `localization_type` of the `Wannier` section if it is not already defined, and from the
`is_maximally_localized` boolean.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[str]): The resolved `localization_type` of the `Wannier` section.
"""
if self.localization_type is None:
if self.is_maximally_localized:
return 'maximally_localized'
else:
return 'single_shot'
logger.info(
'Could not find if the Wannier tight-binding model is maximally localized or not.'
)
return None

def normalize(self, archive, logger):
super().normalize(archive, logger)

# Resolve `localization_type` from `is_maximally_localized`
self.localization_type = self.resolve_localization_type(logger)
if self.localization_type is None:
if self.is_maximally_localized is not None:
if self.is_maximally_localized:
self.localization_type = 'maximally_localized'
else:
self.localization_type = 'single_shot'


class SlaterKosterBond(ArchiveSection):
Expand Down Expand Up @@ -680,29 +671,39 @@ def __init__(self, m_def: Section = None, m_context: Context = None, **kwargs):

def resolve_bond_name_from_references(
self,
orbital_1: OrbitalsState,
orbital_2: OrbitalsState,
bravais_vector: tuple,
orbital_1: Optional[OrbitalsState],
orbital_2: Optional[OrbitalsState],
bravais_vector: Optional[tuple],
logger: BoundLogger,
) -> Optional[str]:
"""
Resolves the `name` of the `SlaterKosterBond` from the references to the `OrbitalsState` sections.
Args:
orbital_1 (OrbitalsState): The first `OrbitalsState` section.
orbital_2 (OrbitalsState): The second `OrbitalsState` section.
bravais_vector (tuple): The bravais vector of the cell.
orbital_1 (Optional[OrbitalsState]): The first `OrbitalsState` section.
orbital_2 (Optional[OrbitalsState]): The second `OrbitalsState` section.
bravais_vector (Optional[tuple]): The bravais vector of the cell.
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[str]): The resolved `name` of the `SlaterKosterBond`.
"""
bond_name = None
# Initial check
if orbital_1 is None or orbital_2 is None:
logger.warning('The `OrbitalsState` sections are not defined.')
return None
if bravais_vector is None:
logger.warning('The `bravais_vector` is not defined.')
return None

# Check for `l_quantum_symbol` in `OrbitalsState` sections
if orbital_1.l_quantum_symbol is None or orbital_2.l_quantum_symbol is None:
logger.warning(
'The `l_quantum_symbol` of the `OrbitalsState` bonds are not defined.'
)
return None

bond_name = None
value = [orbital_1.l_quantum_symbol, orbital_2.l_quantum_symbol, bravais_vector]
# Check if `value` is found in the `self._bond_name_map` and return the key
for key, val in self._bond_name_map.items():
Expand All @@ -714,14 +715,15 @@ def resolve_bond_name_from_references(
def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

# Resolve the SK bond `name` from the `OrbitalsState` references and the `bravais_vector`
if self.orbital_1 and self.orbital_2 and self.bravais_vector is not None:
bravais_vector = tuple(self.bravais_vector) # transformed for comparing
self.name = (
self.resolve_bond_name_from_references(
self.orbital_1, self.orbital_2, bravais_vector, logger
)
if self.name is None
else self.name
if self.bravais_vector is not None:
bravais_vector = tuple(self.bravais_vector) # transformed for comparing
self.name = self.resolve_bond_name_from_references(
orbital_1=self.orbital_1,
orbital_2=self.orbital_2,
bravais_vector=bravais_vector,
logger=logger,
)


Expand Down
Loading

0 comments on commit 55963e8

Please sign in to comment.