Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

(PR-SERIES-P5) [REFACTOR]: Execution Context #580

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 63 additions & 36 deletions jaclang/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from jaclang.plugin.builtin import dotgen
from jaclang.plugin.feature import JacCmd as Cmd
from jaclang.plugin.feature import JacFeature as Jac
from jaclang.runtimelib.constructs import Anchor, Architype
from jaclang.runtimelib.machine import JacProgram
from jaclang.runtimelib.constructs import Architype, WalkerArchitype
from jaclang.runtimelib.context import ExecutionContext
from jaclang.runtimelib.machine import JacMachine, JacProgram
from jaclang.utils.helpers import debugger as db
from jaclang.utils.lang_tools import AstTool

Expand Down Expand Up @@ -90,23 +91,27 @@ def run(
base, mod = os.path.split(filename)
base = base if base else "./"
mod = mod[:-4]
Jac.context().init_memory(base_path=base, session=session)

jctx = ExecutionContext.create(session=session)
if filename.endswith(".jac"):
JacMachine(base).attach_program(JacProgram(mod_bundle=None, bytecode=None))

ret_module = jac_import(
target=mod,
base_path=base,
cachable=cache,
override_name="__main__" if main else None,
)

if ret_module is None:
loaded_mod = None
else:
(loaded_mod,) = ret_module
elif filename.endswith(".jir"):
with open(filename, "rb") as f:
ir = pickle.load(f)
jac_program = JacProgram(mod_bundle=ir, bytecode=None)
Jac.context().jac_machine.attach_program(jac_program)
JacMachine(base).attach_program(
JacProgram(mod_bundle=pickle.load(f), bytecode=None)
)
ret_module = jac_import(
target=mod,
base_path=base,
Expand All @@ -122,13 +127,13 @@ def run(
return

if not node or node == "root":
entrypoint: Architype = Jac.get_root()
else:
obj = Jac.context().mem.find_by_id(UUID(node))
if not isinstance(obj, Anchor) or obj.architype is None:
print(f"Entrypoint {node} not found.")
return
entrypoint: Architype = jctx.root.architype
elif obj := jctx.mem.find_by_id(UUID(node)):
entrypoint = obj.architype
else:
print(f"Entrypoint {node} not found.")
jctx.close()
return

# TODO: handle no override name
if walker:
Expand All @@ -138,7 +143,8 @@ def run(
else:
print(f"Walker {walker} not found.")

Jac.reset_context()
jctx.close()
JacMachine.detach()


@cmd_registry.register
Expand All @@ -147,21 +153,18 @@ def get_object(id: str, session: str = "") -> dict:
if session == "":
session = cmd_registry.args.session if "session" in cmd_registry.args else ""

Jac.context().init_memory(session=session)
jctx = ExecutionContext.create(session=session)

data = {}
if id == "root":
id_uuid = UUID(int=0)
data = jctx.root.__getstate__()
elif obj := jctx.mem.find_by_id(UUID(id)):
data = obj.__getstate__()
else:
id_uuid = UUID(id)

obj = Jac.context().mem.find_by_id(id_uuid)
if obj is None:
print(f"Object with id {id} not found.")
Jac.reset_context()
return {}
else:
Jac.reset_context()
return obj.__getstate__()

jctx.close()
return data


@cmd_registry.register
Expand Down Expand Up @@ -211,26 +214,45 @@ def lsp() -> None:


@cmd_registry.register
def enter(filename: str, entrypoint: str, args: list) -> None:
"""Run the specified entrypoint function in the given .jac file.
def enter(
filename: str,
entrypoint: str,
args: list,
session: str = "",
root: str = "",
node: str = "",
) -> None:
"""
Run the specified entrypoint function in the given .jac file.

:param filename: The path to the .jac file.
:param entrypoint: The name of the entrypoint function.
:param args: Arguments to pass to the entrypoint function.
:param session: shelve.Shelf file path.
:param root: root executor.
:param node: starting node.
"""
jctx = ExecutionContext.create(session=session, root=root, entry=node)

if filename.endswith(".jac"):
base, mod_name = os.path.split(filename)
base = base if base else "./"
mod_name = mod_name[:-4]
(mod,) = jac_import(target=mod_name, base_path=base)
JacMachine(base).attach_program(JacProgram(mod_bundle=None, bytecode=None))
if not mod:
print("Errors occurred while importing the module.")
return
else:
getattr(mod, entrypoint)(*args)
architype = getattr(mod, entrypoint)(*args)
if isinstance(architype, WalkerArchitype):
Jac.spawn_call(jctx.entry.architype, architype)
JacMachine.detach()

else:
print("Not a .jac file.")

jctx.close()


@cmd_registry.register
def test(
Expand All @@ -252,6 +274,8 @@ def test(

jac test => jac test -d .
"""
jctx = ExecutionContext.create()

failcount = Jac.run_test(
filepath=filepath,
filter=filter,
Expand All @@ -260,6 +284,9 @@ def test(
directory=directory,
verbose=verbose,
)

jctx.close()

if failcount:
raise SystemExit(f"Tests failed: {failcount}")

Expand Down Expand Up @@ -361,13 +388,13 @@ def dot(
base, mod = os.path.split(filename)
base = base if base else "./"
mod = mod[:-4]
Jac.context().init_memory(base_path=base, session=session)

jctx = ExecutionContext.create(session=session)

if filename.endswith(".jac"):
jac_import(
target=mod,
base_path=base,
)
module = Jac.context().jac_machine.loaded_modules.get(mod)
jac_machine = JacMachine(base)
jac_import(target=mod, base_path=base)
module = jac_machine.loaded_modules.get(mod)
globals().update(vars(module))
try:
node = globals().get(initial, eval(initial)) if initial else None
Expand All @@ -385,7 +412,7 @@ def dot(
import traceback

traceback.print_exc()
Jac.reset_context()
jctx.close()
return
file_name = saveto if saveto else f"{mod}.dot"
with open(file_name, "w") as file:
Expand All @@ -394,7 +421,7 @@ def dot(
else:
print("Not a .jac file.")

Jac.reset_context()
jctx.close()


@cmd_registry.register
Expand Down
20 changes: 11 additions & 9 deletions jaclang/compiler/tests/test_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,37 @@

from jaclang import jac_import
from jaclang.cli import cli
from jaclang.plugin.feature import JacFeature as Jac
from jaclang.runtimelib.machine import JacMachine, JacProgram
from jaclang.utils.test import TestCase


class TestLoader(TestCase):
"""Test Jac self.prse."""

def setUp(self) -> None:
"""Set up test."""
return super().setUp()

def test_import_basic_python(self) -> None:
"""Test basic self loading."""
Jac.context().init_memory(base_path=self.fixture_abs_path(__file__))
JacMachine(self.fixture_abs_path(__file__)).attach_program(
JacProgram(mod_bundle=None, bytecode=None)
)
(h,) = jac_import("fixtures.hello_world", base_path=__file__)
self.assertEqual(h.hello(), "Hello World!") # type: ignore
JacMachine.detach()

def test_modules_correct(self) -> None:
"""Test basic self loading."""
Jac.context().init_memory(base_path=self.fixture_abs_path(__file__))
JacMachine(self.fixture_abs_path(__file__)).attach_program(
JacProgram(mod_bundle=None, bytecode=None)
)
jac_import("fixtures.hello_world", base_path=__file__)
self.assertIn(
"module 'fixtures.hello_world'",
str(Jac.context().jac_machine.loaded_modules),
str(JacMachine.get().loaded_modules),
)
self.assertIn(
"/tests/fixtures/hello_world.jac",
str(Jac.context().jac_machine.loaded_modules),
str(JacMachine.get().loaded_modules),
)
JacMachine.detach()

def test_jac_py_import(self) -> None:
"""Basic test for pass."""
Expand Down
44 changes: 15 additions & 29 deletions jaclang/plugin/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@
ExecutionContext,
GenericEdge,
JacTestCheck,
Memory,
NodeAnchor,
NodeArchitype,
Root,
WalkerAnchor,
WalkerArchitype,
exec_context,
)
from jaclang.runtimelib.importer import ImportPathSpec, JacImporter, PythonImporter
from jaclang.runtimelib.machine import JacMachine, JacProgram
from jaclang.runtimelib.utils import traverse_graph
from jaclang.plugin.feature import JacFeature as Jac # noqa: I100
from jaclang.plugin.spec import P, T
Expand Down Expand Up @@ -69,28 +68,9 @@ class JacFeatureDefaults:

@staticmethod
@hookimpl
def context(session: str = "") -> ExecutionContext:
"""Get the execution context."""
ctx = exec_context.get()
if ctx is None:
ctx = ExecutionContext()
exec_context.set(ctx)
return ctx

@staticmethod
@hookimpl
def reset_context() -> None:
"""Reset the execution context."""
ctx = exec_context.get()
if ctx:
ctx.reset()
exec_context.set(None)

@staticmethod
@hookimpl
def memory_hook() -> Memory | None:
"""Return the memory hook."""
return Jac.context().mem
def get_context() -> ExecutionContext:
"""Get current execution context."""
return ExecutionContext.get()

@staticmethod
@hookimpl
Expand Down Expand Up @@ -263,12 +243,18 @@ def jac_import(
lng,
items,
)

jac_machine = JacMachine.get(base_path)
if not jac_machine.jac_program:
jac_machine.attach_program(JacProgram(mod_bundle=None, bytecode=None))

if lng == "py":
import_result = PythonImporter(Jac.context().jac_machine).run_import(spec)
import_result = PythonImporter(JacMachine.get()).run_import(spec)
else:
import_result = JacImporter(Jac.context().jac_machine).run_import(
import_result = JacImporter(JacMachine.get()).run_import(
spec, reload_module
)

return (
(import_result.ret_mod,)
if absorb or not items
Expand Down Expand Up @@ -503,14 +489,14 @@ def disconnect(
and node == source
and target.architype in right
):
anchor.destroy()
anchor.destroy() if anchor.persistent else anchor.detach()
disconnect_occurred = True
if (
dir in [EdgeDir.IN, EdgeDir.ANY]
and node == target
and source.architype in right
):
anchor.destroy()
anchor.destroy() if anchor.persistent else anchor.detach()
disconnect_occurred = True

return disconnect_occurred
Expand All @@ -531,7 +517,7 @@ def assign_compr(
@hookimpl
def get_root() -> Root:
"""Jac's assign comprehension feature."""
return Jac.context().get_root()
return ExecutionContext.get_root()

@staticmethod
@hookimpl
Expand Down
19 changes: 4 additions & 15 deletions jaclang/plugin/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@

import jaclang.compiler.absyntree as ast
from jaclang.compiler.passes.main.pyast_gen_pass import PyastGenPass
from jaclang.plugin.default import ExecutionContext
from jaclang.plugin.spec import JacBuiltin, JacCmdSpec, JacFeatureSpec, P, T
from jaclang.runtimelib.constructs import (
Architype,
EdgeArchitype,
Memory,
NodeAnchor,
NodeArchitype,
Root,
WalkerArchitype,
)
from jaclang.runtimelib.context import ExecutionContext

import pluggy

Expand All @@ -43,19 +42,9 @@ class JacFeature:
Walker: TypeAlias = WalkerArchitype

@staticmethod
def context(session: str = "") -> ExecutionContext:
"""Create execution context."""
return pm.hook.context(session=session)

@staticmethod
def reset_context() -> None:
"""Reset execution context."""
return pm.hook.reset_context()

@staticmethod
def memory_hook() -> Memory | None:
"""Create memory abstraction."""
return pm.hook.memory_hook()
def get_context() -> ExecutionContext:
"""Get current execution context."""
return pm.hook.get_context()

@staticmethod
def make_architype(
Expand Down
Loading
Loading