Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update message #30

Merged
merged 2 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions src/scwidgets/_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,56 @@
import re

from termcolor import colored


class Printer:
# move to output
# TODO rename to Formatter
# remove print funcs
LINE_LENGTH = 120
INFO_COLOR = "blue"
ERROR_COLOR = "red"
SUCCESS_COLOR = "green"

@staticmethod
def format_title_message(message: str) -> str:
return message.center(Printer.LINE_LENGTH - len(message) // 2, "-")

@staticmethod
def break_lines(message: str) -> str:
return "\n ".join(re.findall(r".{1," + str(Printer.LINE_LENGTH) + "}", message))

@staticmethod
def color_error_message(message: str) -> str:
return colored(message, Printer.ERROR_COLOR, attrs=["bold"])

@staticmethod
def print_error_message(message: str):
print(colored(message, "red", attrs=["bold"]))
print(Printer.color_error_message(message))

@staticmethod
def color_success_message(message: str) -> str:
return colored(message, Printer.SUCCESS_COLOR, attrs=["bold"])

@staticmethod
def print_success_message(message: str):
print(colored(message, "green", attrs=["bold"]))
print(Printer.color_success_message(message))

@staticmethod
def color_info_message(message: str):
return colored(message, Printer.INFO_COLOR, attrs=["bold"])

@staticmethod
def print_info_message(message: str):
print(colored(message, "blue", attrs=["bold"]))
print(Printer.color_info_message(message))

@staticmethod
def color_assert_failed(message: str) -> str:
return colored(message, "light_" + Printer.ERROR_COLOR)

@staticmethod
def color_assert_info(message: str) -> str:
return colored(message, "light_" + Printer.INFO_COLOR)

@staticmethod
def color_assert_success(message: str) -> str:
return colored(message, "light_" + Printer.SUCCESS_COLOR)
3 changes: 2 additions & 1 deletion src/scwidgets/check/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
assert_shape,
assert_type,
)
from ._check import Check, ChecksLog
from ._check import AssertResult, Check, ChecksLog
from ._widget_check_registry import CheckableWidget, CheckRegistry

