diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 64273e18e3..f131c12534 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -16,7 +16,7 @@ from gt4py import eve from gt4py._core import definitions as core_defs -from gt4py.next import common, constructors, field_utils, utils +from gt4py.next import common, constructors, errors, field_utils, utils from gt4py.next.embedded import common as embedded_common, context as embedded_context @@ -56,7 +56,11 @@ def field_operator_call( else: # field_operator called directly + if "offset_provider" not in kwargs: + raise errors.MissingArgumentError(None, "offset_provider", True) offset_provider = kwargs.pop("offset_provider", None) + if "out" not in kwargs: + raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") domain = kwargs.pop("domain", None) diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 61441e83b9..dd48d6f0f9 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -21,6 +21,7 @@ from .exceptions import ( DSLError, InvalidParameterAnnotationError, + MissingArgumentError, MissingAttributeError, MissingParameterAnnotationError, UndefinedSymbolError, @@ -33,6 +34,7 @@ "InvalidParameterAnnotationError", "MissingAttributeError", "MissingParameterAnnotationError", + "MissingArgumentError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", "set_verbose_exceptions", diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index e956858549..2baed9a60e 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -81,6 +81,18 @@ def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None: self.attr_name = attr_name +class MissingArgumentError(DSLError): + arg_name: str + is_kwarg: bool + + def __init__(self, location: Optional[SourceLocation], arg_name: str, is_kwarg: bool) -> None: + super().__init__( + location, f"expected {'keyword-' if is_kwarg else ''}argument '{arg_name}'" + ) + self.attr_name = arg_name + self.is_kwarg = is_kwarg + + class TypeError_(DSLError): def __init__(self, location: Optional[SourceLocation], message: str) -> None: super().__init__(location, message) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 3f8ae36b94..091f260645 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -33,7 +33,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, embedded as next_embedded +from gt4py.next import allocators as next_allocators, embedded as next_embedded, errors from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( @@ -682,6 +682,7 @@ def as_program( self._program_cache[hash_] = Program( past_node=past_node, closure_vars=closure_vars, + definition=None, backend=self.backend, grid_type=self.grid_type, ) @@ -694,7 +695,12 @@ def __call__( ) -> None: if not next_embedded.context.within_context() and self.backend is not None: # non embedded execution - offset_provider = kwargs.pop("offset_provider", None) + if "offset_provider" not in kwargs: + raise errors.MissingArgumentError(None, "offset_provider", True) + offset_provider = kwargs.pop("offset_provider") + + if "out" not in kwargs: + raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) # TODO(tehrengruber): check all offset providers are given diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py index 1f61070e23..51df5be375 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py @@ -15,6 +15,7 @@ import pytest from gt4py import next as gtx +from gt4py.next import errors from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IField, cartesian_case # noqa: F401 # fixtures @@ -38,3 +39,19 @@ def copy(a: IField) -> IField: # due to `fieldview_backend` fixture (dependency of `cartesian_case`) # setting the default backend to something invalid. _ = copy(a, out=a, offset_provider={}) + + +def test_missing_arg(cartesian_case): # noqa: F811 # fixtures + """Test that calling a field_operator without required args raises an error.""" + + @gtx.field_operator(backend=cartesian_case.backend) + def copy(a: IField) -> IField: + return a + + a = cases.allocate(cartesian_case, copy, "a")() + + with pytest.raises(errors.MissingArgumentError, match="'out'"): + _ = copy(a, offset_provider={}) + + with pytest.raises(errors.MissingArgumentError, match="'offset_provider'"): + _ = copy(a, out=a)