Skip to content

Commit

Permalink
Added more testing and fix normalize functions
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Aug 27, 2024
1 parent 2ac054c commit bc037d5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,17 @@ def find_space_id(space_map: dict) -> str:
return ''

space_id = find_space_id(_real_space_map) + find_space_id(_time_space_map)
return space_id
return space_id if space_id else None

def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)

if not self.space_id:
self.space_id = self.resolve_space_id()

space_id = self.resolve_space_id()
if (self.space_id is not None and self.space_id != space_id):
logger.warning(
f'The stored `space_id`, {self.space_id}, does not coincide with the resolved one, {space_id}. We will update it.'
)
self.space_id = space_id

class ElectronicGreensFunction(BaseGreensFunction):
"""
Expand Down Expand Up @@ -365,7 +368,7 @@ def resolve_system_correlation_strengths(self) -> str:
return 'OSMI'
elif np.all(self.value < 1e-2):
return 'Mott insulator'
return ''
return None

def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)
Expand All @@ -376,5 +379,9 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
)
return

if not self.system_correlation_strengths:
self.system_correlation_strengths = self.resolve_system_correlation_strengths()
system_correlation_strengths = self.resolve_system_correlation_strengths()
if (self.system_correlation_strengths is not None and self.system_correlation_strengths != system_correlation_strengths):
logger.warning(
f'The stored `system_correlation_strengths`, {self.system_correlation_strengths}, does not coincide with the resolved one, {system_correlation_strengths}. We will update it.'
)
self.system_correlation_strengths = system_correlation_strengths
67 changes: 60 additions & 7 deletions tests/test_greens_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
# limitations under the License.
#

from typing import Optional, Union
from typing import Union, Optional

import pytest

from nomad.datamodel import EntryArchive
from nomad_simulations.schema_packages.properties import (
ElectronicGreensFunction,
ElectronicSelfEnergy,
HybridizationFunction,
QuasiparticleWeight,
)
from nomad_simulations.schema_packages.properties.greens_function import (
Expand All @@ -38,6 +36,8 @@
WignerSeitz,
)

from . import logger


class TestBaseGreensFunction:
"""
Expand All @@ -47,7 +47,7 @@ class TestBaseGreensFunction:
@pytest.mark.parametrize(
'variables, result',
[
([], ''),
([], None),
([WignerSeitz()], 'r'),
([KMesh()], 'k'),
([Time()], 't'),
Expand All @@ -72,6 +72,38 @@ def test_resolve_space_id(self, variables: list[Union[WignerSeitz, KMesh, Time,
gfs.variables = variables
assert gfs.resolve_space_id() == result

@pytest.mark.parametrize(
'space_id, variables, result',
[
('', [], None), # empty `space_id`
('rt', [], None), # `space_id` set by parser
('', [WignerSeitz()], 'r'), # resolving `space_id`
('rt', [WignerSeitz()], 'r'), # normalize overwrites `space_id`
('', [KMesh()], 'k'),
('', [Time()], 't'),
('', [ImaginaryTime()], 'it'),
('', [Frequency()], 'w'),
('', [MatsubaraFrequency()], 'iw'),
('', [WignerSeitz(), Time()], 'rt'),
('', [WignerSeitz(), ImaginaryTime()], 'rit'),
('', [WignerSeitz(), Frequency()], 'rw'),
('', [WignerSeitz(), MatsubaraFrequency()], 'riw'),
('', [KMesh(), Time()], 'kt'),
('', [KMesh(), ImaginaryTime()], 'kit'),
('', [KMesh(), Frequency()], 'kw'),
('', [KMesh(), MatsubaraFrequency()], 'kiw'),
],
)
def test_normalize(self, space_id: str, variables: list[Union[WignerSeitz, KMesh, Time, ImaginaryTime, Frequency, MatsubaraFrequency]], result: Optional[str]):
"""
Test the `normalize` method of the `BaseGreensFunction` class.
"""
gfs = BaseGreensFunction(n_atoms=1, n_correlated_orbitals=1)
gfs.variables = variables
gfs.space_id = space_id if space_id else None
gfs.normalize(archive=EntryArchive(), logger=logger)
assert gfs.space_id == result



class TestQuasiparticleWeight:
Expand Down Expand Up @@ -103,13 +135,34 @@ def test_is_valid_quasiparticle_weight(self, value: list[float], result: bool):
([[0.2, 0.3, 0.1]], 'strongly-correlated metal'),
([[0, 0.3, 0.1]], 'OSMI'),
([[0, 0, 0]], 'Mott insulator'),
([[1.0, 0.8, 0.2]], ''),
([[1.0, 0.8, 0.2]], None),
],
)
def test_resolve_system_correlation_strengths(self, value: list[float], result: str):
def test_resolve_system_correlation_strengths(self, value: list[float], result: Optional[str]):
"""
Test the `resolve_system_correlation_strengths` method of the `QuasiparticleWeight` class.
"""
quasiparticle_weight = QuasiparticleWeight(n_atoms=1, n_correlated_orbitals=3)
quasiparticle_weight.value = value
assert quasiparticle_weight.resolve_system_correlation_strengths() == result

@pytest.mark.parametrize(
'value, result',
[
([[1, 0.5, -2]], None),
([[1, 0.5, 8]], None),
([[1, 0.9, 0.8]], 'non-correlated metal'),
([[0.2, 0.3, 0.1]], 'strongly-correlated metal'),
([[0, 0.3, 0.1]], 'OSMI'),
([[0, 0, 0]], 'Mott insulator'),
([[1.0, 0.8, 0.2]], None),
],
)
def test_normalize(self, value: list[float], result: Optional[str]):
"""
Test the `normalize` method of the `QuasiparticleWeight` class.
"""
quasiparticle_weight = QuasiparticleWeight(n_atoms=1, n_correlated_orbitals=3)
quasiparticle_weight.value = value
quasiparticle_weight.normalize(archive=EntryArchive(), logger=logger)
assert quasiparticle_weight.system_correlation_strengths == result

0 comments on commit bc037d5

Please sign in to comment.