Skip to content

Commit

Permalink
First prototype of checking for when to call copy.deepcopy
Browse files Browse the repository at this point in the history
  • Loading branch information
sternj committed Oct 10, 2024
1 parent 228437f commit e439e8c
Showing 1 changed file with 104 additions and 15 deletions.
119 changes: 104 additions & 15 deletions hypothesis-python/src/hypothesis/extra/ghostwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,8 +776,88 @@ def _get_qualname(obj: Any, *, include_module: bool = False) -> str:
return qname


class BaseNameVisitor(ast.NodeVisitor):
def __init__(self):
self.base_name = None

def visit_Attribute(self, node: ast.Attribute):
self.visit(node.value)

def visit_Subscript(self, node: ast.Subscript):
self.visit(node.value)

def visit_Call(self, node: ast.Call):
self.visit(node.func)

def visit_Name(self, node: ast.Name):
assert self.base_name is None
self.base_name = node.id

def potentially_modified_vars(func: ast.FunctionDef, params: Set[str]) -> Set[str]:
"""
We want to get any variables that might be modified in a way that might
cause side-effects so that we can copy them.
A variable are potentially modified if:
- a method is called on it, an attribute under it, or an index
of it is anywhere on the LHS of an ast.Assign or an ast.AugAssign
- a method called on it, an attribute of it, or an index of it
- Do we want to count methods on the RHS?
"""
class ModifiedVisitor(ast.NodeVisitor):
def __init__(self):
self.potentially_modified = set()
def visit_Assign(self, node: ast.Assign):
for target in node.targets:
if not isinstance(node, ast.Name):
# the only other things on the LHS
# can be a Subscript or an Attribute
bv = BaseNameVisitor()
bv.visit(target)
self.potentially_modified.add(bv.base_name)

def visit_Call(self, node: ast.Call):
# Might overlap with above method,
# but since we're using a set that's okay
bv = BaseNameVisitor()
bv.visit(node.func)
self.potentially_modified.add(bv.base_name)
self.generic_visit(node)
v = ModifiedVisitor()
v.visit(func)
return v.potentially_modified & params


# def copy_arg(arg):
# return f'copy.deepcopy({arg})'
# def copy_kwarg(name, val):
# return f'{name}={copy_arg(val)}'

def generate_param_expr(name: str, val: Optional[str], kind: inspect._ParameterKind, potentially_modified: Set[str]):
if val is None:
val = name
if kind is inspect.Parameter.POSITIONAL_ONLY:
if name in potentially_modified:
return f'copy.deepcopy({val})'
return val
if kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
if val in potentially_modified:
return f'{name}=copy.deepcopy({val})'
return f'{name}={val}'
if kind is inspect.Parameter.KEYWORD_ONLY:
if val in potentially_modified:
return f'{name}=copy.deepcopy({val})'
return f'{name}={val}'
if kind is inspect.Parameter.VAR_POSITIONAL:
# TODO: I don't know what correct behavior is here
pass
if kind is inspect.Parameter.VAR_KEYWORD:
# TODO: I don't know what correct behavior is here
pass
raise ValueError(f'Unknown kind {kind}')

