diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e2a3016..5bfdd2e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,13 +10,14 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: python-version: ${{matrix.python-version}} architecture: x64 + allow-prereleases: true - run: pip install nox==2023.4.22 nox-poetry poetry - run: nox --session mypy-${{matrix.python-version}} @@ -24,12 +25,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: python-version: ${{matrix.python-version}} architecture: x64 + allow-prereleases: true - run: pip install nox==2023.4.22 nox-poetry poetry - run: nox --sessions test-${{matrix.python-version}} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b977f30..318d0aa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,9 +44,7 @@ repos: repo: https://github.com/psf/black rev: 23.7.0 - hooks: - - additional_dependencies: - - black==23.7.0 - id: blacken-docs + - id: blacken-docs repo: https://github.com/asottile/blacken-docs rev: 1.16.0 diff --git a/noxfile.py b/noxfile.py index 45b5a0f..eb85f35 100644 --- a/noxfile.py +++ b/noxfile.py @@ -4,9 +4,8 @@ from nox_poetry import session nox.options.sessions = ["clean", "test", "report", "mypy"] -nox.options.reuse_existing_virtualenvs = True -python_versions = ["3.7", "3.8", "3.9", "3.10", "3.11"] +python_versions = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] @session(python="python3.11") @@ -18,13 +17,21 @@ def clean(session): @session(python=python_versions) def mypy(session): - session.install(".", "mypy", "pytest", "rich") + session.install(".", "mypy", "pytest", "rich", "inline-snapshot") session.run("mypy", "pysource_codegen", "tests") @session(python=python_versions) def test(session): - session.install(".", "pytest", "pytest-xdist", "rich", "coverage-enable-subprocess") + session.install( + ".", + "pytest", + "pytest-xdist", + "rich", + "coverage-enable-subprocess", + "inline-snapshot", + ) + session.env["COVERAGE_PROCESS_START"] = str( Path(__file__).parent / "pyproject.toml" ) diff --git a/poetry.lock b/poetry.lock index cf72e2e..990209d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,22 @@ # This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +[[package]] +name = "asttokens" +version = "2.4.0" +description = "Annotate AST trees with source code positions" +optional = false +python-versions = "*" +files = [ + {file = "asttokens-2.4.0-py2.py3-none-any.whl", hash = "sha256:cf8fc9e61a86461aa9fb161a14a0841a03c405fa829ac6b202670b3495d2ce69"}, + {file = "asttokens-2.4.0.tar.gz", hash = "sha256:2e0171b991b2c959acc6c49318049236844a5da1d65ba2672c4880c1c894834e"}, +] + +[package.dependencies] +six = ">=1.12.0" + +[package.extras] +test = ["astroid", "pytest"] + [[package]] name = "astunparse" version = "1.6.3" @@ -15,6 +32,71 @@ files = [ six = ">=1.6.1,<2.0" wheel = ">=0.23.0,<1.0" +[[package]] +name = "black" +version = "23.3.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.7" +files = [ + {file = "black-23.3.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:0945e13506be58bf7db93ee5853243eb368ace1c08a24c65ce108986eac65915"}, + {file = "black-23.3.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:67de8d0c209eb5b330cce2469503de11bca4085880d62f1628bd9972cc3366b9"}, + {file = "black-23.3.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:7c3eb7cea23904399866c55826b31c1f55bbcd3890ce22ff70466b907b6775c2"}, + {file = "black-23.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32daa9783106c28815d05b724238e30718f34155653d4d6e125dc7daec8e260c"}, + {file = "black-23.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:35d1381d7a22cc5b2be2f72c7dfdae4072a3336060635718cc7e1ede24221d6c"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:a8a968125d0a6a404842fa1bf0b349a568634f856aa08ffaff40ae0dfa52e7c6"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c7ab5790333c448903c4b721b59c0d80b11fe5e9803d8703e84dcb8da56fec1b"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:a6f6886c9869d4daae2d1715ce34a19bbc4b95006d20ed785ca00fa03cba312d"}, + {file = "black-23.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f3c333ea1dd6771b2d3777482429864f8e258899f6ff05826c3a4fcc5ce3f70"}, + {file = "black-23.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:11c410f71b876f961d1de77b9699ad19f939094c3a677323f43d7a29855fe326"}, + {file = "black-23.3.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:1d06691f1eb8de91cd1b322f21e3bfc9efe0c7ca1f0e1eb1db44ea367dff656b"}, + {file = "black-23.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50cb33cac881766a5cd9913e10ff75b1e8eb71babf4c7104f2e9c52da1fb7de2"}, + {file = "black-23.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e114420bf26b90d4b9daa597351337762b63039752bdf72bf361364c1aa05925"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:48f9d345675bb7fbc3dd85821b12487e1b9a75242028adad0333ce36ed2a6d27"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:714290490c18fb0126baa0fca0a54ee795f7502b44177e1ce7624ba1c00f2331"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:064101748afa12ad2291c2b91c960be28b817c0c7eaa35bec09cc63aa56493c5"}, + {file = "black-23.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:562bd3a70495facf56814293149e51aa1be9931567474993c7942ff7d3533961"}, + {file = "black-23.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:e198cf27888ad6f4ff331ca1c48ffc038848ea9f031a3b40ba36aced7e22f2c8"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:3238f2aacf827d18d26db07524e44741233ae09a584273aa059066d644ca7b30"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:f0bd2f4a58d6666500542b26354978218a9babcdc972722f4bf90779524515f3"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:92c543f6854c28a3c7f39f4d9b7694f9a6eb9d3c5e2ece488c327b6e7ea9b266"}, + {file = "black-23.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a150542a204124ed00683f0db1f5cf1c2aaaa9cc3495b7a3b5976fb136090ab"}, + {file = "black-23.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6b39abdfb402002b8a7d030ccc85cf5afff64ee90fa4c5aebc531e3ad0175ddb"}, + {file = "black-23.3.0-py3-none-any.whl", hash = "sha256:ec751418022185b0c1bb7d7736e6933d40bbb14c14a0abcf9123d1b159f98dd4"}, + {file = "black-23.3.0.tar.gz", hash = "sha256:1c7b8d606e728a41ea1ccbd7264677e494e87cf630e399262ced92d4a8dac940"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typed-ast = {version = ">=1.4.2", markers = "python_version < \"3.8\" and implementation_name == \"cpython\""} +typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} +importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} + [[package]] name = "colorama" version = "0.4.6" @@ -140,6 +222,20 @@ files = [ [package.extras] testing = ["hatch", "pre-commit", "pytest", "tox"] +[[package]] +name = "executing" +version = "1.2.0" +description = "Get the currently executing AST node of a frame, and other information" +optional = false +python-versions = "*" +files = [ + {file = "executing-1.2.0-py2.py3-none-any.whl", hash = "sha256:0314a69e37426e3608aada02473b4161d4caf5a4b244d1d0c48072b8fee7bacc"}, + {file = "executing-1.2.0.tar.gz", hash = "sha256:19da64c18d2d851112f09c287f8d3dbbdf725ab0e569077efb6cdcbd3497c107"}, +] + +[package.extras] +tests = ["asttokens", "littleutils", "pytest", "rich"] + [[package]] name = "importlib-metadata" version = "6.7.0" @@ -171,6 +267,23 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "inline-snapshot" +version = "0.4.0" +description = "golden master/snapshot/approval testing library which puts the values right into your source code" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "inline_snapshot-0.4.0-py3-none-any.whl", hash = "sha256:68a3316674eabfd071f980cfb766de16e2b3292bc16cc0eeac1433ad8717f904"}, + {file = "inline_snapshot-0.4.0.tar.gz", hash = "sha256:f80d5a257060943710822f347a8f22309bd2bba0d9846ab7fc7cfa0dc2a91e05"}, +] + +[package.dependencies] +asttokens = ">=2.0.5,<3.0.0" +black = ">=23.3.0,<24.0.0" +click = ">=8.1.4,<9.0.0" +executing = ">=1.2.0,<2.0.0" + [[package]] name = "markdown-it-py" version = "2.2.0" @@ -276,6 +389,35 @@ files = [ {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, ] +[[package]] +name = "pathspec" +version = "0.11.2" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, + {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, +] + +[[package]] +name = "platformdirs" +version = "3.10.0" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = ">=3.7" +files = [ + {file = "platformdirs-3.10.0-py3-none-any.whl", hash = "sha256:d7c24979f292f916dc9cbf8648319032f551ea8c49a4c9bf2fb556a02070ec1d"}, + {file = "platformdirs-3.10.0.tar.gz", hash = "sha256:b45696dab2d7cc691a3226759c0d3b00c47c8b6e293d96f6436f733303f77f6d"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.7.1", markers = "python_version < \"3.8\""} + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] + [[package]] name = "pluggy" version = "1.2.0" @@ -512,4 +654,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "b29e08c6dd1c99dbb385197dfe59e244c5a9ff232b8742654e89481a5334cfca" +content-hash = "6b0079486554bb6707e440c1bafd15cd8fb5cb0fa8215721c3275d8c4744db14" diff --git a/pyproject.toml b/pyproject.toml index af18cc1..7b8acb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ pysource-codegen = "pysource_codegen.__main__:run" [tool.poetry.dependencies] python = "^3.7" -astunparse = "^1.6.3" +astunparse = { version = "^1.6.3", python ="<3.9"} typed-ast = "^1.5.5" typing-extensions = "^4.7.1" @@ -21,6 +21,7 @@ pytest-xdist = {extras = ["psutil"], version = "^3.2.1"} pytest = "^7.2.1" mypy = "^1.2.0" coverage-enable-subprocess = "^1.0" +inline-snapshot = "^0.4.0" [build-system] requires = ["poetry-core"] @@ -31,3 +32,6 @@ source = ["tests","pysource_codegen"] parallel = true branch = true data_file = "$TOP/.coverage" + +[tool.coverage.report] +exclude_lines = ["assert False", "raise NotImplemented"] diff --git a/pysource_codegen/_codegen.py b/pysource_codegen/_codegen.py index f60ea47..6cc8252 100644 --- a/pysource_codegen/_codegen.py +++ b/pysource_codegen/_codegen.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import ast import inspect import itertools import re import sys -from typing import Dict -from typing import Union +from typing import Any from .types import BuiltinNodeType from .types import NodeType @@ -15,13 +16,15 @@ else: from astunparse import unparse # type: ignore +from ._limits import f_string_format_limit, f_string_expr_limit py38plus = (3, 8) <= sys.version_info py39plus = (3, 9) <= sys.version_info py310plus = (3, 10) <= sys.version_info py311plus = (3, 11) <= sys.version_info +py312plus = (3, 12) <= sys.version_info -type_infos: Dict[str, Union[NodeType, BuiltinNodeType, UnionNodeType]] = {} +type_infos: dict[str, NodeType | BuiltinNodeType | UnionNodeType] = {} def all_args(args): @@ -127,25 +130,40 @@ def inside(types, not_types=()): ]: return 0 - if parents[-1] == ("FormattedValue", "value") and child_name != "Constant": + # f-string + if parents[-1] == ("JoinedStr", "values") and child_name not in ( + "Constant", + "FormattedValue", + ): + return 0 + + if ( + not py312plus + and parents[-1] == ("FormattedValue", "value") + and child_name != "Constant" + ): return 0 if parents[-1] == ("FormattedValue", "format_spec") and child_name != "JoinedStr": return 0 - if parents[-1] == ("JoinedStr", "values") and child_name not in ( - "Constant", - "FormattedValue", + if ( + child_name == "JoinedStr" + and parents.count(("FormattedValue", "format_spec")) > f_string_format_limit ): return 0 - if child_name == "JoinedStr" and parent_types.count("JoinedStr") >= 2: + if ( + child_name == "JoinedStr" + and parents.count(("FormattedValue", "value")) > f_string_expr_limit + ): return 0 if child_name == "FormattedValue" and parents[-1][0] != "JoinedStr": # TODO: doc says this should be valid, maybe a bug in the python doc return 0 + # function statements if child_name in ( "Nonlocal", "Return", @@ -276,6 +294,38 @@ def valid_deco_parents(parents): if valid_deco_parents(deco_parents) and child_name != "Name": return 0 + # type alias + if py312plus: + if parents[-1] == ("TypeAlias", "name") and child_name != "Name": + return 0 + + if child_name in ( + "NamedExpr", + "Yield", + "YieldFrom", + "Await", + "ListComp", + "DictComp", + "SetComp", + "GeneratorExp", + ) and inside( + ( + "ClassDef.bases", + "ClassDef.keywords", + "FunctionDef.returns", + "AsyncFunctionDef.returns", + "arg.annotation", + "TypeAlias.value", + "TypeVar.bound", + ) + ): + # todo this should only be invalid in type scopes (when the class/def has type parameters) + # and only for async comprehensions + return 0 + + if child_name == "Await" and inside("AnnAssign.annotation"): + return 0 + if child_name == "Expr": return 30 @@ -396,73 +446,6 @@ def visit_NamedExpr(self, node: ast.NamedExpr): node = Transformer().visit(node) - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Module)): - while True: - try: - code = unparse(ast.fix_missing_locations(node)) - compile(code, "", "exec") - break - except ValueError: - break - except SyntaxError as e: - m = re.match("name '(.*)' is used prior to global declaration", str(e)) - - if not m: - m = re.match( - "name '(.*)' is assigned to before global declaration", str(e) - ) - - if not m: - m = re.match("name '(.*)' is parameter and global", str(e)) - if not m: - m = re.match("annotated name '(.*)' can't be global", str(e)) - if not m: - m = re.match("name '(.*)' is nonlocal and global", str(e)) - - if m: - name = m.group(1) - - class Transformer(ast.NodeTransformer): - def visit_Global(self, node): - if name in node.names: - node.names.remove(name) - if not node.names: - return ast.Pass() - return node - - node = Transformer().visit(node) - continue - - m = re.match("name '(.*)' is parameter and nonlocal", str(e)) - if not m: - m = re.match( - "name '(.*)' is used prior to nonlocal declaration", str(e) - ) - if not m: - m = re.match("no binding for nonlocal '(.*)' found", str(e)) - if not m: - m = re.match( - "name '(.*)' is assigned to before nonlocal declaration", str(e) - ) - if not m: - m = re.match("annotated name '(.*)' can't be nonlocal", str(e)) - - if m: - name = m.group(1) - - class Transformer(ast.NodeTransformer): - def visit_Nonlocal(self, node): - if name in node.names: - node.names.remove(name) - if not node.names: - return ast.Pass() - return node - - node = Transformer().visit(node) - continue - - break - # pattern matching if sys.version_info >= (3, 10): @@ -677,6 +660,189 @@ def visit_MatchOr(self, node: ast.MatchOr): if node.args.kwarg: node.args.kwarg.annotation = None + if sys.version_info >= (3, 12): + if isinstance(node, ast.Global): + node.names = list(set(node.names)) + + # type scopes + if hasattr(node, "type_params"): + node.type_params = unique_by(node.type_params, lambda p: p.name) + + def cleanup_annotation(annotation): + class Transformer(ast.NodeTransformer): + def visit_NamedExpr(self, node: ast.NamedExpr): + return self.visit(node.value) + + def visit_Yield(self, node: ast.Yield) -> Any: + if node.value is None: + return ast.Constant(value=None) + return self.visit(node.value) + + def visit_YieldFrom(self, node: ast.YieldFrom) -> Any: + return self.visit(node.value) + + def visit_Lambda(self, node: ast.Lambda) -> Any: + return self.visit(node.body) + + return Transformer().visit(annotation) + + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.type_params + ): + for arg in [ + *node.args.posonlyargs, + *node.args.args, + *node.args.kwonlyargs, + node.args.vararg, + node.args.kwarg, + ]: + if arg is not None and arg.annotation: + arg.annotation = cleanup_annotation(arg.annotation) + + if node.returns is not None: + node.returns = cleanup_annotation(node.returns) + + if isinstance(node, ast.ClassDef) and node.type_params: + node.bases = [cleanup_annotation(b) for b in node.bases] + for kw in node.keywords: + kw.value = cleanup_annotation(kw.value) + + for n in ast.walk(node): + if isinstance(n, ast.TypeAlias): + n.value = cleanup_annotation(n.value) + + if isinstance(node, ast.ClassDef): + for n in ast.walk(node): + if isinstance(n, ast.TypeVar) and n.bound is not None: + n.bound = cleanup_annotation(n.bound) + + if isinstance(node, ast.AnnAssign): + node.annotation = cleanup_annotation(node.annotation) + + return node + + +def fix_result(node): + return fix_nonlocal(node) + + +def arguments(node: ast.FunctionDef | ast.AsyncFunctionDef) -> list[ast.arg]: + args = node.args + l = [ + *args.args, + args.vararg, + *args.kwonlyargs, + args.kwarg, + ] + + if sys.version_info >= (3, 8): + l += args.posonlyargs + + return [arg for arg in l if arg is not None] + + +def fix_nonlocal(node): + class NonLocalFixer(ast.NodeTransformer): + def __init__(self, locals, nonlocals, globals): + self.locals = set(locals) + self.used_names = set(locals) + + # nonlocals from the parent function + self.nonlocals = set(nonlocals) + + # globals from the global scope + self.globals = set(globals) + + def visit_Name(self, node: ast.Name) -> Any: + if isinstance(node.ctx, ast.Store): + self.locals.add(node.id) + self.used_names.add(node.id) + return node + + def visit_Nonlocal(self, node: ast.Nonlocal) -> Any: + node.names = [ + name + for name in node.names + if name not in self.locals + and name in self.nonlocals + and name not in self.used_names + ] + self.locals |= set(node.names) + + if not node.names: + return ast.Pass() + + return node + + def visit_Global(self, node: ast.Global) -> Any: + print("visit global", node.names, self.globals, self.locals) + node.names = [ + name + for name in node.names + if name not in self.locals and name not in self.used_names + ] + self.locals |= set(node.names) + + if not node.names: + return ast.Pass() + + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + for default in [*node.args.defaults, *node.args.kw_defaults]: + if default is not None: + self.visit(default) + return node + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: + for default in [*node.args.defaults, *node.args.kw_defaults]: + if default is not None: + self.visit(default) + return node + + def visit_Lambda(self, node: ast.Lambda) -> Any: + for default in [*node.args.defaults, *node.args.kw_defaults]: + if default is not None: + self.visit(default) + return node + + class FunctionTransformer(ast.NodeTransformer): + def __init__(self, nonlocals, globals): + self.nonlocals = set(nonlocals) + self.globals = set(globals) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + return self.handle_function(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: + return self.handle_function(node) + + def visit_Lambda(self, node: ast.Lambda) -> Any: + # there are no globals/nonlocals/functiondefs in lambdas + return node + + def handle_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> Any: + args = node.args + names = {arg.arg for arg in arguments(node)} + if sys.version_info >= (3, 12): + names |= {typ.name for typ in node.type_params} # type: ignore + print("handle function") + + fixer = NonLocalFixer(names, self.nonlocals, self.globals) + node.body = [fixer.visit(stmt) for stmt in node.body] + + ft = FunctionTransformer(fixer.locals | self.nonlocals, self.globals) + node.body = [ft.visit(stmt) for stmt in node.body] + + return node + + print("handle module", node) + + fixer = NonLocalFixer([], [], []) + node = fixer.visit(node) + + node = FunctionTransformer([], []).visit(node) return node @@ -691,6 +857,11 @@ def cnd(self): return self.rand.choice([True, False]) def generate(self, name: str, parents=(), depth=0): + result = self.generate_impl(name, parents, depth) + result = fix_result(result) + return result + + def generate_impl(self, name: str, parents=(), depth=0): depth += 1 self.nodes += 1 @@ -726,14 +897,14 @@ def range_for(child, attr_name): def child_node(n, t, q, parents): if q == "": - return self.generate(t, parents, depth) + return self.generate_impl(t, parents, depth) elif q == "*": return [ - self.generate(t, parents, depth) + self.generate_impl(t, parents, depth) for _ in range_for(parents[-1][0], n) ] elif q == "?": - return self.generate(t, parents, depth) if self.cnd() else None + return self.generate_impl(t, parents, depth) if self.cnd() else None else: assert False @@ -758,7 +929,7 @@ def child_node(n, t, q, parents): # TODO: better handling of `type?` return None - return self.generate( + return self.generate_impl( self.rand.choices(*zip(*options.items()))[0], parents, depth ) if isinstance(info, BuiltinNodeType): @@ -790,3 +961,70 @@ def generate( tree = generator.generate(root_node) ast.fix_missing_locations(tree) return unparse(tree) + + +# next algo + +# design targets: +# * enumerate "all" possible ast-node combinations +# * check if propability 0 would produce incorrect code +# * the algo should be able to generate every possible syntax combination for every python version. +# * hypothesis integration +# * do not use compile() in the implementation +# * generation should be customizable (custom propabilities and random values) + +# features: +# * node-context: function-scope async-scope type-scope class-scope ... +# * names: nonlocal global + +from dataclasses import dataclass + + +@dataclass +class ParentRef: + node: PartialNode + attr_name: str + index: int + _context: dict + + def __getattr__(self, name): + if name.startswith("ctx_"): + return getattr(node, name) + raise AttributeError + + +# (d:=[n] | q_parent("Delete.targets")) and len(d.targets)==1 + + +@dataclass +class PartialValue: + value: int | str | bool + + +@dataclass +class PartialNode: + _node_type_name: str + parent_ref: ParentRef | None + _defined_attrs: dict + _context: dict + + def inside(self, spec) -> PartialNode | None: + ... + + @property + def parent(self): + return self.parent_ref.node + + def __getattr__(self, name): + if name.startswith("ctx_"): + return getattr(node, name) + + if name not in self._defined_attrs: + raise RuntimeError(f"{self._node_type_name}.{name} is not defined jet") + + return self._defined_attrs[name] + + +def gen(node: PartialNode): + # parents [(node,attr_name)] + pass diff --git a/pysource_codegen/_limits.py b/pysource_codegen/_limits.py new file mode 100644 index 0000000..900da34 --- /dev/null +++ b/pysource_codegen/_limits.py @@ -0,0 +1,35 @@ +def calc_f_string_expr_limit(): + n = 0 + s = "1" + while True: + for q in ("'", '"', '"""', "'''"): + ns = "f" + q + "{" + s + "}" + q + + try: + eval(ns) + s = ns + break + except: + continue + else: + return n + n += 1 + + +def calc_f_string_format_limit(): + n = 0 + s = "{1}" + while True: + s = "{2:" + s + "}" + + try: + eval(f"f'{s}'") + except: + break + n += 1 + + return n + + +f_string_expr_limit = calc_f_string_expr_limit() +f_string_format_limit = calc_f_string_format_limit() diff --git a/tests/test_codegen.py b/tests/test_codegen.py index aa32e73..f4f5452 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -8,13 +8,13 @@ from pysource_codegen import generate -@pytest.mark.parametrize("seed", list(range(500))) +@pytest.mark.parametrize("seed", list(range(50))) def test_codegen(seed): with tempfile.NamedTemporaryFile("w", delete=False) as file: + source = generate(seed) + file.write(source) + file.flush() try: - source = generate(seed) - file.write(source) - file.flush() compile(source, file.name, "exec") except Exception as e: diff --git a/tests/test_fix_nonlocal.py b/tests/test_fix_nonlocal.py new file mode 100644 index 0000000..650a28e --- /dev/null +++ b/tests/test_fix_nonlocal.py @@ -0,0 +1,249 @@ +import ast +import sys + +from inline_snapshot import snapshot + +from pysource_codegen._codegen import fix_nonlocal +from pysource_codegen._codegen import unparse + +known_errors = snapshot( + [ + "no binding for nonlocal 'x' found", + "name 'x' is parameter and nonlocal", + "name 'x' is used prior to nonlocal declaration", + "name 'x' is assigned to before nonlocal declaration", + "name 'x' is parameter and global", + "name 'x' is assigned to before global declaration", + "name 'x' is used prior to global declaration", + "annotated name 'x' can't be global", + "name 'x' is nonlocal and global", + "annotated name 'x' can't be nonlocal", + "nonlocal binding not allowed for type parameter 'x'", + ] +) + + +def check_code(src, snapshot_value): + try: + compile(src, "", "exec") + except SyntaxError as error: + error_str = str(error) + assert error_str.split(" (")[0] in known_errors + else: + assert False, "error expected" + + tree = ast.parse(src) + + print("original tree:") + print(ast.dump(tree, **(dict(indent=2) if sys.version_info >= (3, 9) else {}))) + print("original src:") + print(src) + print("error:", str(error_str)) + + tree = fix_nonlocal(tree) + new_src = unparse(tree).strip() + "\n" + + print() + print("transformed tree:") + print(ast.dump(tree, **(dict(indent=2) if sys.version_info >= (3, 9) else {}))) + print("transformed src:") + print(new_src) + + compile(new_src, "", "exec") + + assert new_src == snapshot_value + + +def test_global(): + check_code( + """ +def a(x): + global x + """, + snapshot( + """\ +def a(x): + pass +""" + ), + ) + + check_code( + """ +def a(): + x = 0 + global x + """, + snapshot( + """\ +def a(): + x = 0 + pass +""" + ), + ) + + check_code( + """ +def a(): + print(x) + global x + """, + snapshot( + """\ +def a(): + print(x) + pass +""" + ), + ) + check_code( + """ +def a(): + x:int + global x + """, + snapshot( + """\ +def a(): + x: int + pass +""" + ), + ) + + check_code( + """ + +def a(): + x=5 + def b(): + nonlocal x + global x + """, + snapshot( + """\ +def a(): + x = 5 + + def b(): + nonlocal x + pass +""" + ), + ) + + +def test_nonlocal(): + check_code( + """ +def b(): + def a(): + nonlocal x + """, + snapshot( + """\ +def b(): + + def a(): + pass +""" + ), + ) + + check_code( + """ +def b(): + x=0 + def a(x): + nonlocal x + """, + snapshot( + """\ +def b(): + x = 0 + + def a(x): + pass +""" + ), + ) + + check_code( + """ +def b(): + x=0 + def a(): + print(x) + nonlocal x + """, + snapshot( + """\ +def b(): + x = 0 + + def a(): + print(x) + pass +""" + ), + ) + + check_code( + """ +def b(): + x=0 + def a(): + x=5 + nonlocal x + """, + snapshot( + """\ +def b(): + x = 0 + + def a(): + x = 5 + pass +""" + ), + ) + + check_code( + """ +def b(): + x=0 + def a(): + x:int + nonlocal x + """, + snapshot( + """\ +def b(): + x = 0 + + def a(): + x: int + pass +""" + ), + ) + + if sys.version_info >= (3, 12): + check_code( + """ +def b(): + x=0 + def a[x:int](): + nonlocal x + """, + snapshot( + """\ +def b(): + x = 0 + + def a[x: int](): + pass +""" + ), + )