Skip to content

Commit

Permalink
refactor minmax
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Sep 30, 2023
1 parent af43e56 commit 94fc67d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 15 deletions.
2 changes: 1 addition & 1 deletion vyper/ast/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int:
continue
try:
new_node = func.evaluate(node) # type: ignore
new_node._metadata["type"] = func.fetch_call_return(node)
new_node._metadata["type"] = node._metadata.get("type")
except UnfoldableNode:
continue

Expand Down
18 changes: 6 additions & 12 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2034,26 +2034,20 @@ def evaluate(self, node):
return type(node.args[0]).from_node(node, value=value)

def fetch_call_return(self, node):
return_type = self.infer_arg_types(node).pop()
return return_type

def get_possible_types_from_node(self, node):
self._validate_arg_types(node)
types_list = get_common_types(
*node.args, filter_fn=lambda x: isinstance(x, (IntegerT, DecimalT))
)
if not types_list:
raise TypeMismatch("Cannot perform action between dislike numeric types", node)
return types_list
return types_list

def infer_arg_types(self, node, propagated_typ=None):
types_list = self.get_possible_types_from_node(node)
def infer_arg_types(self, node, typ=None):
assert typ is not None
types_list = self.fetch_call_return(node)

if propagated_typ and propagated_typ in types_list:
type_ = propagated_typ
else:
type_ = types_list.pop()
return [type_, type_]
assert typ in types_list
return [typ, typ]

@process_inputs
def build_IR(self, expr, args, kwargs, context):
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,11 @@ def types_from_Compare(self, node):
def types_from_Call(self, node):
# function calls, e.g. `foo()` or `MyStruct()`
var = self.get_exact_type_from_node(node.func, include_type_exprs=True)
if hasattr(var, "get_possible_types_from_node"):
return var.get_possible_types_from_node(node)

return_value = var.fetch_call_return(node)
if return_value:
if isinstance(return_value, list):
return return_value
return [return_value]
raise InvalidType(f"{var} did not return a value", node)

Expand Down

0 comments on commit 94fc67d

Please sign in to comment.