From ac8335ab26b8e99d20dba29cc59f96273badc715 Mon Sep 17 00:00:00 2001 From: Theo Date: Mon, 30 Dec 2024 21:56:33 +0100 Subject: [PATCH] Refactor a bit and implement a bit of hot code realoding without throwing --- bootstrap/hot_reloading/engine.py | 132 +++++++++++++++++++++--------- bootstrap/hot_reloading/module.py | 4 +- bootstrap/tui/builder_ui.py | 10 +-- 3 files changed, 99 insertions(+), 47 deletions(-) diff --git a/bootstrap/hot_reloading/engine.py b/bootstrap/hot_reloading/engine.py index ddaaa8a..6c25ffd 100644 --- a/bootstrap/hot_reloading/engine.py +++ b/bootstrap/hot_reloading/engine.py @@ -3,7 +3,7 @@ import inspect import sys import traceback -from types import FrameType +from types import CodeType, FrameType, TracebackType from typing import ( Callable, Optional, @@ -25,7 +25,7 @@ def __init__(self, ui: App): @classmethod def get_class_frame( - cls, func: Callable, exc_traceback + cls, func: Callable, exc_traceback: TracebackType ) -> Tuple[Optional[FrameType], Optional[FrameType]]: """ Find the frame of the last callable within the scope of the MatchboxModule in @@ -34,7 +34,7 @@ def get_class_frame( function that threw (or originated) the exception. """ print("============= get_class_frame() =========") - last_frame_in_scope, last_frame = None, None + root_frame, last_frame = None, None for frame, _ in traceback.walk_tb(exc_traceback): last_frame = frame print(frame.f_code.co_qualname) @@ -48,13 +48,15 @@ def get_class_frame( and "self" in inspect.getargs(frame.f_code).args ): print(f"Found method {val} in traceback, continuing...") - last_frame_in_scope = frame + root_frame = frame print("============================================") - return last_frame_in_scope, last_frame + return root_frame, last_frame @classmethod def get_lambda_child_frame( - cls, func: Callable, exc_traceback + cls, + func: Callable, + exc_traceback: TracebackType, ) -> Tuple[Optional[FrameType], Optional[str]]: """ Find the frame of the last callable within the scope of the MatchboxModule in @@ -95,21 +97,33 @@ def get_lambda_child_frame( return None, None @classmethod - def get_function_frame(cls, func: Callable, exc_traceback) -> Optional[FrameType]: + def get_function_frame( + cls, func: Callable, exc_traceback: TracebackType + ) -> Tuple[Optional[FrameType], Optional[FrameType]]: + """ + Find the frame of the last callable within the scope of the MatchboxModule in + the traceback. In this instance, the MatchboxModule is a module-level function + so we want to find its frame. + """ print("============= get_function_frame() =========") - last_frame = None + root_frame, last_frame = None, None for frame, _ in traceback.walk_tb(exc_traceback): print(frame.f_code.co_qualname) + last_frame = frame if frame.f_code.co_qualname == func.__name__: print( f"Found module.underlying_fn ({func.__name__}) in traceback, continuing..." ) - for name, val in inspect.getmembers(func.__module__): - if name == frame.f_code.co_name: - print(f"Found function {val} in traceback, continuing...") - last_frame = frame + root_frame = frame + # print( + # f"Looking into its {frame.f_code.co_qualname}'s members via {func.__module__}..." + # ) + # for name, val in inspect.getmembers(func.__module__): + # if name == frame.f_code.co_name: + # print(f"Found function {val} in traceback, continuing...") + # last_frame = frame print("============================================") - return last_frame + return root_frame, last_frame async def catch_and_hang(self, module: MatchboxModule, *args, **kwargs): try: @@ -128,7 +142,9 @@ async def catch_and_hang(self, module: MatchboxModule, *args, **kwargs): return output except Exception as exception: # If the exception came from the wrapper itself, we should not catch it! - exc_type, exc_value, exc_traceback = sys.exc_info() + exc_traceback: Optional[TracebackType] = None + _, _, exc_traceback = sys.exc_info() + assert exc_traceback is not None if exc_traceback.tb_next is None: self.ui.exit(1) raise RuntimeError("Could not find the next frame in the call stack!") @@ -157,18 +173,24 @@ async def catch_and_hang(self, module: MatchboxModule, *args, **kwargs): root_frame, lambda_argname = self.get_lambda_child_frame( func, exc_traceback ) - module.throw_lambda_argname = lambda_argname + module.root_lambda_argname = lambda_argname elif inspect.isfunction(func): - root_frame = self.get_function_frame(func, exc_traceback) + root_frame, throwing_frame = self.get_function_frame( + func, exc_traceback + ) else: self.ui.exit(1) raise NotImplementedError() locals_f = root_frame if throwing_frame is None else throwing_frame if not root_frame: - self.ui.exit(1) - raise RuntimeError( + await self.ui.hang(threw=True) + self.ui.print_err( f"Could not find the frame of the original function {func} in the traceback." ) + # self.ui.exit(1) + # raise RuntimeError( + # f"Could not find the frame of the original function {func} in the traceback." + # ) else: # NOTE: Here we reloaded the root frame of the throwing call, i.e. # the frame that's in the scope of our MatchboxModule so a class @@ -176,7 +198,7 @@ async def catch_and_hang(self, module: MatchboxModule, *args, **kwargs): await self.ui.set_locals( locals_f.f_locals, locals_f.f_code.co_qualname ) - module.throw_frame = root_frame + module.root_frame = root_frame info = ( ( f"Exception thrown in <{locals_f.f_code.co_qualname}>" @@ -195,29 +217,13 @@ async def catch_and_hang(self, module: MatchboxModule, *args, **kwargs): self.ui.print_info("Hanged.") await self.ui.hang(threw=True) - async def reload_module(self, module: MatchboxModule): - if module.to_reload and module.throw_frame is None: - self.ui.exit(1) - raise RuntimeError( - f"Module {module} is set to reload but we don't have the frame that threw!" - ) - elif not module.to_reload: - # TODO: This works as long as we init the builder UI with skip_frozen=True - # so that we can get the root frame at least once, but it will fail if the - # module never throws in the first place. We should fix it by decoupling - # root frame finding from the catch_and_hang() method above. Or ideally by - # finding the code object more efficiently! - self.ui.print_info(f"Reloading MatchboxModule({module.underlying_fn})...") - self.ui.print_err("Hot reloading without throwing is not implemented yet.") - code_obj = None - await self.ui.hang(threw=False) + async def _reload_code_obj(self, code_obj: CodeType, module: MatchboxModule): self.ui.log_tracer( Text( - f"Reloading code from {module.throw_frame.f_code.co_filename}", + f"Reloading code from {code_obj.co_filename}", style="purple", ) ) - code_obj = module.throw_frame.f_code print(code_obj.co_qualname, inspect.getmodule(code_obj)) code_module = inspect.getmodule(code_obj) if code_module is None: @@ -273,9 +279,9 @@ async def reload_module(self, module: MatchboxModule): ) ) if module.underlying_fn.__name__ == "": - assert module.throw_lambda_argname is not None + assert module.root_lambda_argname is not None module.reload_surgically_in_lambda( - module.throw_lambda_argname, code_obj.co_name, rld_callable + module.root_lambda_argname, code_obj.co_name, rld_callable ) else: module.reload_surgically(code_obj.co_name, rld_callable) @@ -296,5 +302,51 @@ async def reload_module(self, module: MatchboxModule): print(inspect.getsource(func)) module.reload(func) return + # self.ui.hang(threw=False) while True: - await asyncio.sleep(1) + await asyncio.sleep(1) # TODO: Why not hang? + + async def reload_module(self, module: MatchboxModule): + if module.to_reload: + if module.root_frame is None: + self.ui.exit(1) + raise RuntimeError( + f"Module {module} is set to reload but we don't have the frame that threw!" + ) + code_obj = module.root_frame.f_code + else: + if module.underlying_fn.__name__ == "": + self.ui.exit(1) + raise NotImplementedError( + "Non-throwing Lambda reloading not implemented yet." + ) + # TODO: Get the lambda arguments, and for each argument, find the code + # object and the arg name. The run the following. + lambda_args = inspect.getargs(module.underlying_fn.__code__).args + # print(lambda_args) + print(module.partial.args, module.partial.keywords) + all_args = list(module.partial.args) + list( + module.partial.keywords.values() + ) + + def get_code_obj(a): + if inspect.iscode(a): + return a + if inspect.isclass(a): + return a.__init__.__code__ + if inspect.ismethod(a): + return a.__func__.__code__ + if inspect.isfunction(a): + return a.__code__ + return get_code_obj(a.__class__) + + assert len(lambda_args) == len(all_args) + for argname, argval in zip(lambda_args, all_args): + code_obj = get_code_obj(argval) + module.root_lambda_argname = argname + await self._reload_code_obj(code_obj, module) + elif inspect.isclass(module.underlying_fn): + code_obj = module.underlying_fn.__init__.__code__ + else: + code_obj = module.underlying_fn.__code__ + await self._reload_code_obj(code_obj, module) diff --git a/bootstrap/hot_reloading/module.py b/bootstrap/hot_reloading/module.py index e18bf14..7cb83a6 100644 --- a/bootstrap/hot_reloading/module.py +++ b/bootstrap/hot_reloading/module.py @@ -15,8 +15,8 @@ def __init__(self, name: str, fn: Callable | Partial, *args, **kwargs): self.to_reload = False self.result = None self.is_frozen = False - self.throw_frame: Optional[FrameType] = None - self.throw_lambda_argname: Optional[str] = None + self.root_frame: Optional[FrameType] = None + self.root_lambda_argname: Optional[str] = None def reload(self, new_func: Callable) -> None: print(f"Replacing {self.underlying_fn} with {new_func}") diff --git a/bootstrap/tui/builder_ui.py b/bootstrap/tui/builder_ui.py index a155587..ec1cea1 100644 --- a/bootstrap/tui/builder_ui.py +++ b/bootstrap/tui/builder_ui.py @@ -49,7 +49,7 @@ def __init__(self, chain: List[MatchboxModule]): self._module_chain: List[MatchboxModule] = chain self._runner_task = None self._engine = HotReloadingEngine(self) - self._skip_frozen = True # NOTE: Leave this to true for the first run or it will attempt to reload on the first run, which is not really desireable + self._reload_on_throw_only = True # NOTE: Leave this to true for the first run or it will attempt to reload on the first run, which is not really desireable async def on_mount(self): await self._chain_up() @@ -74,10 +74,10 @@ async def _run_chain(self) -> None: for module in self._module_chain: await self.query_one(LocalsPanel).clear() self.query_one(Tracer).clear() - if module.is_frozen and self._skip_frozen: + if module.is_frozen: self.log_tracer(Text(f"Skipping frozen module {module}", style="green")) continue - if module.to_reload or not self._skip_frozen: + if module.to_reload or not self._reload_on_throw_only: self.log_tracer(Text(f"Reloading module: {module}", style="yellow")) await self._engine.reload_module(module) self.log_tracer(Text(f"Running module: {module}", style="yellow")) @@ -131,11 +131,11 @@ def _reload(self) -> None: self.run_chain() def action_reload(self) -> None: - self._skip_frozen = False + self._reload_on_throw_only = False self._reload() def action_forward_reload(self) -> None: - self._skip_frozen = True + self._reload_on_throw_only = True self._reload() def on_checkbox_changed(self, message: Checkbox.Changed):