Skip to content

Commit

Permalink
Merge pull request #54 from shirte/main
Browse files Browse the repository at this point in the history
Restrict output formats of ProblemListConverter
  • Loading branch information
shirte authored Jan 7, 2025
2 parents 6dd1def + 738fff7 commit fdac965
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
36 changes: 27 additions & 9 deletions nerdd_module/converters/problem_list_converter.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@
from typing import Any, List, cast
import logging
from typing import Any, List, Union, cast

from ..problem import Problem
from .converter import Converter
from .converter_config import ALL, ConverterConfig
from .converter_config import ConverterConfig

__all__ = ["ProblemListConverter"]
__all__ = ["ProblemListIdentityConverter", "ProblemListConverter"]

logger = logging.getLogger(__name__)


class ProblemListIdentityConverter(Converter):
def _convert(self, input: Any, context: dict) -> Any:
return input

config = ConverterConfig(
data_types="problem_list",
output_formats=["pandas", "iterator", "record_list"],
)


class ProblemListConverter(Converter):
def _convert(self, input: Any, context: dict) -> Any:
if self.output_format in ["pandas", "iterator", "record_list"]:
return input
else:
problem_list: List[Problem] = cast(List[Problem], input)
return "; ".join([f"{problem.type}: {problem.message}" for problem in problem_list])
problem_list: List[Union[Problem, str]] = cast(List[Union[Problem, str]], input)

def _represent(problem: Union[Problem, str]) -> str:
if isinstance(problem, Problem):
return f"{problem.type}: {problem.message}"
else:
logger.warning("Item is not an instance of Problem: %s", problem)
return problem

return "; ".join([_represent(problem) for problem in problem_list])

config = ConverterConfig(
data_types="problem_list",
output_formats=ALL,
output_formats=["csv", "sdf"],
)
14 changes: 10 additions & 4 deletions tests/basic/converters/test_converters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from nerdd_module import Problem
from nerdd_module.config import ResultProperty
from nerdd_module.converters import (BasicTypeConverter, Converter,
ProblemListConverter, VoidConverter)
ProblemListConverter,
ProblemListIdentityConverter,
VoidConverter)

primitive_data_types = [
"int",
Expand All @@ -10,7 +12,7 @@
"bool",
]

output_formats = ["sdf", "csv", "pandas", "record_list", "iterator"]
output_formats = ["sdf", "csv", "pandas", "record_list", "iterator", "non-existing"]


def test_basic_data_types():
Expand All @@ -35,11 +37,15 @@ def test_problem_list_converter():
for output_format in output_formats:
converter = Converter.get_converter(result_property, output_format)
converted_value = converter.convert(problem_list, {})
assert isinstance(converter, ProblemListConverter)
if output_format in ["pandas", "record_list", "iterator"]:
assert isinstance(converter, ProblemListIdentityConverter)
assert isinstance(converted_value, list)
assert len(converted_value) == len(problem_list)
assert isinstance(converted_value[0], Problem)
else:
elif output_format in ["sdf", "csv"]:
assert isinstance(converter, ProblemListConverter)
assert isinstance(converted_value, str)
assert converted_value.startswith("problem_type")
else:
assert isinstance(converter, VoidConverter)
assert converted_value is Converter.HIDE

0 comments on commit fdac965

Please sign in to comment.