From 20d2fb2dc36bc46e9e5d39bc0db847b899ee5eac Mon Sep 17 00:00:00 2001 From: ndaelman-hu <107392603+ndaelman-hu@users.noreply.github.com> Date: Wed, 16 Oct 2024 14:36:00 +0200 Subject: [PATCH] Extend equalities with comparison (#122) * Add support for comparison operators to (`Atomic`)`Cell`, similar to `<`, `>`, `<=`, `>=` * Rewrite (in)equality operators --------- Co-authored-by: ndaelman --- .../schema_packages/__init__.py | 4 +- .../schema_packages/model_system.py | 209 +++++++++++------ .../schema_packages/utils/__init__.py | 1 + .../schema_packages/utils/utils.py | 18 +- tests/test_model_system.py | 210 +++++++++++------- 5 files changed, 282 insertions(+), 160 deletions(-) diff --git a/src/nomad_simulations/schema_packages/__init__.py b/src/nomad_simulations/schema_packages/__init__.py index 8b730793..78d66557 100644 --- a/src/nomad_simulations/schema_packages/__init__.py +++ b/src/nomad_simulations/schema_packages/__init__.py @@ -31,8 +31,8 @@ class NOMADSimulationsEntryPoint(SchemaPackageEntryPoint): description='Limite of the number of atoms in the unit cell to be treated for the system type classification from MatID to work. This is done to avoid overhead of the package.', ) equal_cell_positions_tolerance: float = Field( - 1e-12, - description='Tolerance (in meters) for the cell positions to be considered equal.', + 12, + description='Decimal order or tolerance (in meters) for comparing cell positions.', ) def load(self): diff --git a/src/nomad_simulations/schema_packages/model_system.py b/src/nomad_simulations/schema_packages/model_system.py index 0555c432..d47bf673 100644 --- a/src/nomad_simulations/schema_packages/model_system.py +++ b/src/nomad_simulations/schema_packages/model_system.py @@ -1,5 +1,25 @@ +# +# Copyright The NOMAD Authors. +# +# This file is part of NOMAD. See https://nomad-lab.eu for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import re -from typing import TYPE_CHECKING, Optional +from functools import lru_cache +from hashlib import sha1 +from typing import TYPE_CHECKING import ase import numpy as np @@ -22,12 +42,17 @@ from nomad.units import ureg if TYPE_CHECKING: + from collections.abc import Generator + from typing import Any, Callable, Optional + + import pint from nomad.datamodel.datamodel import EntryArchive from nomad.metainfo import Context, Section from structlog.stdlib import BoundLogger from nomad_simulations.schema_packages.atoms_state import AtomsState from nomad_simulations.schema_packages.utils import ( + catch_not_implemented, get_sibling_section, is_not_representative, ) @@ -200,6 +225,72 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: return +def _check_implemented(func: 'Callable'): + """ + Decorator to restrict the comparison functions to the same class. + """ + + def wrapper(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return func(self, other) + + return wrapper + + +class PartialOrderElement: + def __init__(self, representative_variable): + self.representative_variable = representative_variable + + def __hash__(self): + return self.representative_variable.__hash__() + + @_check_implemented + def __eq__(self, other): + return self.representative_variable == other.representative_variable + + @_check_implemented + def __lt__(self, other): + return False + + @_check_implemented + def __gt__(self, other): + return False + + def __le__(self, other): + return self.__eq__(other) + + def __ge__(self, other): + return self.__eq__(other) + + # __ne__ assumes that usage in a finite set with its comparison definitions + + +class HashedPositions(PartialOrderElement): + # `representative_variable` is a `pint.Quantity` object + + def __hash__(self): + hash_str = sha1( + np.ascontiguousarray( + np.round( + self.representative_variable.to_base_units().magnitude, + decimals=configuration.equal_cell_positions_tolerance, + out=None, + ) + ).tobytes() + ).hexdigest() + return int(hash_str, 16) + + def __eq__(self, other): + """Equality as defined between HashedPositions.""" + if ( + self.representative_variable is None + or other.representative_variable is None + ): + return NotImplemented + return np.allclose(self.representative_variable, other.representative_variable) + + class Cell(GeometricSpace): """ A base section used to specify the cell quantities of a system at a given moment in time. @@ -217,7 +308,7 @@ class Cell(GeometricSpace): type=MEnum('original', 'primitive', 'conventional'), description=""" Representation type of the cell structure. It might be: - - 'original' as in origanally parsed, + - 'original' as in originally parsed, - 'primitive' as the primitive unit cell, - 'conventional' as the conventional cell used for referencing. """, @@ -278,45 +369,36 @@ class Cell(GeometricSpace): """, ) - def _check_positions(self, positions_1, positions_2) -> list: - # Check that all the `positions`` of `cell_1` match with the ones in `cell_2` - check_positions = [] - for i1, pos1 in enumerate(positions_1): - for i2, pos2 in enumerate(positions_2): - if np.allclose( - pos1, pos2, atol=configuration.equal_cell_positions_tolerance - ): - check_positions.append([i1, i2]) - break - return check_positions - - def is_equal_cell(self, other) -> bool: - """ - Check if the cell is equal to an`other` cell by comparing the `positions`. - Args: - other: The other cell to compare with. - Returns: - bool: True if the cells are equal, False otherwise. - """ - # TODO implement checks on `lattice_vectors` and other quantities to ensure the equality of primitive cells - if not isinstance(other, Cell): - return False + @staticmethod + def _generate_comparer(obj: 'Cell') -> 'Generator[Any, None, None]': + try: + return ((HashedPositions(pos)) for pos in obj.positions) + except AttributeError: + raise NotImplementedError - # If the `positions` are empty, return False - if self.positions is None or other.positions is None: - return False + @catch_not_implemented + def is_lt_cell(self, other) -> bool: + return set(self._generate_comparer(self)) < set(self._generate_comparer(other)) - # The `positions` should have the same length (same number of positions) - if len(self.positions) != len(other.positions): - return False - n_positions = len(self.positions) + @catch_not_implemented + def is_gt_cell(self, other) -> bool: + return set(self._generate_comparer(self)) > set(self._generate_comparer(other)) - check_positions = self._check_positions( - positions_1=self.positions, positions_2=other.positions - ) - if len(check_positions) != n_positions: - return False - return True + @catch_not_implemented + def is_le_cell(self, other) -> bool: + return set(self._generate_comparer(self)) <= set(self._generate_comparer(other)) + + @catch_not_implemented + def is_ge_cell(self, other) -> bool: + return set(self._generate_comparer(self)) >= set(self._generate_comparer(other)) + + @catch_not_implemented + def is_equal_cell(self, other) -> bool: # TODO: improve naming + return set(self._generate_comparer(self)) == set(self._generate_comparer(other)) + + def is_ne_cell(self, other) -> bool: + # this does not hold in general, but here we use finite sets + return not self.is_equal_cell(other) def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) @@ -361,40 +443,20 @@ def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwarg # Set the name of the section self.name = self.m_def.name - def is_equal_cell(self, other) -> bool: - """ - Check if the atomic cell is equal to an`other` atomic cell by comparing the `positions` and - the `AtomsState[*].chemical_symbol`. - Args: - other: The other atomic cell to compare with. - Returns: - bool: True if the atomic cells are equal, False otherwise. - """ - if not isinstance(other, AtomicCell): - return False - - # Compare positions using the parent sections's `__eq__` method - if not super().is_equal_cell(other=other): - return False - - # Check that the `chemical_symbol` of the atoms in `cell_1` match with the ones in `cell_2` - check_positions = self._check_positions( - positions_1=self.positions, positions_2=other.positions - ) + @staticmethod + def _generate_comparer(obj: 'AtomicCell') -> 'Generator[Any, None, None]': + # presumes `atoms_state` mapping 1-to-1 with `positions` and conserves the order try: - for atom in check_positions: - element_1 = self.atoms_state[atom[0]].chemical_symbol - element_2 = other.atoms_state[atom[1]].chemical_symbol - if element_1 != element_2: - return False - except Exception: - return False - return True + return ( + (HashedPositions(pos), PartialOrderElement(st.chemical_symbol)) + for pos, st in zip(obj.positions, obj.atoms_state) + ) + except AttributeError: + raise NotImplementedError def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]: """ Get the chemical symbols of the atoms in the atomic cell. These are defined on `atoms_state[*].chemical_symbol`. - Args: logger (BoundLogger): The logger to log messages. @@ -412,7 +474,7 @@ def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]: chemical_symbols.append(atom_state.chemical_symbol) return chemical_symbols - def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]: + def to_ase_atoms(self, logger: 'BoundLogger') -> 'Optional[ase.Atoms]': """ Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell` section (labels, periodic_boundary_conditions, positions, and lattice_vectors). @@ -602,8 +664,11 @@ class Symmetry(ArchiveSection): ) def resolve_analyzed_atomic_cell( - self, symmetry_analyzer: SymmetryAnalyzer, cell_type: str, logger: 'BoundLogger' - ) -> Optional[AtomicCell]: + self, + symmetry_analyzer: 'SymmetryAnalyzer', + cell_type: str, + logger: 'BoundLogger', + ) -> 'Optional[AtomicCell]': """ Resolves the `AtomicCell` section from the `SymmetryAnalyzer` object and the cell_type (primitive or conventional). @@ -647,8 +712,8 @@ def resolve_analyzed_atomic_cell( return atomic_cell def resolve_bulk_symmetry( - self, original_atomic_cell: AtomicCell, logger: 'BoundLogger' - ) -> tuple[Optional[AtomicCell], Optional[AtomicCell]]: + self, original_atomic_cell: 'AtomicCell', logger: 'BoundLogger' + ) -> 'tuple[Optional[AtomicCell], Optional[AtomicCell]]': """ Resolves the symmetry of the material being simulated using MatID and the originally parsed data under original_atomic_cell. It generates two other diff --git a/src/nomad_simulations/schema_packages/utils/__init__.py b/src/nomad_simulations/schema_packages/utils/__init__.py index 52d9ca22..f9945a34 100644 --- a/src/nomad_simulations/schema_packages/utils/__init__.py +++ b/src/nomad_simulations/schema_packages/utils/__init__.py @@ -1,5 +1,6 @@ from .utils import ( RussellSaundersState, + catch_not_implemented, get_composition, get_sibling_section, get_variables, diff --git a/src/nomad_simulations/schema_packages/utils/utils.py b/src/nomad_simulations/schema_packages/utils/utils.py index 1d40aa4a..eff18376 100644 --- a/src/nomad_simulations/schema_packages/utils/utils.py +++ b/src/nomad_simulations/schema_packages/utils/utils.py @@ -5,7 +5,7 @@ from nomad.config import config if TYPE_CHECKING: - from typing import Optional + from typing import Callable, Optional from nomad.datamodel.data import ArchiveSection from structlog.stdlib import BoundLogger @@ -154,3 +154,19 @@ def get_composition(children_names: 'list[str]') -> str: children_count_tup = np.unique(children_names, return_counts=True) formula = ''.join([f'{name}({count})' for name, count in zip(*children_count_tup)]) return formula if formula else None + + +def catch_not_implemented(func: 'Callable') -> 'Callable': + """ + Decorator to default comparison functions outside the same class to `False`. + """ + + def wrapper(self, other) -> bool: + if not isinstance(other, self.__class__): + return False # ? should this throw an error instead? + try: + return func(self, other) + except (TypeError, NotImplementedError): + return False + + return wrapper diff --git a/tests/test_model_system.py b/tests/test_model_system.py index f334da23..088ecc6b 100644 --- a/tests/test_model_system.py +++ b/tests/test_model_system.py @@ -18,96 +18,104 @@ from .conftest import generate_atomic_cell -class TestCell: +class TestAtomicCell: """ - Test the `Cell` section defined in model_system.py + Test the `AtomicCell`, `Cell` and `GeometricSpace` classes defined in model_system.py """ @pytest.mark.parametrize( 'cell_1, cell_2, result', [ - (Cell(), None, False), # one cell is None - (Cell(), Cell(), False), # both cells are empty + (Cell(), None, {'lt': False, 'gt': False, 'eq': False}), # one cell is None + # (Cell(), Cell(), False), # both cells are empty + # ( + # Cell(positions=[[1, 0, 0]]), + # Cell(), + # False, + # ), # one cell has positions, the other is empty ( Cell(positions=[[1, 0, 0]]), - Cell(), - False, - ), # one cell has positions, the other is empty + Cell(positions=[[2, 0, 0]]), + {'lt': False, 'gt': False, 'eq': False}, + ), # position vectors are treated as the fundamental set elements ( Cell(positions=[[1, 0, 0], [0, 1, 0]]), Cell(positions=[[1, 0, 0]]), - False, - ), # length mismatch - ( - Cell(positions=[[1, 0, 0], [0, 1, 0]]), - Cell(positions=[[1, 0, 0], [0, -1, 0]]), - False, - ), # different positions - ( - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - True, - ), # same ordered positions - ( - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - Cell(positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]]), - True, - ), # different ordered positions but same cell - ], - ) - def test_is_equal_cell(self, cell_1: Cell, cell_2: Cell, result: bool): - """ - Test the `is_equal_cell` methods of `Cell`. - """ - assert cell_1.is_equal_cell(other=cell_2) == result - - -class TestAtomicCell: - """ - Test the `AtomicCell`, `Cell` and `GeometricSpace` classes defined in model_system.py - """ - - @pytest.mark.parametrize( - 'cell_1, cell_2, result', - [ - (Cell(), None, False), # one cell is None - (Cell(), Cell(), False), # both cells are empty + {'lt': False, 'gt': True, 'eq': False}, + ), # one is a subset of the other ( Cell(positions=[[1, 0, 0]]), - Cell(), - False, - ), # one cell has positions, the other is empty - ( Cell(positions=[[1, 0, 0], [0, 1, 0]]), - Cell(positions=[[1, 0, 0]]), - False, - ), # length mismatch + {'lt': True, 'gt': False, 'eq': False}, + ), # one is a subset of the other ( Cell(positions=[[1, 0, 0], [0, 1, 0]]), Cell(positions=[[1, 0, 0], [0, -1, 0]]), - False, + {'lt': False, 'gt': False, 'eq': False}, ), # different positions ( Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - True, + {'lt': False, 'gt': False, 'eq': True}, ), # same ordered positions ( Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), Cell(positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]]), - True, + {'lt': False, 'gt': False, 'eq': True}, ), # different ordered positions but same cell + # ( + # AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + # Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + # False, + # ), # one atomic cell and another cell (missing chemical symbols) + # ( + # AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + # AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + # False, + # ), # missing chemical symbols + # ND: the comparison will now return an error here + # handling a case that should be resolved by the normalizer falls outside its scope ( - AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - False, - ), # one atomic cell and another cell (missing chemical symbols) + AtomicCell( + positions=[[1, 0, 0]], + atoms_state=[ + AtomsState(chemical_symbol='O'), + ], + ), + AtomicCell( + positions=[[1, 0, 0]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + ], + ), + {'lt': False, 'gt': False, 'eq': False}, + ), # chemical symbols are treated as the fundamental set elements ( - AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - False, - ), # missing chemical symbols + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + ], + ), + {'lt': False, 'gt': True, 'eq': False}, + ), # one is a subset of the other ( + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + ], + ), AtomicCell( positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], atoms_state=[ @@ -116,6 +124,16 @@ class TestAtomicCell: AtomsState(chemical_symbol='O'), ], ), + {'lt': True, 'gt': False, 'eq': False}, + ), # one is a subset of the other + ( + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), AtomicCell( positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], atoms_state=[ @@ -124,7 +142,26 @@ class TestAtomicCell: AtomsState(chemical_symbol='O'), ], ), - True, + {'lt': False, 'gt': False, 'eq': False}, + ), + ( + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + {'lt': False, 'gt': False, 'eq': True}, ), # same ordered positions and chemical symbols ( AtomicCell( @@ -143,7 +180,7 @@ class TestAtomicCell: AtomsState(chemical_symbol='O'), ], ), - False, + {'lt': False, 'gt': False, 'eq': False}, ), # same ordered positions but different chemical symbols ( AtomicCell( @@ -162,38 +199,41 @@ class TestAtomicCell: AtomsState(chemical_symbol='H'), ], ), - True, - ), # different ordered positions but same chemical symbols - ], - ) - def test_is_equal_cell(self, cell_1: Cell, cell_2: Cell, result: bool): - """ - Test the `is_equal_cell` methods of `AtomicCell`. - """ - assert cell_1.is_equal_cell(other=cell_2) == result - - @pytest.mark.parametrize( - 'atomic_cell, result', - [ - (AtomicCell(), []), - (AtomicCell(atoms_state=[AtomsState(chemical_symbol='H')]), ['H']), + {'lt': False, 'gt': False, 'eq': True}, + ), # same position-symbol map, different overall order ( AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + AtomicCell( + positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]], atoms_state=[ AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='Fe'), + AtomsState(chemical_symbol='H'), AtomsState(chemical_symbol='O'), - ] + ], ), - ['H', 'Fe', 'O'], - ), + {'lt': False, 'gt': False, 'eq': False}, + ), # different position-symbol map ], ) - def test_get_chemical_symbols(self, atomic_cell: AtomicCell, result: list[str]): + def test_partial_order( + self, cell_1: 'Cell', cell_2: 'Cell', result: dict[str, bool] + ): """ - Test the `get_chemical_symbols` method of `AtomicCell`. + Test the comparison operators of `Cell` and `AtomicCell`. """ - assert atomic_cell.get_chemical_symbols(logger=logger) == result + assert cell_1.is_lt_cell(cell_2) == result['lt'] + assert cell_1.is_gt_cell(cell_2) == result['gt'] + assert cell_1.is_le_cell(cell_2) == (result['lt'] or result['eq']) + assert cell_1.is_ge_cell(cell_2) == (result['gt'] or result['eq']) + assert cell_1.is_equal_cell(cell_2) == result['eq'] + assert cell_1.is_ne_cell(cell_2) == (not result['eq']) @pytest.mark.parametrize( 'chemical_symbols, atomic_numbers, formula, lattice_vectors, positions, periodic_boundary_conditions',