diff --git a/simulationdataschema/model_system.py b/simulationdataschema/model_system.py index e913addf..d161d5c8 100644 --- a/simulationdataschema/model_system.py +++ b/simulationdataschema/model_system.py @@ -37,6 +37,8 @@ import re import numpy as np import ase +from typing import Union, Tuple +from structlog.stdlib import BoundLogger from matid import SymmetryAnalyzer, Classifier # pylint: disable=import-error from matid.classification.classifications import ( @@ -178,13 +180,13 @@ class AtomicCell(GeometricSpace): """, ) - def to_ase_atoms(self, logger): + def to_ase_atoms(self, logger: BoundLogger) -> Union[ase.Atoms, None]: """ Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell` section (labels, periodic_boundary_conditions, positions, and lattice_vectors). Returns: - ase.Atoms: The ASE Atoms object with the basic information from the `AtomicCell`. + Union[ase.Atoms, None]: The ASE Atoms object with the basic information from the `AtomicCell`. """ # Initialize ase.Atoms object with labels ase_atoms = ase.Atoms(symbols=self.labels) @@ -203,11 +205,11 @@ def to_ase_atoms(self, logger): logger.error( "Length of `AtomicCell.positions` does not coincide with the length of the `AtomicCell.labels`." ) - return + return None ase_atoms.set_positions(self.positions.to("angstrom").magnitude) else: logger.error("Could not find `AtomicCell.positions`.") - return + return None # Lattice vectors if self.lattice_vectors is not None: @@ -364,13 +366,25 @@ class Symmetry(ArchiveSection): ) def resolve_analyzed_atomic_cell( - self, symmetry_analyzer: SymmetryAnalyzer, cell_type: str, logger - ): + self, symmetry_analyzer: SymmetryAnalyzer, cell_type: str, logger: BoundLogger + ) -> Union[AtomicCell, None]: + """ + Resolves the `AtomicCell` section from the `SymmetryAnalyzer` object and the cell_type + (primitive or conventional). + + Args: + symmetry_analyzer (SymmetryAnalyzer): The `SymmetryAnalyzer` object used to resolve. + cell_type (str): The type of cell to resolve, either 'primitive' or 'conventional'. + + Returns: + Union[AtomicCell, None]: The resolved `AtomicCell` section or None if the cell_type + is not recognized. + """ if cell_type not in ["primitive", "conventional"]: logger.error( "Cell type not recognized, only 'primitive' and 'conventional' are allowed." ) - return + return None wyckoff = getattr(symmetry_analyzer, f"get_wyckoff_letters_{cell_type}")() equivalent_atoms = getattr( symmetry_analyzer, f"get_equivalent_atoms_{cell_type}" @@ -392,7 +406,9 @@ def resolve_analyzed_atomic_cell( atomic_cell.get_geometric_space_for_atomic_cell(logger) return atomic_cell - def resolve_bulk_symmetry(self, original_atomic_cell, logger): + def resolve_bulk_symmetry( + self, original_atomic_cell: AtomicCell, logger: BoundLogger + ) -> Tuple[Union[AtomicCell, None], Union[AtomicCell, None]]: """ Resolves the symmetry of the material being simulated using MatID and the originally parsed data under original_atomic_cell. It generates two other @@ -416,10 +432,10 @@ def resolve_bulk_symmetry(self, original_atomic_cell, logger): logger.debug( "Symmetry analysis with MatID is not available.", details=str(e) ) - return + return None, None except Exception as e: logger.warning("Symmetry analysis with MatID failed.", exc_info=e) - return + return None, None # We store symmetry_analyzer info in a dictionary symmetry["bravais_lattice"] = symmetry_analyzer.get_bravais_lattice() @@ -746,7 +762,9 @@ class ModelSystem(System): model_system = SubSection(sub_section=SectionProxy("ModelSystem"), repeats=True) - def resolve_system_type_and_dimensionality(self, ase_atoms): + def resolve_system_type_and_dimensionality( + self, ase_atoms: ase.Atoms + ) -> Tuple[str, int]: """ Determine the ModelSystem.type and ModelSystem.dimensionality using MatID classification analyzer: