From 19e6ad740f86750c74f5a4ec15125cd481e81d8f Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Mon, 15 Jul 2024 09:14:08 +0200 Subject: [PATCH] Fix for nested errors are not correctly resolved When an exception is thrown within a WCI, the last frame is used as context to add the input lines. This does not work if the exception is thrown within a function within WCI. Therefore we now iterate through all traceback frames to find the one corresponding to WCI. The corresponding PR in WCI https://github.com/osscar-org/widget-code-input/pull/26 to solve it there. --- src/scwidgets/code/_widget_code_input.py | 110 +++++++++++++++++++++++ tests/test_code.py | 2 +- 2 files changed, 111 insertions(+), 1 deletion(-) diff --git a/src/scwidgets/code/_widget_code_input.py b/src/scwidgets/code/_widget_code_input.py index cb5fb53..2fed647 100644 --- a/src/scwidgets/code/_widget_code_input.py +++ b/src/scwidgets/code/_widget_code_input.py @@ -1,9 +1,18 @@ import inspect import re +import sys +import traceback import types +import warnings +from functools import wraps from typing import List, Optional from widget_code_input import WidgetCodeInput +from widget_code_input.utils import ( + CodeValidationError, + format_syntax_error_msg, + is_valid_variable_name, +) from ..check import Check @@ -127,3 +136,104 @@ def get_code(func: types.FunctionType) -> str: ) return source + + def get_function_object(self): + """ + Return the compiled function object. + + This can be assigned to a variable and then called, for instance:: + + func = widget.get_function_object() # This can raise a SyntaxError + retval = func(parameters) + + :raise SyntaxError: if the function code has syntax errors (or if + the function name is not a valid identifier) + """ + globals_dict = { + "__builtins__": globals()["__builtins__"], + "__name__": "__main__", + "__doc__": None, + "__package__": None, + } + + if not is_valid_variable_name(self.function_name): + raise SyntaxError("Invalid function name '{}'".format(self.function_name)) + + # Optionally one could do a ast.parse here already, to check syntax + # before execution + try: + exec( + compile(self.full_function_code, __name__, "exec", dont_inherit=True), + globals_dict, + ) + except SyntaxError as exc: + raise CodeValidationError( + format_syntax_error_msg(exc), orig_exc=exc + ) from exc + + function_object = globals_dict[self.function_name] + + def catch_exceptions(func): + @wraps(func) + def wrapper(*args, **kwargs): + """Wrap and check exceptions to return a longer and clearer + exception.""" + + try: + return func(*args, **kwargs) + except Exception as exc: + err_msg = format_generic_error_msg(exc, code_widget=self) + raise CodeValidationError(err_msg, orig_exc=exc) from exc + + return wrapper + + return catch_exceptions(function_object) + + +# Temporary fix until https://github.com/osscar-org/widget-code-input/pull/26 +# is merged +def format_generic_error_msg(exc, code_widget): + """ + Return a string reproducing the traceback of a typical error. + This includes line numbers, as well as neighboring lines. + + It will require also the code_widget instance, to get the actual source code. + + :note: this must be called from withou the exception, as it will get the + current traceback state. + + :param exc: The exception that is being processed. + :param code_widget: the instance of the code widget with the code that + raised the exception. + """ + error_class, _, tb = sys.exc_info() + frame_summaries = traceback.extract_tb(tb) + # The correct frame summary corresponding to widget_code_intput is not + # always at the end therefore we loop through all of them + wci_frame_summary = None + for frame_summary in frame_summaries: + if frame_summary.filename == "widget_code_input": + wci_frame_summary = frame_summary + if wci_frame_summary is None: + warnings.warn( + "Could not find traceback frame corresponding to " + "widget_code_input, we output whole error message.", + stacklevel=2, + ) + + return exc + line_number = wci_frame_summary[1] + code_lines = code_widget.full_function_code.splitlines() + + err_msg = f"{error_class.__name__} in code input: {str(exc)}\n" + if line_number > 2: + err_msg += f" {line_number - 2:4d} {code_lines[line_number - 3]}\n" + if line_number > 1: + err_msg += f" {line_number - 1:4d} {code_lines[line_number - 2]}\n" + err_msg += f"---> {line_number:4d} {code_lines[line_number - 1]}\n" + if line_number < len(code_lines): + err_msg += f" {line_number + 1:4d} {code_lines[line_number]}\n" + if line_number < len(code_lines) - 1: + err_msg += f" {line_number + 2:4d} {code_lines[line_number + 1]}\n" + + return err_msg diff --git a/tests/test_code.py b/tests/test_code.py index 68a6139..e1eccdd 100644 --- a/tests/test_code.py +++ b/tests/test_code.py @@ -219,7 +219,7 @@ def test_run_code(self, code_ex): def test_erroneous_run_code(self, code_ex): with pytest.raises( CodeValidationError, - match="NameError in code input: name 'bug' is not defined.*", + match="name 'bug' is not defined.*", ): code_ex.run_code(**code_ex.parameters)