Skip to content

Commit

Permalink
Moved testing to test_general.py
Browse files Browse the repository at this point in the history
Added explicitly input for methods in ModelSystem and testing

Added testing for _set_system_branch_depth and TODO for the functions in these testings which should be implemented in the datamodel
  • Loading branch information
JosePizarro3 committed Jun 5, 2024
1 parent d4fec92 commit 908a34b
Show file tree
Hide file tree
Showing 5 changed files with 397 additions and 261 deletions.
34 changes: 19 additions & 15 deletions src/nomad_simulations/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@

import numpy as np
from typing import List
from structlog.stdlib import BoundLogger

from nomad.units import ureg
from nomad.metainfo import SubSection, Quantity, MEnum, Section, Datetime
from nomad.metainfo import SubSection, Quantity, Section, Datetime
from nomad.datamodel.metainfo.annotations import ELNAnnotation
from nomad.datamodel.data import EntryData
from nomad.datamodel.metainfo.basesections import Entity, Activity

from .model_system import ModelSystem
from .model_method import ModelMethod
from .outputs import Outputs
from .utils import is_not_representative, get_composition
from nomad_simulations.model_system import ModelSystem
from nomad_simulations.model_method import ModelMethod
from nomad_simulations.outputs import Outputs
from nomad_simulations.utils import is_not_representative, get_composition


class Program(Entity):
Expand Down Expand Up @@ -178,7 +176,9 @@ def _set_system_branch_depth(
):
for system_child in system_parent.model_system:
system_child.branch_depth = branch_depth + 1
self._set_system_branch_depth(system_child, branch_depth + 1)
self._set_system_branch_depth(
system_parent=system_child, branch_depth=branch_depth + 1
)

def resolve_composition_formula(self, system_parent: ModelSystem) -> None:
"""Determine and set the composition formula for `system_parent` and all of its
Expand Down Expand Up @@ -217,7 +217,9 @@ def set_composition_formula(
for subsystem in subsystems
]
if system.composition_formula is None:
system.composition_formula = get_composition(subsystem_labels)
system.composition_formula = get_composition(
children_names=subsystem_labels
)

def get_composition_recurs(system: ModelSystem, atom_labels: List[str]) -> None:
"""Traverse the system hierarchy downward and set the branch composition for
Expand All @@ -229,10 +231,12 @@ def get_composition_recurs(system: ModelSystem, atom_labels: List[str]) -> None:
to the atom indices stored in system.
"""
subsystems = system.model_system
set_composition_formula(system, subsystems, atom_labels)
set_composition_formula(
system=system, subsystems=subsystems, atom_labels=atom_labels
)
if subsystems:
for subsystem in subsystems:
get_composition_recurs(subsystem, atom_labels)
get_composition_recurs(system=subsystem, atom_labels=atom_labels)

atoms_state = (
system_parent.cell[0].atoms_state if system_parent.cell is not None else []
Expand All @@ -242,7 +246,7 @@ def get_composition_recurs(system: ModelSystem, atom_labels: List[str]) -> None:
if atoms_state is not None
else []
)
get_composition_recurs(system_parent, atom_labels)
get_composition_recurs(system=system_parent, atom_labels=atom_labels)

def normalize(self, archive, logger) -> None:
super(EntryData, self).normalize(archive, logger)
Expand All @@ -263,8 +267,8 @@ def normalize(self, archive, logger) -> None:
system_parent.branch_depth = 0
if len(system_parent.model_system) == 0:
continue
self._set_system_branch_depth(system_parent)
self._set_system_branch_depth(system_parent=system_parent)

if is_not_representative(system_parent, logger):
if is_not_representative(model_system=system_parent, logger=logger):
continue
self.resolve_composition_formula(system_parent)
self.resolve_composition_formula(system_parent=system_parent)
62 changes: 34 additions & 28 deletions src/nomad_simulations/model_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
import numpy as np
import ase
from typing import Tuple, Optional, List
from typing import Tuple, Optional
from structlog.stdlib import BoundLogger

from matid import SymmetryAnalyzer, Classifier # pylint: disable=import-error
Expand All @@ -39,16 +39,15 @@
Formula,
get_normalized_wyckoff,
search_aflow_prototype,
get_composition,
)

from nomad.metainfo import Quantity, SubSection, SectionProxy, MEnum, Section, Context
from nomad.datamodel.data import ArchiveSection
from nomad.datamodel.metainfo.basesections import Entity, System
from nomad.datamodel.metainfo.annotations import ELNAnnotation

from .atoms_state import AtomsState
from .utils import get_sibling_section, is_not_representative
from nomad_simulations.atoms_state import AtomsState
from nomad_simulations.utils import get_sibling_section, is_not_representative


class GeometricSpace(Entity):
Expand Down Expand Up @@ -185,7 +184,7 @@ def get_geometric_space_for_atomic_cell(self, logger: BoundLogger) -> None:
Args:
logger (BoundLogger): The logger to log messages.
"""
atoms = self.to_ase_atoms(logger) # function defined in AtomicCell
atoms = self.to_ase_atoms(logger=logger) # function defined in AtomicCell
cell = atoms.get_cell()
self.length_vector_a, self.length_vector_b, self.length_vector_c = (
cell.lengths() * ureg.angstrom
Expand All @@ -198,7 +197,7 @@ def get_geometric_space_for_atomic_cell(self, logger: BoundLogger) -> None:
def normalize(self, archive, logger) -> None:
# Skip normalization for `Entity`
try:
self.get_geometric_space_for_atomic_cell(logger)
self.get_geometric_space_for_atomic_cell(logger=logger)
except Exception:
logger.warning(
'Could not extract the geometric space information from ASE Atoms object.',
Expand Down Expand Up @@ -348,11 +347,11 @@ def to_ase_atoms(self, logger: BoundLogger) -> Optional[ase.Atoms]:
'Could not find `AtomicCell.periodic_boundary_conditions`. They will be set to [False, False, False].'
)
self.periodic_boundary_conditions = [False, False, False]
ase_atoms.set_pbc(self.periodic_boundary_conditions)
ase_atoms.set_pbc(pbc=self.periodic_boundary_conditions)

# Lattice vectors
if self.lattice_vectors is not None:
ase_atoms.set_cell(self.lattice_vectors.to('angstrom').magnitude)
ase_atoms.set_cell(cell=self.lattice_vectors.to('angstrom').magnitude)
else:
logger.info('Could not find `AtomicCell.lattice_vectors`.')

Expand All @@ -363,7 +362,9 @@ def to_ase_atoms(self, logger: BoundLogger) -> Optional[ase.Atoms]:
'Length of `AtomicCell.positions` does not coincide with the length of the `AtomicCell.atoms_state`.'
)
return None
ase_atoms.set_positions(self.positions.to('angstrom').magnitude)
ase_atoms.set_positions(
newpositions=self.positions.to('angstrom').magnitude
)
else:
logger.warning('Could not find `AtomicCell.positions`.')
return None
Expand Down Expand Up @@ -528,7 +529,7 @@ def resolve_analyzed_atomic_cell(
) # ? why do we need to pass units
atomic_cell.wyckoff_letters = wyckoff
atomic_cell.equivalent_atoms = equivalent_atoms
atomic_cell.get_geometric_space_for_atomic_cell(logger)
atomic_cell.get_geometric_space_for_atomic_cell(logger=logger)
return atomic_cell

