diff --git a/src/nomad_simulations/schema_packages/properties/greens_function.py b/src/nomad_simulations/schema_packages/properties/greens_function.py index 967fa1ab..c4788c7e 100644 --- a/src/nomad_simulations/schema_packages/properties/greens_function.py +++ b/src/nomad_simulations/schema_packages/properties/greens_function.py @@ -173,14 +173,17 @@ def find_space_id(space_map: dict) -> str: return '' space_id = find_space_id(_real_space_map) + find_space_id(_time_space_map) - return space_id + return space_id if space_id else None def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) - if not self.space_id: - self.space_id = self.resolve_space_id() - + space_id = self.resolve_space_id() + if (self.space_id is not None and self.space_id != space_id): + logger.warning( + f'The stored `space_id`, {self.space_id}, does not coincide with the resolved one, {space_id}. We will update it.' + ) + self.space_id = space_id class ElectronicGreensFunction(BaseGreensFunction): """ @@ -365,7 +368,7 @@ def resolve_system_correlation_strengths(self) -> str: return 'OSMI' elif np.all(self.value < 1e-2): return 'Mott insulator' - return '' + return None def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) @@ -376,5 +379,9 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: ) return - if not self.system_correlation_strengths: - self.system_correlation_strengths = self.resolve_system_correlation_strengths() + system_correlation_strengths = self.resolve_system_correlation_strengths() + if (self.system_correlation_strengths is not None and self.system_correlation_strengths != system_correlation_strengths): + logger.warning( + f'The stored `system_correlation_strengths`, {self.system_correlation_strengths}, does not coincide with the resolved one, {system_correlation_strengths}. We will update it.' + ) + self.system_correlation_strengths = system_correlation_strengths diff --git a/tests/test_greens_function.py b/tests/test_greens_function.py index 25c8b208..1cac4038 100644 --- a/tests/test_greens_function.py +++ b/tests/test_greens_function.py @@ -16,14 +16,12 @@ # limitations under the License. # -from typing import Optional, Union +from typing import Union, Optional import pytest +from nomad.datamodel import EntryArchive from nomad_simulations.schema_packages.properties import ( - ElectronicGreensFunction, - ElectronicSelfEnergy, - HybridizationFunction, QuasiparticleWeight, ) from nomad_simulations.schema_packages.properties.greens_function import ( @@ -38,6 +36,8 @@ WignerSeitz, ) +from . import logger + class TestBaseGreensFunction: """ @@ -47,7 +47,7 @@ class TestBaseGreensFunction: @pytest.mark.parametrize( 'variables, result', [ - ([], ''), + ([], None), ([WignerSeitz()], 'r'), ([KMesh()], 'k'), ([Time()], 't'), @@ -72,6 +72,38 @@ def test_resolve_space_id(self, variables: list[Union[WignerSeitz, KMesh, Time, gfs.variables = variables assert gfs.resolve_space_id() == result + @pytest.mark.parametrize( + 'space_id, variables, result', + [ + ('', [], None), # empty `space_id` + ('rt', [], None), # `space_id` set by parser + ('', [WignerSeitz()], 'r'), # resolving `space_id` + ('rt', [WignerSeitz()], 'r'), # normalize overwrites `space_id` + ('', [KMesh()], 'k'), + ('', [Time()], 't'), + ('', [ImaginaryTime()], 'it'), + ('', [Frequency()], 'w'), + ('', [MatsubaraFrequency()], 'iw'), + ('', [WignerSeitz(), Time()], 'rt'), + ('', [WignerSeitz(), ImaginaryTime()], 'rit'), + ('', [WignerSeitz(), Frequency()], 'rw'), + ('', [WignerSeitz(), MatsubaraFrequency()], 'riw'), + ('', [KMesh(), Time()], 'kt'), + ('', [KMesh(), ImaginaryTime()], 'kit'), + ('', [KMesh(), Frequency()], 'kw'), + ('', [KMesh(), MatsubaraFrequency()], 'kiw'), + ], + ) + def test_normalize(self, space_id: str, variables: list[Union[WignerSeitz, KMesh, Time, ImaginaryTime, Frequency, MatsubaraFrequency]], result: Optional[str]): + """ + Test the `normalize` method of the `BaseGreensFunction` class. + """ + gfs = BaseGreensFunction(n_atoms=1, n_correlated_orbitals=1) + gfs.variables = variables + gfs.space_id = space_id if space_id else None + gfs.normalize(archive=EntryArchive(), logger=logger) + assert gfs.space_id == result + class TestQuasiparticleWeight: @@ -103,13 +135,34 @@ def test_is_valid_quasiparticle_weight(self, value: list[float], result: bool): ([[0.2, 0.3, 0.1]], 'strongly-correlated metal'), ([[0, 0.3, 0.1]], 'OSMI'), ([[0, 0, 0]], 'Mott insulator'), - ([[1.0, 0.8, 0.2]], ''), + ([[1.0, 0.8, 0.2]], None), ], ) - def test_resolve_system_correlation_strengths(self, value: list[float], result: str): + def test_resolve_system_correlation_strengths(self, value: list[float], result: Optional[str]): """ Test the `resolve_system_correlation_strengths` method of the `QuasiparticleWeight` class. """ quasiparticle_weight = QuasiparticleWeight(n_atoms=1, n_correlated_orbitals=3) quasiparticle_weight.value = value assert quasiparticle_weight.resolve_system_correlation_strengths() == result + + @pytest.mark.parametrize( + 'value, result', + [ + ([[1, 0.5, -2]], None), + ([[1, 0.5, 8]], None), + ([[1, 0.9, 0.8]], 'non-correlated metal'), + ([[0.2, 0.3, 0.1]], 'strongly-correlated metal'), + ([[0, 0.3, 0.1]], 'OSMI'), + ([[0, 0, 0]], 'Mott insulator'), + ([[1.0, 0.8, 0.2]], None), + ], + ) + def test_normalize(self, value: list[float], result: Optional[str]): + """ + Test the `normalize` method of the `QuasiparticleWeight` class. + """ + quasiparticle_weight = QuasiparticleWeight(n_atoms=1, n_correlated_orbitals=3) + quasiparticle_weight.value = value + quasiparticle_weight.normalize(archive=EntryArchive(), logger=logger) + assert quasiparticle_weight.system_correlation_strengths == result