def _write_call(
func: Callable, *pass_variables: str, except_: Except = Exception, assign: str = ""
func: Callable, *pass_variables: str, except_: Except = Exception, assign: str = "", copy_args: Optional[bool] = None
) -> str:
"""Write a call to `func` with explicit and implicit arguments.
Expand All @@ -791,13 +871,20 @@ def _write_call(
which `func` might raise, and catch-and-reject on them... *unless* they're
subtypes of `except_`, which will be handled in an outer try-except block.
"""
params_dd = _get_params(func)
if copy_args:
potentially_modified = potentially_modified_vars(ast.parse(inspect.getsource(func)).body[0], set(params_dd.keys()))
else:
potentially_modified = set()
args = ", ".join(
(
(v or p.name)
if p.kind is inspect.Parameter.POSITIONAL_ONLY
else f"{p.name}={v or p.name}"
)
for v, p in zip_longest(pass_variables, _get_params(func).values())
generate_param_expr(v or p.name, p.name, p.kind, potentially_modified)
for v, p in zip_longest(pass_variables, params_dd.values())
# (
# (v or p.name)
# if p.kind is inspect.Parameter.POSITIONAL_ONLY
# else f"{p.name}={v or p.name}"
# )
# for v, p in zip_longest(pass_variables, params_dd.values())
)
call = f"{_get_qualname(func, include_module=True)}({args})"
if assign:
Expand Down Expand Up @@ -1187,6 +1274,7 @@ def magic(
except_: Except = (),
style: str = "pytest",
annotate: Optional[bool] = None,
copy_args: bool = False,
) -> str:
"""Guess which ghostwriters to use, for a module or collection of functions.
Expand Down Expand Up @@ -1283,7 +1371,7 @@ def make_(how, *args, **kwargs):
sentinel = object()
returns = {get_type_hints(f).get("return", sentinel) for f in group}
if len(returns - {sentinel}) <= 1:
make_(_make_equiv_body, group, annotate=annotate)
make_(_make_equiv_body, group, annotate=annotate, copy_args=copy_args)
for f in group:
by_name.pop(_get_qualname(f, include_module=True))

Expand Down Expand Up @@ -1529,10 +1617,10 @@ def _get_varnames(funcs):
return var_names


def _make_equiv_body(funcs, except_, style, annotate):
def _make_equiv_body(funcs, except_, style, annotate, copy_args):
var_names = _get_varnames(funcs)
test_lines = [
_write_call(f, assign=vname, except_=except_)
_write_call(f, assign=vname, except_=except_, copy_args=copy_args)
for vname, f in zip(var_names, funcs)
]
assertions = "\n".join(
Expand Down Expand Up @@ -1569,7 +1657,7 @@ def _make_equiv_body(funcs, except_, style, annotate):
""".rstrip()


def _make_equiv_errors_body(funcs, except_, style, annotate):
def _make_equiv_errors_body(funcs, except_, style, annotate, copy_args):
var_names = _get_varnames(funcs)
first, *rest = funcs
first_call = _write_call(first, assign=var_names[0], except_=except_)
Expand All @@ -1587,8 +1675,8 @@ def _make_equiv_errors_body(funcs, except_, style, annotate):
ctx = "self.assertRaises"
block = EQUIV_CHECK_BLOCK.format(
ctx=ctx,
check_raises=indent(_write_call(f, except_=()), " "),
call=indent(_write_call(f, assign=vname, except_=()), " "),
check_raises=indent(_write_call(f, except_=(), copy_args=copy_args), " "),
call=indent(_write_call(f, assign=vname, except_=(), copy_args=copy_args), " "),
compare=indent(_assert_eq(style, var_names[0], vname), " "),
)
test_lines.append(block)
Expand All @@ -1610,6 +1698,7 @@ def equivalent(
except_: Except = (),
style: str = "pytest",
annotate: Optional[bool] = None,
copy_args: Optional[bool] = None,
) -> str:
"""Write source code for a property-based test of ``funcs``.
Expand Down Expand Up @@ -1642,9 +1731,9 @@ def equivalent(
annotate = _are_annotations_used(*funcs)

if allow_same_errors and not any(issubclass(Exception, ex) for ex in except_):
imports, source_code = _make_equiv_errors_body(funcs, except_, style, annotate)
imports, source_code = _make_equiv_errors_body(funcs, except_, style, annotate, copy_args)
else:
imports, source_code = _make_equiv_body(funcs, except_, style, annotate)
imports, source_code = _make_equiv_body(funcs, except_, style, annotate, copy_args)
return _make_test(imports, source_code)


Expand Down

0 comments on commit e439e8c

Please sign in to comment.