def resolve_bulk_symmetry(
Expand All @@ -550,7 +551,7 @@ def resolve_bulk_symmetry(
"""
symmetry = {}
try:
ase_atoms = original_atomic_cell.to_ase_atoms(logger)
ase_atoms = original_atomic_cell.to_ase_atoms(logger=logger)
symmetry_analyzer = SymmetryAnalyzer(
ase_atoms, symmetry_tol=config.normalize.symmetry_tolerance
)
Expand Down Expand Up @@ -584,12 +585,12 @@ def resolve_bulk_symmetry(

# Populating the primitive AtomState information
primitive_atomic_cell = self.resolve_analyzed_atomic_cell(
symmetry_analyzer, 'primitive', logger
symmetry_analyzer=symmetry_analyzer, cell_type='primitive', logger=logger
)

# Populating the conventional AtomState information
conventional_atomic_cell = self.resolve_analyzed_atomic_cell(
symmetry_analyzer, 'conventional', logger
symmetry_analyzer=symmetry_analyzer, cell_type='conventional', logger=logger
)

# Getting prototype_formula, prototype_aflow_id, and strukturbericht designation from
Expand All @@ -606,10 +607,11 @@ def resolve_bulk_symmetry(
conventional_wyckoff = conventional_atomic_cell.wyckoff_letters
# Normalize wyckoff letters
norm_wyckoff = get_normalized_wyckoff(
conventional_num, conventional_wyckoff
atomic_numbers=conventional_num, wyckoff_letters=conventional_wyckoff
)
aflow_prototype = search_aflow_prototype(
symmetry.get('space_group_number'), norm_wyckoff
space_group=symmetry.get('space_group_number'),
norm_wyckoff=norm_wyckoff,
)
if aflow_prototype:
strukturbericht = aflow_prototype.get('Strukturbericht Designation')
Expand Down Expand Up @@ -642,7 +644,9 @@ def normalize(self, archive, logger) -> None:
(
primitive_atomic_cell,
conventional_atomic_cell,
) = self.resolve_bulk_symmetry(atomic_cell, logger)
) = self.resolve_bulk_symmetry(
original_atomic_cell=atomic_cell, logger=logger
)
self.m_parent.m_add_sub_section(ModelSystem.cell, primitive_atomic_cell)
self.m_parent.m_add_sub_section(ModelSystem.cell, conventional_atomic_cell)
# Reference to the standarized cell, and if not, fallback to the originally parsed one
Expand Down Expand Up @@ -712,11 +716,11 @@ def resolve_chemical_formulas(self, formula: Formula) -> None:
Args:
formula (Formula): The Formula object from NOMAD atomutils containing the chemical formulas.
"""
self.descriptive = formula.format('descriptive')
self.reduced = formula.format('reduced')
self.iupac = formula.format('iupac')
self.hill = formula.format('hill')
self.anonymous = formula.format('anonymous')
self.descriptive = formula.format(fmt='descriptive')
self.reduced = formula.format(fmt='reduced')
self.iupac = formula.format(fmt='iupac')
self.hill = formula.format(fmt='hill')
self.anonymous = formula.format(fmt='anonymous')

def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)
Expand All @@ -727,10 +731,10 @@ def normalize(self, archive, logger) -> None:
if atomic_cell is None:
logger.warning('Could not resolve the sibling `AtomicCell` section.')
return
ase_atoms = atomic_cell.to_ase_atoms(logger)
ase_atoms = atomic_cell.to_ase_atoms(logger=logger)
formula = None
try:
formula = Formula(ase_atoms.get_chemical_formula())
formula = Formula(formula=ase_atoms.get_chemical_formula())
# self.chemical_composition = ase_atoms.get_chemical_formula(mode="all")
except ValueError as e:
logger.warning(
Expand All @@ -739,7 +743,7 @@ def normalize(self, archive, logger) -> None:
error=str(e),
)
if formula:
self.resolve_chemical_formulas(formula)
self.resolve_chemical_formulas(formula=formula)
self.m_cache['elemental_composition'] = formula.elemental_composition()


Expand Down Expand Up @@ -957,7 +961,7 @@ def resolve_system_type_and_dimensionality(
radii='covalent',
cluster_threshold=config.normalize.cluster_threshold,
)
classification = classifier.classify(ase_atoms)
classification = classifier.classify(input_system=ase_atoms)
except Exception as e:
logger.warning(
'MatID system classification failed.', exc_info=e, error=str(e)
Expand Down Expand Up @@ -993,7 +997,7 @@ def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

# We don't need to normalize if the system is not representative
if is_not_representative(self, logger):
if is_not_representative(model_system=self, logger=logger):
return

# Extracting ASE Atoms object from the originally parsed AtomicCell section
Expand All @@ -1004,7 +1008,7 @@ def normalize(self, archive, logger) -> None:
return
if self.cell[0].name == 'AtomicCell':
self.cell[0].type = 'original'
ase_atoms = self.cell[0].to_ase_atoms(logger)
ase_atoms = self.cell[0].to_ase_atoms(logger=logger)
if not ase_atoms:
return

Expand All @@ -1016,7 +1020,9 @@ def normalize(self, archive, logger) -> None:
(
self.type,
self.dimensionality,
) = self.resolve_system_type_and_dimensionality(ase_atoms, logger)
) = self.resolve_system_type_and_dimensionality(
ase_atoms=ase_atoms, logger=logger
)
# Creating and normalizing Symmetry section
if self.type == 'bulk' and self.symmetry is not None:
sec_symmetry = self.m_create(Symmetry)
Expand Down
17 changes: 6 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def generate_model_system(


def generate_atomic_cell(
lattice_vectors: List = [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
positions=None,
periodic_boundary_conditions=None,
chemical_symbols: List = ['H', 'H', 'O'],
atomic_numbers: List = [1, 1, 8],
lattice_vectors: List[List[float]] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
positions: Optional[list] = None,
periodic_boundary_conditions: List[bool] = [False, False, False],
chemical_symbols: List[str] = ['H', 'H', 'O'],
atomic_numbers: List[int] = [1, 1, 8],
) -> AtomicCell:
"""
Generate an `AtomicCell` section with the given parameters.
Expand All @@ -133,18 +133,13 @@ def generate_atomic_cell(
if positions is None and chemical_symbols is not None:
n_atoms = len(chemical_symbols)
positions = [[i / n_atoms, i / n_atoms, i / n_atoms] for i in range(n_atoms)]
# Define periodic boundary conditions if not provided
if periodic_boundary_conditions is None:
periodic_boundary_conditions = [False, False, False]

# Define the atomic cell
atomic_cell = AtomicCell()
atomic_cell = AtomicCell(periodic_boundary_conditions=periodic_boundary_conditions)
if lattice_vectors:
atomic_cell.lattice_vectors = lattice_vectors * ureg('angstrom')
if positions:
atomic_cell.positions = positions * ureg('angstrom')
if periodic_boundary_conditions:
atomic_cell.periodic_boundary_conditions = periodic_boundary_conditions

# Add the elements information
for index, atom in enumerate(chemical_symbols):
Expand Down
Loading

0 comments on commit 908a34b

Please sign in to comment.