__all__ = [
"Check",
"ChecksLog",
"AssertResult",
"CheckRegistry",
"CheckableWidget",
"assert_shape",
Expand Down
96 changes: 75 additions & 21 deletions src/scwidgets/check/_asserts.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import functools
from collections import abc
from typing import Iterable, TypeVar, Union
from typing import Iterable, Union

import numpy as np

from ._check import Check
from ._check import AssertResult, Check

AssertResultT = TypeVar("AssertResultT", bound="str")
AssertFunctionOutputT = Union[str, AssertResult]


def assert_shape(
output_parameters: Check.FunOutParamsT,
output_references: Check.FunOutParamsT,
parameters_to_check: Union[Iterable[int], str] = "auto",
) -> str:
) -> AssertResult:
assert len(output_parameters) == len(
output_references
), "output_parameters and output_references have to have the same length"
Expand All @@ -22,7 +22,7 @@ def assert_shape(
if isinstance(parameters_to_check, str):
if parameters_to_check == "auto":
parameter_indices = []
for i in range(len(output_parameters)):
for i in range(len(output_references)):
if hasattr(output_references[i], "shape"):
parameter_indices.append(i)
elif parameters_to_check == "all":
Expand All @@ -40,13 +40,25 @@ def assert_shape(
f"but got type {type(parameters_to_check)}."
)

failed_parameter_indices = []
failed_parameter_values = []
messages = []
for i in parameter_indices:
if output_parameters[i].shape != output_references[i].shape:
return (
f"For parameter {i} expected shape {output_references[i].shape} "
message = (
f"Expected shape {output_references[i].shape})"
f"but got {output_parameters[i].shape}."
)
return ""
failed_parameter_indices.append(i)
failed_parameter_values.append(output_parameters[i])
messages.append(message)

return AssertResult(
assert_name="assert_shape",
parameter_indices=failed_parameter_indices,
parameter_values=failed_parameter_values,
messages=messages,
)


def assert_numpy_allclose(
Expand All @@ -56,7 +68,7 @@ def assert_numpy_allclose(
rtol=1e-05,
atol=1e-08,
equal_nan=False,
) -> str:
) -> AssertResult:
assert len(output_parameters) == len(
output_references
), "output_parameters and output_references have to have the same length"
Expand Down Expand Up @@ -86,6 +98,9 @@ def assert_numpy_allclose(
f"but got type {type(parameters_to_check)}."
)

failed_parameter_indices = []
failed_parameter_values = []
messages = []
for i in parameter_indices:
is_allclose = np.allclose(
output_parameters[i],
Expand All @@ -101,18 +116,28 @@ def assert_numpy_allclose(
)
abs_diff = np.sum(diff)
rel_diff = np.sum(diff / np.abs(output_references[i]))
return (
f"Output parameter {i} is not close to reference absolute difference "

message = (
f"Output is not close to reference absolute difference "
f"is {abs_diff}, relative difference is {rel_diff}."
)
return ""
failed_parameter_indices.append(i)
failed_parameter_values.append(output_parameters[i])
messages.append(message)

return AssertResult(
assert_name="assert_numpy_allclose",
parameter_indices=failed_parameter_indices,
parameter_values=failed_parameter_values,
messages=messages,
)


def assert_type(
output_parameters: Check.FunOutParamsT,
output_references: Check.FunOutParamsT,
parameters_to_check: Union[Iterable[int], str] = "all",
) -> str:
) -> AssertResult:
assert len(output_parameters) == len(
output_references
), "output_parameters and output_references have to have the same length"
Expand All @@ -134,20 +159,31 @@ def assert_type(
f"but got type {type(parameters_to_check)}."
)

failed_parameter_indices = []
failed_parameter_values = []
messages = []
for i in parameter_indices:
if not (isinstance(output_parameters[i], type(output_references[i]))):
return (
message = (
f"Expected type {type(output_references[i])} "
f"but got {type(output_parameters[i])}."
)
return ""
failed_parameter_indices.append(i)
failed_parameter_values.append(output_parameters[i])
messages.append(message)
return AssertResult(
assert_name="assert_type",
parameter_indices=failed_parameter_indices,
parameter_values=failed_parameter_values,
messages=messages,
)


def assert_numpy_sub_dtype(
output_parameters: Union[Check.FunOutParamsT, tuple[Check.FingerprintT]],
numpy_type: Union[np.dtype, type],
parameters_to_check: Union[Iterable[int], str] = "all",
) -> str:
) -> AssertResult:
if parameters_to_check == "all":
parameter_indices = range(len(output_parameters))
elif isinstance(parameters_to_check, abc.Iterable):
Expand All @@ -158,23 +194,41 @@ def assert_numpy_sub_dtype(
f"but got type {type(parameters_to_check)}."
)

failed_parameter_indices = []
failed_parameter_values = []
messages = []
for i in parameter_indices:
if not (isinstance(output_parameters[i], np.ndarray)):
return (
f"Output parameter {i} expected to be numpy array "
failed_parameter_indices.append(i)
failed_parameter_values.append(output_parameters[i])
message = (
f"Output expected to be numpy array "
f"but got {type(output_parameters[i])}."
)
messages.append(message)
if not (np.issubdtype(output_parameters[i].dtype, numpy_type)):
if isinstance(numpy_type, np.dtype):
type_name = numpy_type.type.__name__
else:
type_name = numpy_type.__name__
return (
f"Output parameter {i} expected to be sub dtype "
failed_parameter_indices.append(i)
failed_parameter_values.append(output_parameters[i])
message = (
f"Output expected to be sub dtype "
f"numpy.{type_name} but got "
f"numpy.{output_parameters[i].dtype.type.__name__}."
)
return ""
messages.append(message)
if isinstance(numpy_type, np.dtype):
type_name = numpy_type.type.__name__
else:
type_name = numpy_type.__name__
return AssertResult(
assert_name=f"assert_numpy_{type_name}_sub_dtype",
parameter_indices=failed_parameter_indices,
parameter_values=failed_parameter_values,
messages=messages,
)


assert_numpy_floating_sub_dtype = functools.partial(
Expand Down
Loading