From 29a3406bb2207b1433ec0e5ed9d13ccbc9b9c434 Mon Sep 17 00:00:00 2001 From: Gordon Watts Date: Thu, 12 Sep 2024 14:31:33 -0700 Subject: [PATCH] Type-Checking the Zip operator on Dictionary should not crash (#149) * feat: Add test for dictionary with Zip key Add a new test case to check that the type follows from a dictionary through a Zip operation. The test creates a lambda expression that selects the 'pt' attribute from a dictionary of jets, which are obtained from an event object. The test ensures that the expected exception is raised when an invalid key is used. Refactor the test_type_based_replacement.py file to include the new test_dictionary_Zip_key() function. * Fix up type and spelling errors * Fix up more spelling and type checking errors * Ignore `zip` when looking at dictionary attributes * Fix up bad addition --- func_adl/ast/function_simplifier.py | 10 +++++----- func_adl/type_based_replacement.py | 24 +++++++++++++----------- tests/test_type_based_replacement.py | 17 +++++++++++++++++ 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/func_adl/ast/function_simplifier.py b/func_adl/ast/function_simplifier.py index a476799..c6f63c6 100644 --- a/func_adl/ast/function_simplifier.py +++ b/func_adl/ast/function_simplifier.py @@ -244,8 +244,8 @@ def visit_SelectMany_of_SelectMany(self, parent: ast.Call, selection: ast.Lambda "SelectMany", [cast(ast.AST, captured_body), cast(ast.AST, func_g)] ) new_select_lambda = lambda_build(captured_arg, new_select) - new_selectmany = function_call("SelectMany", [seq, cast(ast.AST, new_select_lambda)]) - return new_selectmany + new_select_many = function_call("SelectMany", [seq, cast(ast.AST, new_select_lambda)]) + return new_select_many def call_SelectMany(self, node: ast.Call, args: List[ast.AST]): r""" @@ -442,7 +442,7 @@ def visit_Subscript_Tuple(self, v: ast.Tuple, s: Union[ast.Num, ast.Constant, as # Get the value out - this is due to supporting python 3.7-3.9 n = _get_value_from_index(s) if n is None: - return ast.Subscript(v, s, ast.Load()) + return ast.Subscript(v, s, ast.Load()) # type: ignore assert isinstance(n, int), "Programming error: index is not an integer in tuple subscript" if n >= len(v.elts): raise FuncADLIndexError( @@ -460,7 +460,7 @@ def visit_Subscript_List(self, v: ast.List, s: Union[ast.Num, ast.Constant, ast. """ n = _get_value_from_index(s) if n is None: - return ast.Subscript(v, s, ast.Load()) + return ast.Subscript(v, s, ast.Load()) # type: ignore if n >= len(v.elts): raise FuncADLIndexError( f"Attempt to access the {n}th element of a tuple" @@ -484,7 +484,7 @@ def visit_Subscript_Dict_with_value(self, v: ast.Dict, s: Union[str, int]): if _get_value_from_index(value) == s: return v.values[index] - return ast.Subscript(v, s, ast.Load()) + return ast.Subscript(v, s, ast.Load()) # type: ignore def visit_Subscript_Of_First(self, first: ast.AST, s): """ diff --git a/func_adl/type_based_replacement.py b/func_adl/type_based_replacement.py index 1ccc51d..cfc7563 100644 --- a/func_adl/type_based_replacement.py +++ b/func_adl/type_based_replacement.py @@ -102,7 +102,7 @@ def func_adl_callable( Callable[[ObjectStream[W], ast.Call], Tuple[ObjectStream[W], ast.AST]] ] = None ): - """Dectorator that will declare a function that can be used inline in + """Decorator that will declare a function that can be used inline in a `func_adl` expression. The body of the function, what the backend translates it to, must be given by another route (e.g. via `MetaData` and the `processor` argument). @@ -388,7 +388,7 @@ def _fill_in_default_arguments(func: Callable, call: ast.Call) -> Tuple[ast.Call ) return_type = Any - return call, return_type + return call, return_type # type: ignore def fixup_ast_from_modifications(transformed_ast: ast.AST, original_ast: ast.Call) -> ast.Call: @@ -582,7 +582,7 @@ def type_follow_in_callbacks( # If this is a known collection class, we can use call-backs to follow it. if get_origin(call_site_info.obj_type) in _g_collection_classes: rtn_value = self.process_method_call_on_stream_obj( - _g_collection_classes[get_origin(call_site_info.obj_type)], + _g_collection_classes[get_origin(call_site_info.obj_type)], # type: ignore m_name, r_node, get_args(call_site_info.obj_type)[0], @@ -631,7 +631,7 @@ def process_method_call(self, node: ast.Call, obj_type: type) -> ast.AST: Args: node (ast.Call): The ast node - obj_type (type): The object type this method call is occuring against + obj_type (type): The object type this method call is occurring against Returns: ast.AST: An updated ast that is the new method call (with default args, etc.) @@ -643,7 +643,7 @@ def process_method_call(self, node: ast.Call, obj_type: type) -> ast.AST: base_obj_list_all = [obj_type] if is_iterable(obj_type): item_type = unwrap_iterable(obj_type) - base_obj_list_all += [c[item_type] for c in _g_collection_classes] + base_obj_list_all += [c[item_type] for c in _g_collection_classes] # type: ignore assert isinstance(r_node.func, ast.Attribute) m_name = r_node.func.attr @@ -758,7 +758,7 @@ def process_function_call(self, node: ast.Call, func_info: _FuncAdlFunction) -> f"function {func_info.function.__name__} ({str(e)})" ) from e - def process_paramaterized_method_call( + def process_parameterized_method_call( self, node: ast.Call, obj_type: Type, @@ -815,7 +815,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST: if isinstance(t_node.func.value, ast.Attribute): found_type = self.lookup_type(t_node.func.value.value) if found_type is not None: - t_node = self.process_paramaterized_method_call( + t_node = self.process_parameterized_method_call( t_node, found_type, t_node.func.value.attr, @@ -941,17 +941,17 @@ def visit_Constant(self, node: ast.Constant) -> Any: return node def visit_Num(self, node: ast.Num) -> Any: # pragma: no cover - "3.7 compatability" + "3.7 compatibility" self._found_types[node] = type(node.n) return node def visit_Str(self, node: ast.Str) -> Any: # pragma: no cover - "3.7 compatability" + "3.7 compatibility" self._found_types[node] = str return node def visit_NameConstant(self, node: ast.NameConstant) -> Any: # pragma: no cover - "3.7 compatability" + "3.7 compatibility" if node.value is None: raise ValueError("Do not know how to work with pythons None") self._found_types[node] = bool @@ -968,6 +968,8 @@ def visit_Attribute(self, node: ast.Attribute) -> Any: e for e, k in enumerate(t_node.value.keys) if k.value == key # type: ignore ] if len(key_index) == 0: + if t_node.attr.lower() == "zip": + return t_node raise ValueError(f"Key {key} not found in dict expression!!") value = t_node.value.values[key_index[0]] self._found_types[node] = self.lookup_type(value) @@ -999,7 +1001,7 @@ def remap_from_lambda( orig_type = o_stream.item_type var_name = l_func.args.args[0].arg stream, new_body, return_type = remap_by_types(o_stream, var_name, orig_type, l_func.body) - return stream, ast.Lambda(l_func.args, new_body), return_type + return stream, ast.Lambda(l_func.args, new_body), return_type # type: ignore def reset_global_functions(): diff --git a/tests/test_type_based_replacement.py b/tests/test_type_based_replacement.py index ca39992..a2641f5 100644 --- a/tests/test_type_based_replacement.py +++ b/tests/test_type_based_replacement.py @@ -546,6 +546,23 @@ def test_dictionary_bad_key(): assert "jetsss" in str(e) +def test_dictionary_Zip_key(): + "Check that type follow from a dictionary through a Zip works" + + s = ast_lambda( + """({ + 'jet_pt': e.Jets().Select(lambda j: j.pt()), + 'jet_eta': e.Jets().Select(lambda j: j.eta())} + .Zip() + .Select(lambda j: j.pt()))""" + ) + objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) + + new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) + + assert expr_type == Any + + def test_dictionary_through_Select(): """Make sure the Select statement carries the typing all the way through"""