Skip to content

Commit

Permalink
Refactor a bit and implement a bit of hot code realoding without thro…
Browse files Browse the repository at this point in the history
…wing
  • Loading branch information
DubiousCactus committed Dec 30, 2024
1 parent cd226c7 commit ac8335a
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 47 deletions.
132 changes: 92 additions & 40 deletions bootstrap/hot_reloading/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import sys
import traceback
from types import FrameType
from types import CodeType, FrameType, TracebackType
from typing import (
Callable,
Optional,
Expand All @@ -25,7 +25,7 @@ def __init__(self, ui: App):

@classmethod
def get_class_frame(
cls, func: Callable, exc_traceback
cls, func: Callable, exc_traceback: TracebackType
) -> Tuple[Optional[FrameType], Optional[FrameType]]:
"""
Find the frame of the last callable within the scope of the MatchboxModule in
Expand All @@ -34,7 +34,7 @@ def get_class_frame(
function that threw (or originated) the exception.
"""
print("============= get_class_frame() =========")
last_frame_in_scope, last_frame = None, None
root_frame, last_frame = None, None
for frame, _ in traceback.walk_tb(exc_traceback):
last_frame = frame
print(frame.f_code.co_qualname)
Expand All @@ -48,13 +48,15 @@ def get_class_frame(
and "self" in inspect.getargs(frame.f_code).args
):
print(f"Found method {val} in traceback, continuing...")
last_frame_in_scope = frame
root_frame = frame
print("============================================")
return last_frame_in_scope, last_frame
return root_frame, last_frame

@classmethod
def get_lambda_child_frame(
cls, func: Callable, exc_traceback
cls,
func: Callable,
exc_traceback: TracebackType,
) -> Tuple[Optional[FrameType], Optional[str]]:
"""
Find the frame of the last callable within the scope of the MatchboxModule in
Expand Down Expand Up @@ -95,21 +97,33 @@ def get_lambda_child_frame(
return None, None

@classmethod
def get_function_frame(cls, func: Callable, exc_traceback) -> Optional[FrameType]:
def get_function_frame(
cls, func: Callable, exc_traceback: TracebackType
) -> Tuple[Optional[FrameType], Optional[FrameType]]:
"""
Find the frame of the last callable within the scope of the MatchboxModule in
the traceback. In this instance, the MatchboxModule is a module-level function
so we want to find its frame.
"""
print("============= get_function_frame() =========")
last_frame = None
root_frame, last_frame = None, None
for frame, _ in traceback.walk_tb(exc_traceback):
print(frame.f_code.co_qualname)
last_frame = frame
if frame.f_code.co_qualname == func.__name__:
print(
f"Found module.underlying_fn ({func.__name__}) in traceback, continuing..."
)
for name, val in inspect.getmembers(func.__module__):
if name == frame.f_code.co_name:
print(f"Found function {val} in traceback, continuing...")
last_frame = frame
root_frame = frame
# print(
# f"Looking into its {frame.f_code.co_qualname}'s members via {func.__module__}..."
# )
# for name, val in inspect.getmembers(func.__module__):
# if name == frame.f_code.co_name:
# print(f"Found function {val} in traceback, continuing...")
# last_frame = frame
print("============================================")
return last_frame
return root_frame, last_frame

async def catch_and_hang(self, module: MatchboxModule, *args, **kwargs):
try:
Expand All @@ -128,7 +142,9 @@ async def catch_and_hang(self, module: MatchboxModule, *args, **kwargs):
return output
except Exception as exception:
# If the exception came from the wrapper itself, we should not catch it!
exc_type, exc_value, exc_traceback = sys.exc_info()
exc_traceback: Optional[TracebackType] = None
_, _, exc_traceback = sys.exc_info()
assert exc_traceback is not None
if exc_traceback.tb_next is None:
self.ui.exit(1)
raise RuntimeError("Could not find the next frame in the call stack!")
Expand Down Expand Up @@ -157,26 +173,32 @@ async def catch_and_hang(self, module: MatchboxModule, *args, **kwargs):
root_frame, lambda_argname = self.get_lambda_child_frame(
func, exc_traceback
)
module.throw_lambda_argname = lambda_argname
module.root_lambda_argname = lambda_argname
elif inspect.isfunction(func):
root_frame = self.get_function_frame(func, exc_traceback)
root_frame, throwing_frame = self.get_function_frame(
func, exc_traceback
)
else:
self.ui.exit(1)
raise NotImplementedError()
locals_f = root_frame if throwing_frame is None else throwing_frame
if not root_frame:
self.ui.exit(1)
raise RuntimeError(
await self.ui.hang(threw=True)
self.ui.print_err(
f"Could not find the frame of the original function {func} in the traceback."
)
# self.ui.exit(1)
# raise RuntimeError(
# f"Could not find the frame of the original function {func} in the traceback."
# )
else:
# NOTE: Here we reloaded the root frame of the throwing call, i.e.
# the frame that's in the scope of our MatchboxModule so a class
# method if the module is a class.
await self.ui.set_locals(
locals_f.f_locals, locals_f.f_code.co_qualname
)
module.throw_frame = root_frame
module.root_frame = root_frame
info = (
(
f"Exception thrown in <{locals_f.f_code.co_qualname}>"
Expand All @@ -195,29 +217,13 @@ async def catch_and_hang(self, module: MatchboxModule, *args, **kwargs):
self.ui.print_info("Hanged.")
await self.ui.hang(threw=True)

async def reload_module(self, module: MatchboxModule):
if module.to_reload and module.throw_frame is None:
self.ui.exit(1)
raise RuntimeError(
f"Module {module} is set to reload but we don't have the frame that threw!"
)
elif not module.to_reload:
# TODO: This works as long as we init the builder UI with skip_frozen=True
# so that we can get the root frame at least once, but it will fail if the
# module never throws in the first place. We should fix it by decoupling
# root frame finding from the catch_and_hang() method above. Or ideally by
# finding the code object more efficiently!
self.ui.print_info(f"Reloading MatchboxModule({module.underlying_fn})...")
self.ui.print_err("Hot reloading without throwing is not implemented yet.")
code_obj = None
await self.ui.hang(threw=False)
async def _reload_code_obj(self, code_obj: CodeType, module: MatchboxModule):
self.ui.log_tracer(
Text(
f"Reloading code from {module.throw_frame.f_code.co_filename}",
f"Reloading code from {code_obj.co_filename}",
style="purple",
)
)
code_obj = module.throw_frame.f_code
print(code_obj.co_qualname, inspect.getmodule(code_obj))
code_module = inspect.getmodule(code_obj)
if code_module is None:
Expand Down Expand Up @@ -273,9 +279,9 @@ async def reload_module(self, module: MatchboxModule):
)
)
if module.underlying_fn.__name__ == "<lambda>":
assert module.throw_lambda_argname is not None
assert module.root_lambda_argname is not None
module.reload_surgically_in_lambda(
module.throw_lambda_argname, code_obj.co_name, rld_callable
module.root_lambda_argname, code_obj.co_name, rld_callable
)
else:
module.reload_surgically(code_obj.co_name, rld_callable)
Expand All @@ -296,5 +302,51 @@ async def reload_module(self, module: MatchboxModule):
print(inspect.getsource(func))
module.reload(func)
return
# self.ui.hang(threw=False)
while True:
await asyncio.sleep(1)
await asyncio.sleep(1) # TODO: Why not hang?

async def reload_module(self, module: MatchboxModule):
if module.to_reload:
if module.root_frame is None:
self.ui.exit(1)
raise RuntimeError(
f"Module {module} is set to reload but we don't have the frame that threw!"
)
code_obj = module.root_frame.f_code
else:
if module.underlying_fn.__name__ == "<lambda>":
self.ui.exit(1)
raise NotImplementedError(
"Non-throwing Lambda reloading not implemented yet."
)
# TODO: Get the lambda arguments, and for each argument, find the code
# object and the arg name. The run the following.
lambda_args = inspect.getargs(module.underlying_fn.__code__).args
# print(lambda_args)
print(module.partial.args, module.partial.keywords)
all_args = list(module.partial.args) + list(
module.partial.keywords.values()
)

def get_code_obj(a):
if inspect.iscode(a):
return a
if inspect.isclass(a):
return a.__init__.__code__
if inspect.ismethod(a):
return a.__func__.__code__
if inspect.isfunction(a):
return a.__code__
return get_code_obj(a.__class__)

assert len(lambda_args) == len(all_args)
for argname, argval in zip(lambda_args, all_args):
code_obj = get_code_obj(argval)
module.root_lambda_argname = argname
await self._reload_code_obj(code_obj, module)
elif inspect.isclass(module.underlying_fn):
code_obj = module.underlying_fn.__init__.__code__
else:
code_obj = module.underlying_fn.__code__
await self._reload_code_obj(code_obj, module)
4 changes: 2 additions & 2 deletions bootstrap/hot_reloading/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def __init__(self, name: str, fn: Callable | Partial, *args, **kwargs):
self.to_reload = False
self.result = None
self.is_frozen = False
self.throw_frame: Optional[FrameType] = None
self.throw_lambda_argname: Optional[str] = None
self.root_frame: Optional[FrameType] = None
self.root_lambda_argname: Optional[str] = None

def reload(self, new_func: Callable) -> None:
print(f"Replacing {self.underlying_fn} with {new_func}")
Expand Down
10 changes: 5 additions & 5 deletions bootstrap/tui/builder_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, chain: List[MatchboxModule]):
self._module_chain: List[MatchboxModule] = chain
self._runner_task = None
self._engine = HotReloadingEngine(self)
self._skip_frozen = True # NOTE: Leave this to true for the first run or it will attempt to reload on the first run, which is not really desireable
self._reload_on_throw_only = True # NOTE: Leave this to true for the first run or it will attempt to reload on the first run, which is not really desireable

async def on_mount(self):
await self._chain_up()
Expand All @@ -74,10 +74,10 @@ async def _run_chain(self) -> None:
for module in self._module_chain:
await self.query_one(LocalsPanel).clear()
self.query_one(Tracer).clear()
if module.is_frozen and self._skip_frozen:
if module.is_frozen:
self.log_tracer(Text(f"Skipping frozen module {module}", style="green"))
continue
if module.to_reload or not self._skip_frozen:
if module.to_reload or not self._reload_on_throw_only:
self.log_tracer(Text(f"Reloading module: {module}", style="yellow"))
await self._engine.reload_module(module)
self.log_tracer(Text(f"Running module: {module}", style="yellow"))
Expand Down Expand Up @@ -131,11 +131,11 @@ def _reload(self) -> None:
self.run_chain()

def action_reload(self) -> None:
self._skip_frozen = False
self._reload_on_throw_only = False
self._reload()

def action_forward_reload(self) -> None:
self._skip_frozen = True
self._reload_on_throw_only = True
self._reload()

def on_checkbox_changed(self, message: Checkbox.Changed):
Expand Down

0 comments on commit ac8335a

Please sign in to comment.