Skip to content

Commit

Permalink
Added testing for Variables
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Apr 8, 2024
1 parent 5d9bfac commit f694cba
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 8 deletions.
36 changes: 28 additions & 8 deletions src/nomad_simulations/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#

import numpy as np
from typing import Optional
from structlog.stdlib import BoundLogger

from nomad.datamodel.data import ArchiveSection
from nomad.metainfo import Quantity, Section, Context
Expand Down Expand Up @@ -53,17 +55,35 @@ class Variables(ArchiveSection):

# grid_points_error = Quantity()

def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

# Setting `n_bins` if these are not defined
if self.grid_points is not None:
if self.n_grid_points != len(self.grid_points):
def get_n_grid_points(
self, grid_points: Optional[list], logger: BoundLogger
) -> Optional[int]:
"""
Get the number of grid points from the `grid_points` list. If `n_grid_points` is previously defined
and does not coincide with the length of `grid_points`, a warning is issued and this function re-assigns `n_grid_points`
as the length of `grid_points`.
Args:
grid_points (Optional[list]): The grid points in which the variable is discretized.
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[int]): The number of grid points.
"""
if grid_points is not None and len(grid_points) > 0:
if self.n_grid_points != len(grid_points):
logger.warning(
f'The stored `n_grid_points`, {self.n_grid_points}, does not coincide with the length of `grid_points`, '
f'{len(self.grid_points)}. We will re-assign `n_grid_points` as the length of `grid_points`.'
f'{len(grid_points)}. We will re-assign `n_grid_points` as the length of `grid_points`.'
)
self.n_grid_points = len(self.grid_points)
return len(grid_points)
return self.n_grid_points

def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

# Setting `n_grid_points` if these are not defined
self.n_grid_points = self.get_n_grid_points(self.grid_points, logger)


class Temperature(Variables):
Expand Down
52 changes: 52 additions & 0 deletions tests/test_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#
# Copyright The NOMAD Authors.
#
# This file is part of NOMAD. See https://nomad-lab.eu for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pytest

from . import logger

from nomad_simulations.variables import Variables


class TestVariables:
"""
Test the `Variables` class defined in `variables.py`.
"""

@pytest.mark.parametrize(
'n_grid_points, grid_points, result',
[
(3, [-1, 0, 1], 3),
(5, [-1, 0, 1], 3),
(None, [-1, 0, 1], 3),
(4, None, 4),
(4, [], 4),
],
)
def test_normalize(self, n_grid_points: int, grid_points: list, result: int):
"""
Test the `normalize` and `get_n_grid_points` methods.
"""
variable = Variables(
name='variable_1',
n_grid_points=n_grid_points,
grid_points=grid_points,
)
assert variable.get_n_grid_points(grid_points, logger) == result
variable.normalize(None, logger)
assert variable.n_grid_points == result

0 comments on commit f694cba

Please sign in to comment.