Skip to content

Commit

Permalink
add proper errors for missing kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Nov 30, 2023
1 parent afd1042 commit 788f59b
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.next import common, constructors, field_utils, utils
from gt4py.next import common, constructors, errors, field_utils, utils
from gt4py.next.embedded import common as embedded_common, context as embedded_context


Expand Down Expand Up @@ -56,7 +56,11 @@ def field_operator_call(

else:
# field_operator called directly
if "offset_provider" not in kwargs:
raise errors.MissingArgumentError(None, "offset_provider", True)
offset_provider = kwargs.pop("offset_provider", None)
if "out" not in kwargs:
raise errors.MissingArgumentError(None, "out", True)
out = kwargs.pop("out")
domain = kwargs.pop("domain", None)

Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .exceptions import (
DSLError,
InvalidParameterAnnotationError,
MissingArgumentError,
MissingAttributeError,
MissingParameterAnnotationError,
UndefinedSymbolError,
Expand All @@ -33,6 +34,7 @@
"InvalidParameterAnnotationError",
"MissingAttributeError",
"MissingParameterAnnotationError",
"MissingArgumentError",
"UndefinedSymbolError",
"UnsupportedPythonFeatureError",
"set_verbose_exceptions",
Expand Down
12 changes: 12 additions & 0 deletions src/gt4py/next/errors/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None:
self.attr_name = attr_name


class MissingArgumentError(DSLError):
arg_name: str
is_kwarg: bool

def __init__(self, location: Optional[SourceLocation], arg_name: str, is_kwarg: bool) -> None:
super().__init__(
location, f"expected {'keyword-' if is_kwarg else ''}argument '{arg_name}'"
)
self.attr_name = arg_name
self.is_kwarg = is_kwarg


class TypeError_(DSLError):
def __init__(self, location: Optional[SourceLocation], message: str) -> None:
super().__init__(location, message)
Expand Down
10 changes: 8 additions & 2 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Any, Optional
from gt4py.next import allocators as next_allocators, embedded as next_embedded
from gt4py.next import allocators as next_allocators, embedded as next_embedded, errors
from gt4py.next.common import Dimension, DimensionKind, GridType
from gt4py.next.embedded import operators as embedded_operators
from gt4py.next.ffront import (
Expand Down Expand Up @@ -682,6 +682,7 @@ def as_program(
self._program_cache[hash_] = Program(
past_node=past_node,
closure_vars=closure_vars,
definition=None,
backend=self.backend,
grid_type=self.grid_type,
)
Expand All @@ -694,7 +695,12 @@ def __call__(
) -> None:
if not next_embedded.context.within_context() and self.backend is not None:
# non embedded execution
offset_provider = kwargs.pop("offset_provider", None)
if "offset_provider" not in kwargs:
raise errors.MissingArgumentError(None, "offset_provider", True)
offset_provider = kwargs.pop("offset_provider")

if "out" not in kwargs:
raise errors.MissingArgumentError(None, "out", True)
out = kwargs.pop("out")
args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs)
# TODO(tehrengruber): check all offset providers are given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest

from gt4py import next as gtx
from gt4py.next import errors

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import IField, cartesian_case # noqa: F401 # fixtures
Expand All @@ -38,3 +39,19 @@ def copy(a: IField) -> IField:
# due to `fieldview_backend` fixture (dependency of `cartesian_case`)
# setting the default backend to something invalid.
_ = copy(a, out=a, offset_provider={})


def test_missing_arg(cartesian_case): # noqa: F811 # fixtures
"""Test that calling a field_operator without required args raises an error."""

@gtx.field_operator(backend=cartesian_case.backend)
def copy(a: IField) -> IField:
return a

a = cases.allocate(cartesian_case, copy, "a")()

with pytest.raises(errors.MissingArgumentError, match="'out'"):
_ = copy(a, offset_provider={})

with pytest.raises(errors.MissingArgumentError, match="'offset_provider'"):
_ = copy(a, out=a)

0 comments on commit 788f59b

Please sign in to comment.