diff --git a/src/puyapy/awst_build/validation/__init__.py b/src/puya/awst/validation/__init__.py similarity index 100% rename from src/puyapy/awst_build/validation/__init__.py rename to src/puya/awst/validation/__init__.py diff --git a/src/puyapy/awst_build/validation/arc4_copy.py b/src/puya/awst/validation/arc4_copy.py similarity index 99% rename from src/puyapy/awst_build/validation/arc4_copy.py rename to src/puya/awst/validation/arc4_copy.py index 4a6f597aef..731c49172c 100644 --- a/src/puyapy/awst_build/validation/arc4_copy.py +++ b/src/puya/awst/validation/arc4_copy.py @@ -1,6 +1,7 @@ from collections.abc import Iterator import attrs + from puya import log from puya.awst import ( nodes as awst_nodes, diff --git a/src/puyapy/awst_build/validation/base_invoker.py b/src/puya/awst/validation/base_invoker.py similarity index 100% rename from src/puyapy/awst_build/validation/base_invoker.py rename to src/puya/awst/validation/base_invoker.py diff --git a/src/puyapy/awst_build/validation/inner_transactions.py b/src/puya/awst/validation/inner_transactions.py similarity index 100% rename from src/puyapy/awst_build/validation/inner_transactions.py rename to src/puya/awst/validation/inner_transactions.py diff --git a/src/puyapy/awst_build/validation/labels.py b/src/puya/awst/validation/labels.py similarity index 100% rename from src/puyapy/awst_build/validation/labels.py rename to src/puya/awst/validation/labels.py diff --git a/src/puyapy/awst_build/validation/main.py b/src/puya/awst/validation/main.py similarity index 57% rename from src/puyapy/awst_build/validation/main.py rename to src/puya/awst/validation/main.py index 61f933c13c..98df0bedcb 100644 --- a/src/puyapy/awst_build/validation/main.py +++ b/src/puya/awst/validation/main.py @@ -1,15 +1,14 @@ from puya.awst import nodes as awst_nodes - -from puyapy.awst_build.validation.arc4_copy import ARC4CopyValidator -from puyapy.awst_build.validation.base_invoker import BaseInvokerValidator -from puyapy.awst_build.validation.inner_transactions import ( +from puya.awst.validation.arc4_copy import ARC4CopyValidator +from puya.awst.validation.base_invoker import BaseInvokerValidator +from puya.awst.validation.inner_transactions import ( InnerTransactionsValidator, InnerTransactionUsedInALoopValidator, StaleInnerTransactionsValidator, ) -from puyapy.awst_build.validation.labels import LabelsValidator -from puyapy.awst_build.validation.scratch_slots import ScratchSlotReservationValidator -from puyapy.awst_build.validation.storage import StorageTypesValidator +from puya.awst.validation.labels import LabelsValidator +from puya.awst.validation.scratch_slots import ScratchSlotReservationValidator +from puya.awst.validation.storage import StorageTypesValidator def validate_awst(module: awst_nodes.AWST) -> None: diff --git a/src/puyapy/awst_build/validation/scratch_slots.py b/src/puya/awst/validation/scratch_slots.py similarity index 100% rename from src/puyapy/awst_build/validation/scratch_slots.py rename to src/puya/awst/validation/scratch_slots.py diff --git a/src/puyapy/awst_build/validation/storage.py b/src/puya/awst/validation/storage.py similarity index 100% rename from src/puyapy/awst_build/validation/storage.py rename to src/puya/awst/validation/storage.py diff --git a/src/puya/compile.py b/src/puya/compile.py index 212a5bcbd1..20b7cfb050 100644 --- a/src/puya/compile.py +++ b/src/puya/compile.py @@ -10,6 +10,7 @@ from puya.arc32 import create_arc32_json from puya.artifact_sorter import ArtifactCompilationSorter from puya.awst.nodes import AWST +from puya.awst.validation.main import validate_awst from puya.context import CompileContext from puya.errors import CodeError, InternalError from puya.ir.main import awst_to_ir, optimize_and_destructure_ir @@ -52,6 +53,8 @@ def awst_to_teal( *, write: bool = True, ) -> list[CompilationArtifact]: + validate_awst(awst) + log_ctx.exit_if_errors() context = CompileContext( options=options, compilation_set=compilation_set, diff --git a/src/puyapy/awst_build/module.py b/src/puyapy/awst_build/module.py index a9c117c09d..90738ef75f 100644 --- a/src/puyapy/awst_build/module.py +++ b/src/puyapy/awst_build/module.py @@ -27,7 +27,6 @@ get_decorators_by_fullname, get_unaliased_fullname, ) -from puyapy.awst_build.validation.main import validate_awst logger = log.get_logger(__name__) @@ -67,7 +66,6 @@ def convert(self) -> AWST: for deferred in deferrals: awst_node = deferred(self.context) awst.append(awst_node) - validate_awst(awst) # TODO: move/split this to/with puya core return awst # Supported Statements diff --git a/tests/test_expected_output/data.py b/tests/test_expected_output/data.py index e97ec4653b..999a7849f2 100644 --- a/tests/test_expected_output/data.py +++ b/tests/test_expected_output/data.py @@ -1,13 +1,14 @@ -from __future__ import annotations - import contextlib import difflib import tempfile import typing as t +from collections.abc import Iterator, Sequence from pathlib import Path +import _pytest._code.code import attrs import pytest +from puya.awst.nodes import AWST from puya.awst.to_code_visitor import ToCodeVisitor from puya.compile import awst_to_teal from puya.errors import PuyaError, log_exceptions @@ -19,12 +20,6 @@ from tests.utils import narrowed_compile_context -if t.TYPE_CHECKING: - from collections.abc import Sequence - - import _pytest._code.code - from puya.awst.nodes import AWST - THIS_DIR = Path(__file__).parent REPO_DIR = THIS_DIR.parent.parent CASE_COMMENT = "##" @@ -300,49 +295,52 @@ def compile_and_update_cases(cases: list[TestCase]) -> None: awst, compilation_set = transform_ast(parse_result) # lower each case further if possible and process for case in cases: - if case_has_awst_errors(awst_log_ctx.logs, case): - case_logs = [] - else: - # lower awst for each case individually to order to get any output - # from lower layers - # this needs a new logging context so AWST errors from other cases - # are not seen - case_options = attrs.evolve( - puyapy_options, cli_template_definitions=case.template_vars - ) - case_sources_by_path, case_compilation_set = narrowed_compile_context( - parse_result, - case_path[case], - awst, - compilation_set, + case_awst = [ + n + for n in awst + if n.source_location.file == case_path[case] + # hacky way to keep "framework" sources included, good enough for now + # the real solution here is to remove mypy, so we don't need to do this special + # combine+split of sources to achieve decent mypy parsing speed + or n.source_location.line < 0 + ] + # lower awst for each case individually to order to get any output + # from lower layers + # this needs a new logging context so AWST errors from other cases + # are not seen + case_options = attrs.evolve( + puyapy_options, cli_template_definitions=case.template_vars + ) + case_sources_by_path, case_compilation_set = narrowed_compile_context( + parse_result, + case_path[case], + awst, + compilation_set, + case_options, + ) + with ( + contextlib.suppress(SystemExit), + logging_context() as case_log_ctx, + log_exceptions(), + ): + case_log_ctx.logs.extend(filter_logs(awst_log_ctx.logs, case)) + awst_to_teal( + case_log_ctx, case_options, + case_compilation_set, + case_sources_by_path, + case_awst, + write=False, ) - with ( - contextlib.suppress(SystemExit), - logging_context() as case_log_ctx, - log_exceptions(), - ): - awst_to_teal( - case_log_ctx, - case_options, - case_compilation_set, - case_sources_by_path, - awst, - write=False, - ) - case_logs = case_log_ctx.logs - process_test_case(case, awst_log_ctx.logs + case_logs, awst) + process_test_case(case, case_log_ctx.logs, case_awst) -def case_has_awst_errors(captured_logs: list[Log], case: TestCase) -> bool: +def filter_logs(captured_logs: list[Log], case: TestCase) -> Iterator[Log]: for file in case.files: path = file.src_path assert path is not None abs_path = path.resolve() - path_records = [record for record in captured_logs if record.file == abs_path] - if any(r.level == LogLevel.error and r.line is not None for r in path_records): - return True - return False + yield from (record for record in captured_logs if record.file == abs_path) def get_python_file_name(name: str) -> str: @@ -369,13 +367,11 @@ def process_test_case(case: TestCase, captured_logs: Sequence[Log], awst: AWST) for file in case.files: path = file.src_path assert path is not None - abs_path = path.resolve() expected_output = { (line, message) for line, messages in file.expected_output.items() for message in messages } - path_records = [record for record in captured_logs if record.file == abs_path] seen_output = { ( record.line, @@ -384,7 +380,7 @@ def process_test_case(case: TestCase, captured_logs: Sequence[Log], awst: AWST) output=record.message.strip(), ), ) - for record in path_records + for record in captured_logs if record.line is not None and record.level >= MIN_LEVEL_TO_REPORT } file_missing_output = expected_output - seen_output