From e9e830a966a58304123f0a1d135db360c02caf58 Mon Sep 17 00:00:00 2001 From: jrudz Date: Wed, 4 Dec 2024 21:54:28 +0100 Subject: [PATCH] initial investigation into harmonizing particles and atoms --- .../schema_packages/atoms_state.py | 16 ++++- .../schema_packages/general.py | 58 +++++++++++-------- .../schema_packages/model_system.py | 16 ++--- .../schema_packages/particles_state.py | 10 +++- 4 files changed, 67 insertions(+), 33 deletions(-) diff --git a/src/nomad_simulations/schema_packages/atoms_state.py b/src/nomad_simulations/schema_packages/atoms_state.py index 32fbdd11..ed2c6526 100644 --- a/src/nomad_simulations/schema_packages/atoms_state.py +++ b/src/nomad_simulations/schema_packages/atoms_state.py @@ -552,7 +552,17 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: ) -class AtomsState(Entity): +class State(Entity): + """ + A base section to define the state information of the system. + """ + + def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwargs): + super().__init__(m_def, m_context, **kwargs) + self.labels = None + + +class AtomsState(State): """ A base section to define each atom state information. """ @@ -641,3 +651,7 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: self.chemical_symbol = self.resolve_chemical_symbol(logger=logger) if self.atomic_number is None: self.atomic_number = self.resolve_atomic_number(logger=logger) + + # Set the labels + if self.chemical_symbol is not None: + self.labels = self.chemical_symbol diff --git a/src/nomad_simulations/schema_packages/general.py b/src/nomad_simulations/schema_packages/general.py index 59777858..496cebad 100644 --- a/src/nomad_simulations/schema_packages/general.py +++ b/src/nomad_simulations/schema_packages/general.py @@ -1,10 +1,17 @@ -from typing import TYPE_CHECKING +#! TODO: Why is TYPE_CHECKING False? +from typing import TYPE_CHECKING, List, Iterable, Union -if TYPE_CHECKING: - from collections.abc import Callable - - from nomad.datamodel.datamodel import EntryArchive - from structlog.stdlib import BoundLogger +if not TYPE_CHECKING: + from nomad.datamodel.datamodel import ( + EntryArchive, + ) + from nomad.metainfo import ( + Context, + Section, + ) + from structlog.stdlib import ( + BoundLogger, + ) import numpy as np from nomad.config import config @@ -227,7 +234,7 @@ def resolve_composition_formula(self, system_parent: ModelSystem) -> None: """ def set_composition_formula( - system: ModelSystem, subsystems: list[ModelSystem], atom_labels: list[str] + system: ModelSystem, subsystems: list[ModelSystem], labels: list[str] ) -> None: """Determine the composition formula for `system` based on its `subsystems`. If `system` has no children, the atom_labels are used to determine the formula. @@ -243,8 +250,8 @@ def set_composition_formula( system.atom_indices if system.atom_indices is not None else [] ) subsystem_labels = ( - [np.array(atom_labels)[atom_indices]] - if atom_labels + [np.array(labels)[atom_indices]] + if labels else ['Unknown' for atom in range(len(atom_indices))] ) else: @@ -259,7 +266,7 @@ def set_composition_formula( children_names=subsystem_labels ) - def get_composition_recurs(system: ModelSystem, atom_labels: list[str]) -> None: + def get_composition_recurs(system: ModelSystem, labels: list[str]) -> None: """Traverse the system hierarchy downward and set the branch composition for all (sub)systems at each level. @@ -269,23 +276,28 @@ def get_composition_recurs(system: ModelSystem, atom_labels: list[str]) -> None: to the atom indices stored in system. """ subsystems = system.model_system - set_composition_formula( - system=system, subsystems=subsystems, atom_labels=atom_labels - ) + set_composition_formula(system=system, subsystems=subsystems, labels=labels) if subsystems: for subsystem in subsystems: - get_composition_recurs(system=subsystem, atom_labels=atom_labels) + get_composition_recurs(system=subsystem, labels=labels) # ! CG: system_parent.cell[0].particles_state instead of atoms_state! - atoms_state = ( - system_parent.cell[0].atoms_state if system_parent.cell is not None else [] - ) - atom_labels = ( - [atom.chemical_symbol for atom in atoms_state] - if atoms_state is not None - else [] - ) - get_composition_recurs(system=system_parent, atom_labels=atom_labels) + labels = [] + if system_parent.cell is not None: + if system_parent.cell[0].name == 'AtomicCell': + labels = ( + [atom.labels for atom in system_parent.cell[0].atoms_state] + if system_parent.cell[0].atoms_state is not None + else [] + ) + elif system_parent.cell[0].name == 'ParticleCell': + labels = ( + [atom.labels for atom in system_parent.cell[0].particles_state] + if system_parent.cell[0].particles_state is not None + else [] + ) + + get_composition_recurs(system=system_parent, labels=labels) def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super(Schema, self).normalize(archive, logger) diff --git a/src/nomad_simulations/schema_packages/model_system.py b/src/nomad_simulations/schema_packages/model_system.py index 21cd6ad4..d0cb0042 100644 --- a/src/nomad_simulations/schema_packages/model_system.py +++ b/src/nomad_simulations/schema_packages/model_system.py @@ -1372,10 +1372,12 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: #! ChemicalFormula calls `ase_atoms = atomic_cell.to_ase_atoms(logger=logger)` and `ase_atoms.get_chemical_formula()` # Creating and normalizing ChemicalFormula section - # TODO add support for fractional formulas (possibly add `AtomicCell.concentrations` for each species) - sec_chemical_formula = self.m_create(ChemicalFormula) - sec_chemical_formula.normalize(archive, logger) - if sec_chemical_formula.m_cache: - self.elemental_composition = sec_chemical_formula.m_cache.get( - 'elemental_composition', [] - ) + if any(cell.name == 'AtomicCell' for cell in self.cell): + # TODO: get_sibling_section() may need to be updated to more specifically search for AtomicCell in ChemicalFormula and Symmetry, in cases where multiple different cells are present + # TODO add support for fractional formulas (possibly add `AtomicCell.concentrations` for each species) + sec_chemical_formula = self.m_create(ChemicalFormula) + sec_chemical_formula.normalize(archive, logger) + if sec_chemical_formula.m_cache: + self.elemental_composition = sec_chemical_formula.m_cache.get( + 'elemental_composition', [] + ) diff --git a/src/nomad_simulations/schema_packages/particles_state.py b/src/nomad_simulations/schema_packages/particles_state.py index 28afb1e6..4fa5bbe1 100644 --- a/src/nomad_simulations/schema_packages/particles_state.py +++ b/src/nomad_simulations/schema_packages/particles_state.py @@ -5,7 +5,8 @@ import ase.geometry import numpy as np import pint -from deprecated import deprecated + +# from deprecated import deprecated from nomad.datamodel.data import ArchiveSection from nomad.datamodel.metainfo.annotations import ELNAnnotation from nomad.datamodel.metainfo.basesections import Entity @@ -17,6 +18,8 @@ from nomad.metainfo import Context, Section from structlog.stdlib import BoundLogger +from nomad_simulations.schema_packages.atoms_state import State + class Particles: """Particle object. @@ -424,7 +427,7 @@ def _set_positions(self, pos): # ? How generic (usable for any CG model) vs. Martini-specific do we want to be? -class ParticlesState(Entity): +class ParticlesState(State): """ A base section to define individual coarse-grained (CG) particle information. """ @@ -491,3 +494,6 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: # Get particle_type as string, if possible. if not isinstance(self.particle_type, str): self.particle_type = self.resolve_particle_type(logger=logger) + + if self.particle_type is not None: + self.labels = self.particle_type