diff --git a/jaclang_jaseci/core/architype.py b/jaclang_jaseci/core/architype.py index 88e2998..f2e6e1e 100644 --- a/jaclang_jaseci/core/architype.py +++ b/jaclang_jaseci/core/architype.py @@ -1,16 +1,13 @@ """Core constructs for Jac Language.""" from dataclasses import asdict, dataclass, field, fields, is_dataclass -from inspect import iscoroutine from os import getenv from re import IGNORECASE, compile from typing import ( Any, Callable, ClassVar, - Iterable, Mapping, - Type, TypeVar, cast, get_args, @@ -20,59 +17,77 @@ from bson import ObjectId -from jaclang.compiler.constant import EdgeDir, T +from jaclang.compiler.constant import EdgeDir from jaclang.runtimelib.architype import ( Access as _Access, AccessLevel, - Anchor as _Anchor, - AnchorState as _AnchorState, - AnchorType, - Architype as _Architype, + Anchor, + Architype, DSFunc, + EdgeAnchor as _EdgeAnchor, + EdgeArchitype as _EdgeArchitype, + NodeAnchor as _NodeAnchor, + NodeArchitype as _NodeArchitype, Permission as _Permission, + TANCH, + WalkerAnchor as _WalkerAnchor, + WalkerArchitype as _WalkerArchitype, ) -from jaclang.runtimelib.utils import collect_node_connections - -from motor.motor_asyncio import AsyncIOMotorClientSession from orjson import dumps from pymongo import ASCENDING, DeleteMany, DeleteOne, InsertOne, UpdateMany, UpdateOne +from pymongo.client_session import ClientSession from pymongo.errors import ConnectionFailure, OperationFailure from ..jaseci.datasources import Collection as BaseCollection from ..jaseci.utils import logger - -GENERIC_ID_REGEX = compile(r"^(g|n|e|w):([^:]*):([a-f\d]{24})$", IGNORECASE) +MANUAL_SAVE = getenv("MANUAL_SAVE") +GENERIC_ID_REGEX = compile(r"^(n|e|w):([^:]*):([a-f\d]{24})$", IGNORECASE) NODE_ID_REGEX = compile(r"^n:([^:]*):([a-f\d]{24})$", IGNORECASE) EDGE_ID_REGEX = compile(r"^e:([^:]*):([a-f\d]{24})$", IGNORECASE) WALKER_ID_REGEX = compile(r"^w:([^:]*):([a-f\d]{24})$", IGNORECASE) -TA = TypeVar("TA", bound="Architype") +T = TypeVar("T") +TBA = TypeVar("TBA", bound="BaseArchitype") + + +def architype_to_dataclass(cls: type[T], data: dict[str, Any], **kwargs: object) -> T: + """Parse dict to architype.""" + _to_dataclass(cls, data) + architype = object.__new__(cls) + architype.__init__(**data, **kwargs) # type: ignore[misc] + return architype def to_dataclass(cls: type[T], data: dict[str, Any], **kwargs: object) -> T: """Parse dict to dataclass.""" + _to_dataclass(cls, data) + return cls(**data, **kwargs) + + +def _to_dataclass(cls: type[T], data: dict[str, Any]) -> None: + """Parse dict to dataclass implementation.""" hintings = get_type_hints(cls) - for attr in fields(cls): # type: ignore[arg-type] - if target := data.get(attr.name): - hint = hintings[attr.name] - if is_dataclass(hint): - data[attr.name] = to_dataclass(hint, target) - else: - origin = get_origin(hint) - if origin == dict and isinstance(target, dict): - if is_dataclass(inner_cls := get_args(hint)[-1]): - for key, value in target.items(): + if is_dataclass(cls): + for attr in fields(cls): + if target := data.get(attr.name): + hint = hintings[attr.name] + if is_dataclass(hint): + data[attr.name] = to_dataclass(hint, target) + else: + origin = get_origin(hint) + if origin == dict and isinstance(target, dict): + if is_dataclass(inner_cls := get_args(hint)[-1]): + for key, value in target.items(): + target[key] = to_dataclass(inner_cls, value) + elif ( + origin == list + and isinstance(target, list) + and is_dataclass(inner_cls := get_args(hint)[-1]) + ): + for key, value in enumerate(target): target[key] = to_dataclass(inner_cls, value) - elif ( - origin == list - and isinstance(target, list) - and is_dataclass(inner_cls := get_args(hint)[-1]) - ): - for key, value in enumerate(target): - target[key] = to_dataclass(inner_cls, value) - return cls(**data, **kwargs) @dataclass @@ -87,14 +102,13 @@ class BulkWrite: ) operations: dict[ - AnchorType, + type["BaseAnchor"], list[InsertOne[Any] | DeleteMany | DeleteOne | UpdateMany | UpdateOne], ] = field( default_factory=lambda: { - AnchorType.node: [], - AnchorType.edge: [], - AnchorType.walker: [], - AnchorType.generic: [], # ignored + NodeAnchor: [], + EdgeAnchor: [], + WalkerAnchor: [], } ) @@ -105,7 +119,7 @@ class BulkWrite: def del_node(self, id: ObjectId) -> None: """Add node to delete many operations.""" if not self.del_ops_nodes: - self.operations[AnchorType.node].append( + self.operations[NodeAnchor].append( DeleteMany({"_id": {"$in": self.del_ops_nodes}}) ) @@ -114,7 +128,7 @@ def del_node(self, id: ObjectId) -> None: def del_edge(self, id: ObjectId) -> None: """Add edge to delete many operations.""" if not self.del_ops_edges: - self.operations[AnchorType.edge].append( + self.operations[EdgeAnchor].append( DeleteMany({"_id": {"$in": self.del_ops_edges}}) ) @@ -123,7 +137,7 @@ def del_edge(self, id: ObjectId) -> None: def del_walker(self, id: ObjectId) -> None: """Add walker to delete many operations.""" if not self.del_ops_walker: - self.operations[AnchorType.walker].append( + self.operations[WalkerAnchor].append( DeleteMany({"_id": {"$in": self.del_ops_walker}}) ) @@ -135,13 +149,13 @@ def has_operations(self) -> bool: return any(val for val in self.operations.values()) @staticmethod - async def commit(session: AsyncIOMotorClientSession) -> None: + def commit(session: ClientSession) -> None: """Commit current session.""" commit_retry = 0 commit_max_retry = BulkWrite.SESSION_MAX_COMMIT_RETRY while commit_retry <= commit_max_retry: try: - await session.commit_transaction() + session.commit_transaction() break except (ConnectionFailure, OperationFailure) as ex: if ex.has_error_label("UnknownTransactionCommitResult"): @@ -156,29 +170,23 @@ async def commit(session: AsyncIOMotorClientSession) -> None: ) raise except Exception: - await session.abort_transaction() + session.abort_transaction() logger.error("Error commiting bulk write!") raise - async def execute(self, session: AsyncIOMotorClientSession) -> None: + def execute(self, session: ClientSession) -> None: """Execute all operations.""" transaction_retry = 0 transaction_max_retry = self.SESSION_MAX_TRANSACTION_RETRY while transaction_retry <= transaction_max_retry: try: - if node_operation := self.operations[AnchorType.node]: - await NodeAnchor.Collection.bulk_write( - node_operation, False, session - ) - if edge_operation := self.operations[AnchorType.edge]: - await EdgeAnchor.Collection.bulk_write( - edge_operation, False, session - ) - if walker_operation := self.operations[AnchorType.walker]: - await WalkerAnchor.Collection.bulk_write( - walker_operation, False, session - ) - await self.commit(session) + if node_operation := self.operations[NodeAnchor]: + NodeAnchor.Collection.bulk_write(node_operation, False, session) + if edge_operation := self.operations[EdgeAnchor]: + EdgeAnchor.Collection.bulk_write(edge_operation, False, session) + if walker_operation := self.operations[WalkerAnchor]: + WalkerAnchor.Collection.bulk_write(walker_operation, False, session) + self.commit(session) break except (ConnectionFailure, OperationFailure) as ex: if ex.has_error_label("TransientTransactionError"): @@ -204,7 +212,6 @@ class Access(_Access): def serialize(self) -> dict[str, object]: """Serialize Access.""" return { - "whitelist": self.whitelist, "anchors": {key: val.name for key, val in self.anchors.items()}, } @@ -213,7 +220,6 @@ def deserialize(cls, data: dict[str, Any]) -> "Access": """Deserialize Access.""" anchors = cast(dict[str, str], data.get("anchors")) return Access( - whitelist=bool(data.get("whitelist")), anchors={key: AccessLevel[val] for key, val in anchors.items()}, ) @@ -238,34 +244,28 @@ def deserialize(cls, data: dict[str, Any]) -> "Permission": @dataclass -class AnchorState(_AnchorState): +class AnchorState: """Anchor state handler.""" - # checker if needs to update on db changes: dict[str, dict[str, Any]] = field(default_factory=dict) - # context checker if update happens for each field hashes: dict[str, int] = field(default_factory=dict) + deleted: bool | None = None + connected: bool = False + persistent: bool = False -@dataclass -class WalkerAnchorState(AnchorState): - """Anchor state handler.""" - - disengaged: bool = False - persistent: bool = False # disabled by default +@dataclass(eq=False, repr=False, kw_only=True) +class BaseAnchor: + """Base Anchor.""" + architype: "BaseArchitype" + name: str = "" + id: ObjectId = field(default_factory=ObjectId) + root: ObjectId | None = None + access: Permission + state: AnchorState -@dataclass(eq=False) -class Anchor(_Anchor): - """Object Anchor.""" - - id: ObjectId = field(default_factory=ObjectId) # type: ignore[assignment] - root: ObjectId | None = None # type: ignore[assignment] - access: Permission = field(default_factory=Permission) - architype: "Architype | None" = None - state: AnchorState = field(default_factory=AnchorState) - - class Collection(BaseCollection["Anchor"]): + class Collection(BaseCollection["BaseAnchor"]): """Anchor collection interface.""" pass @@ -273,24 +273,28 @@ class Collection(BaseCollection["Anchor"]): @property def ref_id(self) -> str: """Return id in reference type.""" - return f"{self.type.value}:{self.name}:{self.id}" + return f"{self.__class__.__name__[:1].lower()}:{self.name}:{self.id}" @staticmethod - def ref(ref_id: str) -> "Anchor | None": + def ref(ref_id: str) -> "BaseAnchor | Anchor": """Return ObjectAnchor instance if .""" - if matched := GENERIC_ID_REGEX.search(ref_id): - cls: type = Anchor - match AnchorType(matched.group(1)): - case AnchorType.node: + if match := GENERIC_ID_REGEX.search(ref_id): + cls: type[BaseAnchor] + + match match.group(1): + case "n": cls = NodeAnchor - case AnchorType.edge: + case "e": cls = EdgeAnchor - case AnchorType.walker: + case "w": cls = WalkerAnchor case _: - pass - return cls(name=matched.group(2), id=ObjectId(matched.group(3))) - return None + raise ValueError(f"{ref_id}] is not a valid reference!") + anchor = object.__new__(cls) + anchor.name = str(match.group(2)) + anchor.id = ObjectId(match.group(3)) + return anchor + raise ValueError(f"{ref_id}] is not a valid reference!") #################################################### # QUERY OPERATIONS # @@ -322,7 +326,7 @@ def _pull(self) -> dict: return self.state.changes["$pull"] - def add_to_set(self, field: str, anchor: "Anchor", remove: bool = False) -> None: + def add_to_set(self, field: str, anchor: Anchor, remove: bool = False) -> None: """Add to set.""" if field not in (add_to_set := self._add_to_set): add_to_set[field] = {"$each": set()} @@ -336,7 +340,7 @@ def add_to_set(self, field: str, anchor: "Anchor", remove: bool = False) -> None ops.add(anchor) self.pull(field, anchor, True) - def pull(self, field: str, anchor: "Anchor", remove: bool = False) -> None: + def pull(self, field: str, anchor: Anchor, remove: bool = False) -> None: """Pull from set.""" if field not in (pull := self._pull): pull[field] = {"$in": set()} @@ -350,128 +354,37 @@ def pull(self, field: str, anchor: "Anchor", remove: bool = False) -> None: ops.add(anchor) self.add_to_set(field, anchor, True) - def connect_edge(self, anchor: "Anchor") -> None: + def connect_edge(self, anchor: Anchor) -> None: """Push update that there's newly added edge.""" self.add_to_set("edges", anchor) - def disconnect_edge(self, anchor: "Anchor") -> None: + def disconnect_edge(self, anchor: Anchor) -> None: """Push update that there's edge that has been removed.""" self.pull("edges", anchor) - # def whitelist_nodes(self, whitelist: bool = True) -> None: - # """Toggle node whitelist/blacklist.""" - # if whitelist != self.access.nodes.whitelist: - # self._set.update({"access.nodes.whitelist": whitelist}) - - # def allow_node(self, node: Anchor, level: AccessLevel | int | str = AccessLevel.READ) -> None: - # """Allow all access from target node to current Architype.""" - # level = AccessLevel.cast(level) - # access = self.access.nodes - # if access.whitelist: - # if (ref_id := node.ref_id) and level != access.anchors.get(ref_id, AccessLevel.NO_ACCESS): - # access.anchors[ref_id] = level - # self._set.update({f"access.nodes.anchors.{ref_id}": level.name}) - # self._unset.pop(f"access.nodes.anchors.{ref_id}", None) - # else: - # self.disallow_node(node, level) - - # def disallow_node(self, node: Anchor, level: AccessLevel | int | str = AccessLevel.READ) -> None: - # """Disallow all access from target node to current Architype.""" - # level = AccessLevel.cast(level) - # access = self.access.nodes - # if access.whitelist: - # if (ref_id := node.ref_id) and access.anchors.pop(ref_id, None) is not None: - # self._unset.update({f"access.nodes.anchors.{ref_id}": True}) - # self._set.pop(f"access.nodes.anchors.{ref_id}", None) - # else: - # self.allow_node(node, level) - - # def add_types(self, type: type[NodeArchitype]) -> None: - # """Add type checking.""" - # if not self.access.types.get(type): - # name = type.__ref_cls__() - # self._set.update({f"access.types.{name}": {}}) - # self._unset.pop(f"access.types.{name}", None) - - # def remove_types(self, type: type[NodeArchitype]) -> None: - # """Remove type checking.""" - # if self.access.types.pop(type, None): - # name = type.__ref_cls__() - # self._unset.update({f"access.types.{name}": True}) - # self._set.pop(f"access.types.{name}", None) - - # def whitelist_types( - # self, type: type[NodeArchitype], whitelist: bool = True - # ) -> None: - # """Toggle type whitelist/blacklist.""" - # if (access := self.access.types.get(type)) and whitelist != access.whitelist: - # self._set.update( - # {f"access.types.{type.__ref_cls__()}.whitelist": whitelist} - # ) - - # def allow_type( - # self, type: type[NodeArchitype], node: Anchor, level: AccessLevel | int | str = AccessLevel.READ - # ) -> None: - # """Allow all access from target type graph to current Architype.""" - # level = AccessLevel.cast(level) - # if access := self.access.types.get(type): - # if access.whitelist: - # if (ref_id := node.ref_id) and level != access.anchors.get(ref_id, AccessLevel.NO_ACCESS): - # access.anchors[ref_id] = level - # name = type.__ref_cls__() - # self._set.update({f"access.types.{name}.anchors.{ref_id}": level.name}) - # self._unset.pop(f"access.types.{name}.anchors.{ref_id}", None) - # else: - # self.disallow_type(type, node, level) - - # def disallow_type( - # self, type: type[NodeArchitype], node: Anchor, level: AccessLevel | int | str = AccessLevel.READ - # ) -> None: - # """Disallow all access from target type graph to current Architype.""" - # level = AccessLevel.cast(level) - # if access := self.access.types.get(type): - # if access.whitelist: - # if (ref_id := node.ref_id) and access.anchors.pop(ref_id, None) is not None: - # name = type.__ref_cls__() - # self._unset.update({f"access.types.{name}.anchors.{ref_id}": True}) - # self._set.pop(f"access.types.{name}.anchors.{ref_id}", None) - # else: - # self.allow_type(type, node, level) - - def whitelist_roots(self, whitelist: bool = True) -> None: - """Toggle root whitelist/blacklist.""" - if whitelist != self.access.roots.whitelist: - self.access.roots.whitelist = whitelist - self._set.update({"access.roots.whitelist": whitelist}) - - def allow_root( # type: ignore[override] - self, root: "Anchor", level: AccessLevel | int | str = AccessLevel.READ + def allow_root( + self, root: Anchor, level: AccessLevel | int | str = AccessLevel.READ ) -> None: """Allow all access from target root graph to current Architype.""" level = AccessLevel.cast(level) access = self.access.roots - if access.whitelist: - if (ref_id := root.ref_id) and level != access.anchors.get( - ref_id, AccessLevel.NO_ACCESS - ): - access.anchors[ref_id] = level - self._set.update({f"access.roots.anchors.{ref_id}": level.name}) - self._unset.pop(f"access.roots.anchors.{ref_id}", None) - else: - self.disallow_root(root, level) + if (ref_id := root.ref_id) and level != access.anchors.get( + ref_id, AccessLevel.NO_ACCESS + ): + access.anchors[ref_id] = level + self._set.update({f"access.roots.anchors.{ref_id}": level.name}) + self._unset.pop(f"access.roots.anchors.{ref_id}", None) - def disallow_root( # type: ignore[override] - self, root: "Anchor", level: AccessLevel | int | str = AccessLevel.READ + def disallow_root( + self, root: Anchor, level: AccessLevel | int | str = AccessLevel.READ ) -> None: """Disallow all access from target root graph to current Architype.""" level = AccessLevel.cast(level) access = self.access.roots - if access.whitelist: - if (ref_id := root.ref_id) and access.anchors.pop(ref_id, None) is not None: - self._unset.update({f"access.roots.anchors.{ref_id}": True}) - self._set.pop(f"access.roots.anchors.{ref_id}", None) - else: - self.allow_root(root, level) + + if (ref_id := root.ref_id) and access.anchors.pop(ref_id, None) is not None: + self._unset.update({f"access.roots.anchors.{ref_id}": True}) + self._set.pop(f"access.roots.anchors.{ref_id}", None) def unrestrict(self, level: AccessLevel | int | str = AccessLevel.READ) -> None: """Allow everyone to access current Architype.""" @@ -486,38 +399,37 @@ def restrict(self) -> None: self.access.all = AccessLevel.NO_ACCESS self._set.update({"access.all": AccessLevel.NO_ACCESS.name}) - # ------------------------------------------------ # - - async def sync(self, node: "NodeAnchor | None" = None) -> "Architype | None": # type: ignore[override] - """Retrieve the Architype from db and return.""" - if self.state.deleted is not None: - return None + #################################################### + # POPULATE OPERATIONS # + #################################################### - if architype := self.architype: - if await (node or self).has_read_access(self): - return architype - return None + def is_populated(self) -> bool: + """Check if populated.""" + return "architype" in self.__dict__ + + def make_stub(self: "BaseAnchor | TANCH") -> "BaseAnchor | TANCH": + """Return unsynced copy of anchor.""" + if self.is_populated(): + unloaded = object.__new__(self.__class__) + unloaded.name = self.name + unloaded.id = self.id + return unloaded + return self + def populate(self) -> None: + """Retrieve the Architype from db and return.""" from .context import JaseciContext - jsrc = JaseciContext.get_datasource() - anchor = await jsrc.find_one(self.__class__, self) + jsrc = JaseciContext.get().mem - if anchor and await (node or self).has_read_access(anchor): + if anchor := jsrc.find_by_id(self): self.__dict__.update(anchor.__dict__) + else: + raise ValueError( + f"{self.__class__.__name__} [{self.ref_id}] is not a valid reference!" + ) - return self.architype - - def allocate(self) -> None: - """Allocate hashes and memory.""" - from .context import JASECI_CONTEXT - - if jctx := JASECI_CONTEXT.get(None): - if self.root is None and not isinstance(self.architype, Root): - self.root = jctx.root.id - jctx.datasource.set(self) - - def _save( # type: ignore[override] + def build_query( self, bulk_write: BulkWrite, ) -> None: @@ -530,22 +442,24 @@ def _save( # type: ignore[override] self.state.connected = True self.sync_hash() self.insert(bulk_write) - elif self.state.current_access_level > AccessLevel.READ: + elif self.has_connect_access(self): # type: ignore[attr-defined] self.update(bulk_write, True) - async def save(self, session: AsyncIOMotorClientSession | None = None) -> BulkWrite: # type: ignore[override] + def apply(self, session: ClientSession | None = None) -> BulkWrite: """Save Anchor.""" bulk_write = BulkWrite() - self._save(bulk_write) + self.build_query(bulk_write) if bulk_write.has_operations: if session: - await bulk_write.execute(session) + bulk_write.execute(session) else: - async with await BaseCollection.get_session() as session: - async with session.start_transaction(): - await bulk_write.execute(session) + with ( + BaseCollection.get_session() as session, + session.start_transaction(), + ): + bulk_write.execute(session) return bulk_write @@ -554,26 +468,32 @@ def insert( bulk_write: BulkWrite, ) -> None: """Append Insert Query.""" - bulk_write.operations[self.type].append(InsertOne(self.serialize())) + raise NotImplementedError("insert must be implemented in subclasses") def update(self, bulk_write: BulkWrite, propagate: bool = False) -> None: """Append Update Query.""" changes = self.state.changes self.state.changes = {} # renew reference - operations = bulk_write.operations[self.type] + operations = bulk_write.operations[self.__class__] operation_filter = {"_id": self.id} ############################################################ # POPULATE CONTEXT # ############################################################ + from .context import JaseciContext - if self.state.current_access_level > AccessLevel.CONNECT: + if JaseciContext.get().root.has_write_access(self): set_architype = changes.pop("$set", {}) if is_dataclass(architype := self.architype) and not isinstance( architype, type ): - for key, val in architype.__getstate__().items(): + for ( + key, + val, + ) in ( + architype.__serialize__().items() # type:ignore[attr-defined] # mypy issue + ): if (h := hash(dumps(val))) != self.state.hashes.get(key): self.state.hashes[key] = h set_architype[f"architype.{key}"] = val @@ -585,18 +505,18 @@ def update(self, bulk_write: BulkWrite, propagate: bool = False) -> None: # -------------------------------------------------------- # - if self.type is AnchorType.node: + if isinstance(self, NodeAnchor): ############################################################ # POPULATE ADDED EDGES # ############################################################ - added_edges: set[Anchor] = ( + added_edges: set[BaseAnchor | Anchor] = ( changes.get("$addToSet", {}).get("edges", {}).get("$each", []) ) if added_edges: _added_edges = [] for anchor in added_edges: if propagate: - anchor._save(bulk_write) + anchor.build_query(bulk_write) _added_edges.append(anchor.ref_id) changes["$addToSet"]["edges"]["$each"] = _added_edges else: @@ -607,7 +527,7 @@ def update(self, bulk_write: BulkWrite, propagate: bool = False) -> None: ############################################################ # POPULATE REMOVED EDGES # ############################################################ - pulled_edges: set[Anchor] = ( + pulled_edges: set[BaseAnchor | Anchor] = ( changes.get("$pull", {}).get("edges", {}).get("$in", []) ) if pulled_edges: @@ -643,17 +563,7 @@ def delete(self, bulk_write: BulkWrite) -> None: def destroy(self) -> None: """Delete Anchor.""" - if ( - self.architype - and self.state.current_access_level > AccessLevel.CONNECT - and self.state.deleted is None - ): - from .context import JaseciContext - - jsrc = JaseciContext.get_datasource() - - self.state.deleted = False - jsrc.remove(self) + raise NotImplementedError("destroy must be implemented in subclasses") def sync_hash(self) -> None: """Sync current serialization hash.""" @@ -661,88 +571,22 @@ def sync_hash(self) -> None: architype, type ): self.state.hashes = { - key: hash(dumps(val)) for key, val in architype.__getstate__().items() + key: hash(dumps(val)) + for key, val in architype.__serialize__().items() # type:ignore[attr-defined] # mypy issue } - async def has_read_access(self, to: "Anchor") -> bool: # type: ignore[override] - """Read Access Validation.""" - return await self.access_level(to) > AccessLevel.NO_ACCESS - - async def has_connect_access(self, to: "Anchor") -> bool: # type: ignore[override] - """Write Access Validation.""" - return await self.access_level(to) > AccessLevel.READ - - async def has_write_access(self, to: "Anchor") -> bool: # type: ignore[override] - """Write Access Validation.""" - return await self.access_level(to) > AccessLevel.CONNECT - - async def access_level(self, to: "Anchor") -> AccessLevel: # type: ignore[override] - """Access validation.""" - from .context import JaseciContext - - jctx = JaseciContext.get() + # ---------------------------------------------------------------------- # - to.state.current_access_level = AccessLevel.NO_ACCESS - if jctx.root == jctx.super_root or jctx.root.id == to.root or jctx.root == to: - to.state.current_access_level = AccessLevel.WRITE - return to.state.current_access_level - - if (to_access := to.access).all > AccessLevel.NO_ACCESS: - to.state.current_access_level = to_access.all - - if to.root and ( - to_root := await jctx.datasource.find_one( - NodeAnchor, - NodeAnchor(id=to.root, state=AnchorState(connected=True)), - ) - ): - if to_root.access.all > to.state.current_access_level: - to.state.current_access_level = to_root.access.all - - whitelist, level = to_root.access.roots.check(jctx.root.ref_id) - if not whitelist: - if level < AccessLevel.READ: - to.state.current_access_level = AccessLevel.NO_ACCESS - return to.state.current_access_level - elif level < to.state.current_access_level: - level = to.state.current_access_level - elif whitelist and level > to.state.current_access_level: - to.state.current_access_level = level - - whitelist, level = to_access.roots.check(jctx.root.ref_id) - if not whitelist: - if level < AccessLevel.READ: - to.state.current_access_level = AccessLevel.NO_ACCESS - return to.state.current_access_level - elif level < to.state.current_access_level: - level = to.state.current_access_level - elif whitelist and level > to.state.current_access_level: - to.state.current_access_level = level - - # if (architype := self.architype) and ( - # access_type := to_access.types.get(architype.__class__) - # ): - # whitelist, level = access_type.check(self) - # if not whitelist: - # if level < AccessLevel.READ: - # to.state.current_access_level = AccessLevel.NO_ACCESS - # return to.state.current_access_level - # elif level < to.state.current_access_level: - # level = to.state.current_access_level - # elif whitelist and level > to.state.current_access_level: - # to.state.current_access_level = level - - # whitelist, level = to_access.nodes.check(self) - # if not whitelist: - # if level < AccessLevel.READ: - # to.state.current_access_level = AccessLevel.NO_ACCESS - # return to.state.current_access_level - # elif level < to.state.current_access_level: - # level = to.state.current_access_level - # elif whitelist and level > to.state.current_access_level: - # to.state.current_access_level = level - - return to.state.current_access_level + def report(self) -> dict[str, object]: + """Report Anchor.""" + return { + "id": self.ref_id, + "context": ( + self.architype.__serialize__() # type:ignore[attr-defined] # mypy issue + if is_dataclass(self.architype) and not isinstance(self.architype, type) + else {} + ), + } def serialize(self) -> dict[str, object]: """Serialize Anchor.""" @@ -752,20 +596,32 @@ def serialize(self) -> dict[str, object]: "root": self.root, "access": self.access.serialize(), "architype": ( - self.architype.__getstate__() + self.architype.__serialize__() # type:ignore[attr-defined] # mypy issue if is_dataclass(self.architype) and not isinstance(self.architype, type) else {} ), } + def __repr__(self) -> str: + """Override representation.""" + if self.is_populated(): + attrs = "" + for f in fields(self): + if f.name in self.__dict__: + attrs += f"{f.name}={self.__dict__[f.name]}, " + attrs = attrs[:-2] + else: + attrs = f"name={self.name}, id={self.id}" -@dataclass(eq=False) -class NodeAnchor(Anchor): + return f"{self.__class__.__name__}({attrs})" + + +@dataclass(eq=False, repr=False, kw_only=True) +class NodeAnchor(BaseAnchor, _NodeAnchor): # type: ignore[misc] """Node Anchor.""" - type: ClassVar[AnchorType] = AnchorType.node - architype: "NodeArchitype | None" = None - edges: list["EdgeAnchor"] = field(default_factory=list) + architype: "NodeArchitype" + edges: list["EdgeAnchor"] class Collection(BaseCollection["NodeAnchor"]): """NodeAnchor collection interface.""" @@ -779,28 +635,32 @@ class Collection(BaseCollection["NodeAnchor"]): def __document__(cls, doc: Mapping[str, Any]) -> "NodeAnchor": """Parse document to NodeAnchor.""" doc = cast(dict, doc) - architype: dict = doc.pop("architype") + + architype = architype_to_dataclass( + NodeArchitype.__get_class__(doc.get("name") or "Root"), + doc.pop("architype"), + ) anchor = NodeAnchor( + architype=architype, id=doc.pop("_id"), edges=[e for edge in doc.pop("edges") if (e := EdgeAnchor.ref(edge))], access=Permission.deserialize(doc.pop("access")), state=AnchorState(connected=True), **doc, ) - architype_cls = NodeArchitype.__get_class__(doc.get("name") or "Root") - anchor.architype = to_dataclass(architype_cls, architype, __jac__=anchor) + architype.__jac__ = anchor anchor.sync_hash() return anchor @classmethod - def ref(cls, ref_id: str) -> "NodeAnchor | None": + def ref(cls, ref_id: str) -> "NodeAnchor": """Return NodeAnchor instance if existing.""" if match := NODE_ID_REGEX.search(ref_id): - return cls( - name=match.group(1), - id=ObjectId(match.group(2)), - ) - return None + anchor = object.__new__(cls) + anchor.name = str(match.group(1)) + anchor.id = ObjectId(match.group(2)) + return anchor + raise ValueError(f"[{ref_id}] is not a valid reference!") def insert( self, @@ -808,9 +668,9 @@ def insert( ) -> None: """Append Insert Query.""" for edge in self.edges: - edge._save(bulk_write) + edge.build_query(bulk_write) - super().insert(bulk_write) + bulk_write.operations[NodeAnchor].append(InsertOne(self.serialize())) def delete(self, bulk_write: BulkWrite) -> None: """Append Delete Query.""" @@ -825,135 +685,44 @@ def delete(self, bulk_write: BulkWrite) -> None: def destroy(self) -> None: """Delete Anchor.""" - if ( - self.architype - and self.state.current_access_level > AccessLevel.CONNECT - and self.state.deleted is None - ): + if self.state.deleted is None: from .context import JaseciContext - jsrc = JaseciContext.get_datasource() - self.state.deleted = False + jctx = JaseciContext.get() - for edge in self.edges: - edge.destroy() - jsrc.remove(self) + if jctx.root.has_write_access(self): - async def sync(self, node: "NodeAnchor | None" = None) -> "NodeArchitype | None": # type: ignore[override] - """Retrieve the Architype from db and return.""" - return cast(NodeArchitype | None, await super().sync(node)) + self.state.deleted = False - def connect_node(self, nd: "NodeAnchor", edg: "EdgeAnchor") -> None: - """Connect a node with given edge.""" - edg.attach(self, nd) + for edge in self.edges: + edge.destroy() + jctx.mem.remove(self.id) - async def get_edges( + def get_edges( self, dir: EdgeDir, filter_func: Callable[[list["EdgeArchitype"]], list["EdgeArchitype"]] | None, - target_cls: list[Type["NodeArchitype"]] | None, + target_obj: list["NodeArchitype"] | None, ) -> list["EdgeArchitype"]: """Get edges connected to this node.""" from .context import JaseciContext - await JaseciContext.get_datasource().populate_data(self.edges) - - ret_edges: list[EdgeArchitype] = [] - for anchor in self.edges: - if ( - (architype := await anchor.sync(self)) - and (source := anchor.source) - and (target := anchor.target) - and (not filter_func or filter_func([architype])) - and (src_arch := await source.sync()) - and (trg_arch := await target.sync()) - ): - if ( - dir in [EdgeDir.OUT, EdgeDir.ANY] - and self == source - and (not target_cls or trg_arch.__class__ in target_cls) - and await source.has_read_access(target) - ): - ret_edges.append(architype) - if ( - dir in [EdgeDir.IN, EdgeDir.ANY] - and self == target - and (not target_cls or src_arch.__class__ in target_cls) - and await target.has_read_access(source) - ): - ret_edges.append(architype) - return ret_edges + JaseciContext.mem.populate_data(self.edges) - async def edges_to_nodes( + return super().get_edges(dir, filter_func, target_obj) + + def edges_to_nodes( self, dir: EdgeDir, filter_func: Callable[[list["EdgeArchitype"]], list["EdgeArchitype"]] | None, - target_cls: list[Type["NodeArchitype"]] | None, + target_obj: list["NodeArchitype"] | None, ) -> list["NodeArchitype"]: """Get set of nodes connected to this node.""" from .context import JaseciContext - await JaseciContext.get_datasource().populate_data(self.edges) - - ret_edges: list[NodeArchitype] = [] - for anchor in self.edges: - if ( - (architype := await anchor.sync(self)) - and (source := anchor.source) - and (target := anchor.target) - and (not filter_func or filter_func([architype])) - and (src_arch := await source.sync()) - and (trg_arch := await target.sync()) - ): - if ( - dir in [EdgeDir.OUT, EdgeDir.ANY] - and self == source - and (not target_cls or trg_arch.__class__ in target_cls) - and await source.has_read_access(target) - ): - ret_edges.append(trg_arch) - if ( - dir in [EdgeDir.IN, EdgeDir.ANY] - and self == target - and (not target_cls or src_arch.__class__ in target_cls) - and await target.has_read_access(source) - ): - ret_edges.append(src_arch) - return ret_edges - - def remove_edge(self, edge: "EdgeAnchor") -> None: - """Remove reference without checking sync status.""" - for idx, ed in enumerate(self.edges): - if ed.id == edge.id: - self.edges.pop(idx) - break - - def gen_dot(self, dot_file: str | None = None) -> str: - """Generate Dot file for visualizing nodes and edges.""" - visited_nodes: set[NodeAnchor] = set() - connections: set[tuple[NodeArchitype, NodeArchitype, str]] = set() - unique_node_id_dict = {} - - collect_node_connections(self, visited_nodes, connections) # type: ignore[arg-type] - dot_content = 'digraph {\nnode [style="filled", shape="ellipse", fillcolor="invis", fontcolor="black"];\n' - for idx, i in enumerate([nodes_.architype for nodes_ in visited_nodes]): - unique_node_id_dict[i] = (i.__class__.__name__, str(idx)) - dot_content += f'{idx} [label="{i}"];\n' - dot_content += 'edge [color="gray", style="solid"];\n' - - for pair in list(set(connections)): - dot_content += ( - f"{unique_node_id_dict[pair[0]][1]} -> {unique_node_id_dict[pair[1]][1]}" - f' [label="{pair[2]}"];\n' - ) - if dot_file: - with open(dot_file, "w") as f: - f.write(dot_content + "}") - return dot_content + "}" + JaseciContext.get().mem.populate_data(self.edges) - async def spawn_call(self, walk: "WalkerAnchor") -> "WalkerArchitype": - """Invoke data spatial call.""" - return await walk.spawn_call(self) + return super().edges_to_nodes(dir, filter_func, target_obj) def serialize(self) -> dict[str, object]: """Serialize Node Anchor.""" @@ -963,15 +732,14 @@ def serialize(self) -> dict[str, object]: } -@dataclass(eq=False) -class EdgeAnchor(Anchor): +@dataclass(eq=False, repr=False, kw_only=True) +class EdgeAnchor(BaseAnchor, _EdgeAnchor): # type: ignore[misc] """Edge Anchor.""" - type: ClassVar[AnchorType] = AnchorType.edge - architype: "EdgeArchitype | None" = None - source: NodeAnchor | None = None - target: NodeAnchor | None = None - is_undirected: bool = False + architype: "EdgeArchitype" + source: NodeAnchor + target: NodeAnchor + is_undirected: bool class Collection(BaseCollection["EdgeAnchor"]): """EdgeAnchor collection interface.""" @@ -985,8 +753,12 @@ class Collection(BaseCollection["EdgeAnchor"]): def __document__(cls, doc: Mapping[str, Any]) -> "EdgeAnchor": """Parse document to EdgeAnchor.""" doc = cast(dict, doc) - architype: dict = doc.pop("architype") + architype = architype_to_dataclass( + EdgeArchitype.__get_class__(doc.get("name") or "GenericEdge"), + doc.pop("architype"), + ) anchor = EdgeAnchor( + architype=architype, id=doc.pop("_id"), source=NodeAnchor.ref(doc.pop("source")), target=NodeAnchor.ref(doc.pop("target")), @@ -994,61 +766,56 @@ def __document__(cls, doc: Mapping[str, Any]) -> "EdgeAnchor": state=AnchorState(connected=True), **doc, ) - architype_cls = EdgeArchitype.__get_class__( - doc.get("name") or "GenericEdge" - ) - anchor.architype = to_dataclass(architype_cls, architype, __jac__=anchor) + architype.__jac__ = anchor anchor.sync_hash() return anchor + def __post_init__(self) -> None: + """Populate edge to source and target.""" + self.source.edges.append(self) + self.target.edges.append(self) + @classmethod - def ref(cls, ref_id: str) -> "EdgeAnchor | None": + def ref(cls, ref_id: str) -> "EdgeAnchor": """Return EdgeAnchor instance if existing.""" if match := EDGE_ID_REGEX.search(ref_id): - return cls( - name=match.group(1), - id=ObjectId(match.group(2)), - ) - return None + anchor = object.__new__(cls) + anchor.name = str(match.group(1)) + anchor.id = ObjectId(match.group(2)) + return anchor + raise ValueError(f"{ref_id}] is not a valid reference!") def insert(self, bulk_write: BulkWrite) -> None: """Append Insert Query.""" if source := self.source: - source._save(bulk_write) + source.build_query(bulk_write) if target := self.target: - target._save(bulk_write) + target.build_query(bulk_write) - super().insert(bulk_write) + bulk_write.operations[EdgeAnchor].append(InsertOne(self.serialize())) def delete(self, bulk_write: BulkWrite) -> None: """Append Delete Query.""" if source := self.source: - source._save(bulk_write) + source.build_query(bulk_write) if target := self.target: - target._save(bulk_write) + target.build_query(bulk_write) bulk_write.del_edge(self.id) def destroy(self) -> None: """Delete Anchor.""" - if ( - self.architype - and self.state.current_access_level > AccessLevel.CONNECT - and self.state.deleted is None - ): + if self.state.deleted is None: from .context import JaseciContext - jsrc = JaseciContext.get_datasource() + jctx = JaseciContext.get() - self.state.deleted = False - self.detach() - jsrc.remove(self) - - async def sync(self, node: "NodeAnchor | None" = None) -> "EdgeArchitype | None": # type: ignore[override] - """Retrieve the Architype from db and return.""" - return cast(EdgeArchitype | None, await super().sync(node)) + if jctx.root.has_write_access(self): + self.state.deleted = False + self.detach() + jctx.mem.remove(self.id) def attach( self, src: NodeAnchor, trg: NodeAnchor, is_undirected: bool = False @@ -1072,13 +839,6 @@ def detach(self) -> None: target.remove_edge(self) target.disconnect_edge(self) - async def spawn_call(self, walk: "WalkerAnchor") -> "WalkerArchitype": - """Invoke data spatial call.""" - if target := self.target: - return await walk.spawn_call(target) - else: - raise ValueError("Edge has no target.") - def serialize(self) -> dict[str, object]: """Serialize Node Anchor.""" return { @@ -1089,17 +849,16 @@ def serialize(self) -> dict[str, object]: } -@dataclass(eq=False) -class WalkerAnchor(Anchor): +@dataclass(eq=False, repr=False, kw_only=True) +class WalkerAnchor(BaseAnchor, _WalkerAnchor): # type: ignore[misc] """Walker Anchor.""" - type: ClassVar[AnchorType] = AnchorType.walker - architype: "WalkerArchitype | None" = None + architype: "WalkerArchitype" path: list[Anchor] = field(default_factory=list) next: list[Anchor] = field(default_factory=list) returns: list[Any] = field(default_factory=list) ignores: list[Anchor] = field(default_factory=list) - state: WalkerAnchorState = field(default_factory=WalkerAnchorState) + disengaged: bool = False class Collection(BaseCollection["WalkerAnchor"]): """WalkerAnchor collection interface.""" @@ -1113,140 +872,118 @@ class Collection(BaseCollection["WalkerAnchor"]): def __document__(cls, doc: Mapping[str, Any]) -> "WalkerAnchor": """Parse document to WalkerAnchor.""" doc = cast(dict, doc) - architype: dict = doc.pop("architype") + architype = architype_to_dataclass( + WalkerArchitype.__get_class__(doc.get("name") or ""), + doc.pop("architype"), + ) anchor = WalkerAnchor( + architype=architype, id=doc.pop("_id"), access=Permission.deserialize(doc.pop("access")), - state=WalkerAnchorState(connected=True), + state=AnchorState(connected=True), **doc, ) - architype_cls = WalkerArchitype.__get_class__(doc.get("name") or "") - anchor.architype = to_dataclass(architype_cls, architype, __jac__=anchor) + architype.__jac__ = anchor anchor.sync_hash() return anchor @classmethod - def ref(cls, ref_id: str) -> "WalkerAnchor | None": + def ref(cls, ref_id: str) -> "WalkerAnchor": """Return EdgeAnchor instance if existing.""" - if ref_id and (match := WALKER_ID_REGEX.search(ref_id)): - return cls( - name=match.group(1), - id=ObjectId(match.group(2)), - ) - return None + if match := WALKER_ID_REGEX.search(ref_id): + anchor = object.__new__(cls) + anchor.name = str(match.group(1)) + anchor.id = ObjectId(match.group(2)) + return anchor + raise ValueError(f"{ref_id}] is not a valid reference!") + + def insert( + self, + bulk_write: BulkWrite, + ) -> None: + """Append Insert Query.""" + bulk_write.operations[WalkerAnchor].append(InsertOne(self.serialize())) def delete(self, bulk_write: BulkWrite) -> None: """Append Delete Query.""" bulk_write.del_walker(self.id) - async def sync(self, node: "NodeAnchor | None" = None) -> "WalkerArchitype | None": # type: ignore[override] - """Retrieve the Architype from db and return.""" - return cast(WalkerArchitype | None, await super().sync(node)) - - async def visit_node(self, anchors: Iterable[NodeAnchor | EdgeAnchor]) -> bool: - """Walker visits node.""" - before_len = len(self.next) - for anchor in anchors: - if anchor not in self.ignores: - if isinstance(anchor, NodeAnchor): - self.next.append(anchor) - elif isinstance(anchor, EdgeAnchor): - if await anchor.sync() and (target := anchor.target): - self.next.append(target) - else: - raise ValueError("Edge has no target.") - return len(self.next) > before_len - - async def ignore_node(self, anchors: Iterable[NodeAnchor | EdgeAnchor]) -> bool: - """Walker ignores node.""" - before_len = len(self.ignores) - for anchor in anchors: - if anchor not in self.ignores: - if isinstance(anchor, NodeAnchor): - self.ignores.append(anchor) - elif isinstance(anchor, EdgeAnchor): - if await anchor.sync() and (target := anchor.target): - self.ignores.append(target) - else: - raise ValueError("Edge has no target.") - return len(self.ignores) > before_len - - def disengage_now(self) -> None: - """Disengage walker from traversal.""" - self.state.disengaged = True - - async def await_if_coroutine(self, ret: Any) -> Any: # noqa: ANN401 - """Await return if it's a coroutine.""" - if iscoroutine(ret): - ret = await ret - - self.returns.append(ret) - - return ret - - async def spawn_call(self, nd: Anchor) -> "WalkerArchitype": + def destroy(self) -> None: + """Delete Anchor.""" + if self.state.deleted is None: + from .context import JaseciContext + + jctx = JaseciContext.get() + + if jctx.root.has_write_access(self): + self.state.deleted = False + jctx.mem.remove(self.id) + + def spawn_call(self, node: Anchor) -> "WalkerArchitype": """Invoke data spatial call.""" - if walker := await self.sync(): + if walker := self.architype: self.path = [] - self.next = [nd] + self.next = [node] self.returns = [] while len(self.next): - if node := await self.next.pop(0).sync(): - for i in node._jac_entry_funcs_: + if current_node := self.next.pop(0).architype: + for i in current_node._jac_entry_funcs_: if not i.trigger or isinstance(walker, i.trigger): if i.func: - await self.await_if_coroutine(i.func(node, walker)) + self.returns.append(i.func(current_node, walker)) else: raise ValueError(f"No function {i.name} to call.") - if self.state.disengaged: + if self.disengaged: return walker for i in walker._jac_entry_funcs_: - if not i.trigger or isinstance(node, i.trigger): + if not i.trigger or isinstance(current_node, i.trigger): if i.func: - await self.await_if_coroutine(i.func(walker, node)) + self.returns.append(i.func(walker, current_node)) else: raise ValueError(f"No function {i.name} to call.") - if self.state.disengaged: + if self.disengaged: return walker for i in walker._jac_exit_funcs_: - if not i.trigger or isinstance(node, i.trigger): + if not i.trigger or isinstance(current_node, i.trigger): if i.func: - await self.await_if_coroutine(i.func(walker, node)) + self.returns.append(i.func(walker, current_node)) else: raise ValueError(f"No function {i.name} to call.") - if self.state.disengaged: + if self.disengaged: return walker - for i in node._jac_exit_funcs_: + for i in current_node._jac_exit_funcs_: if not i.trigger or isinstance(walker, i.trigger): if i.func: - await self.await_if_coroutine(i.func(node, walker)) + self.returns.append(i.func(current_node, walker)) else: raise ValueError(f"No function {i.name} to call.") - if self.state.disengaged: + if self.disengaged: return walker self.ignores = [] return walker - raise Exception(f"Invalid Reference {self.ref_id}") + raise Exception(f"Invalid Reference {self.id}") + + +@dataclass(eq=False, repr=False, kw_only=True) +class ObjectAnchor(BaseAnchor, Anchor): # type: ignore[misc] + """Object Anchor.""" -class Architype(_Architype): +class BaseArchitype: """Architype Protocol.""" - __jac_classes__: dict[str, type["Architype"]] + __jac_classes__: dict[str, type["BaseArchitype"]] __jac_hintings__: dict[str, type] __jac__: Anchor - def __init__(self, __jac__: Anchor | None = None) -> None: - """Create default architype.""" - if not __jac__: - __jac__ = Anchor(architype=self) - __jac__.allocate() - self.__jac__ = __jac__ - - def __getstate__(self) -> dict[str, Any]: + def __serialize__(self) -> dict[str, Any]: """Process default serialization.""" - return asdict(self) + if is_dataclass(self) and not isinstance(self, type): + return asdict(self) + raise ValueError( + f"{self.__jac__.__class__.__name__} {self.__class__.__name__} is not serializable!" + ) @classmethod def __ref_cls__(cls) -> str: @@ -1265,7 +1002,7 @@ def __set_classes__(cls) -> dict[str, Any]: return jac_classes @classmethod - def __get_class__(cls: type[TA], name: str) -> type[TA]: + def __get_class__(cls: type[TBA], name: str) -> type[TBA]: """Build class map from subclasses.""" jac_classes: dict[str, Any] | None = getattr(cls, "__jac_classes__", None) if not jac_classes or not (jac_class := jac_classes.get(name)): @@ -1275,17 +1012,20 @@ def __get_class__(cls: type[TA], name: str) -> type[TA]: return jac_class -class NodeArchitype(Architype): +class NodeArchitype(BaseArchitype, _NodeArchitype): """Node Architype Protocol.""" __jac__: NodeAnchor - def __init__(self, __jac__: NodeAnchor | None = None) -> None: + def __post_init__(self) -> None: """Create node architype.""" - if not __jac__: - __jac__ = NodeAnchor(name=self.__class__.__name__, architype=self) - __jac__.allocate() - self.__jac__ = __jac__ + self.__jac__ = NodeAnchor( + architype=self, + name=self.__class__.__name__, + edges=[], + access=Permission(), + state=AnchorState(), + ) @classmethod def __ref_cls__(cls) -> str: @@ -1293,17 +1033,27 @@ def __ref_cls__(cls) -> str: return f"n:{cls.__name__}" -class EdgeArchitype(Architype): +class EdgeArchitype(BaseArchitype, _EdgeArchitype): """Edge Architype Protocol.""" __jac__: EdgeAnchor - def __init__(self, __jac__: EdgeAnchor | None = None) -> None: - """Create edge architype.""" - if not __jac__: - __jac__ = EdgeAnchor(name=self.__class__.__name__, architype=self) - __jac__.allocate() - self.__jac__ = __jac__ + def __attach__( + self, + source: NodeAnchor, + target: NodeAnchor, + is_undirected: bool, + ) -> None: + """Attach EdgeAnchor properly.""" + self.__jac__ = EdgeAnchor( + architype=self, + name=self.__class__.__name__, + source=source, + target=target, + is_undirected=is_undirected, + access=Permission(), + state=AnchorState(), + ) @classmethod def __ref_cls__(cls) -> str: @@ -1311,17 +1061,19 @@ def __ref_cls__(cls) -> str: return f"e:{cls.__name__}" -class WalkerArchitype(Architype): +class WalkerArchitype(BaseArchitype, _WalkerArchitype): """Walker Architype Protocol.""" __jac__: WalkerAnchor - def __init__(self, __jac__: WalkerAnchor | None = None) -> None: + def __post_init__(self) -> None: """Create walker architype.""" - if not __jac__: - __jac__ = WalkerAnchor(name=self.__class__.__name__, architype=self) - __jac__.allocate() - self.__jac__ = __jac__ + self.__jac__ = WalkerAnchor( + architype=self, + name=self.__class__.__name__, + access=Permission(), + state=AnchorState(), + ) @classmethod def __ref_cls__(cls) -> str: @@ -1329,31 +1081,57 @@ def __ref_cls__(cls) -> str: return f"w:{cls.__name__}" +class ObjectArchitype(BaseArchitype, Architype): + """Object Architype Protocol.""" + + __jac__: ObjectAnchor + + def __post_init__(self) -> None: + """Create default architype.""" + self.__jac__ = ObjectAnchor( + architype=self, + name=self.__class__.__name__, + access=Permission(), + state=AnchorState(), + ) + + @dataclass(eq=False) class GenericEdge(EdgeArchitype): """Generic Root Node.""" - _jac_entry_funcs_: ClassVar[list[DSFunc]] = [] # type: ignore[misc] - _jac_exit_funcs_: ClassVar[list[DSFunc]] = [] # type: ignore[misc] + _jac_entry_funcs_: ClassVar[list[DSFunc]] = [] + _jac_exit_funcs_: ClassVar[list[DSFunc]] = [] - def __init__(self, __jac__: EdgeAnchor | None = None) -> None: - """Create walker architype.""" - if not __jac__: - __jac__ = EdgeAnchor(architype=self) - __jac__.allocate() - self.__jac__ = __jac__ + def __attach__( + self, + source: NodeAnchor, + target: NodeAnchor, + is_undirected: bool, + ) -> None: + """Attach EdgeAnchor properly.""" + self.__jac__ = EdgeAnchor( + architype=self, + source=source, + target=target, + is_undirected=is_undirected, + access=Permission(), + state=AnchorState(), + ) @dataclass(eq=False) class Root(NodeArchitype): """Generic Root Node.""" - _jac_entry_funcs_: ClassVar[list[DSFunc]] = [] # type: ignore[misc] - _jac_exit_funcs_: ClassVar[list[DSFunc]] = [] # type: ignore[misc] + _jac_entry_funcs_: ClassVar[list[DSFunc]] = [] + _jac_exit_funcs_: ClassVar[list[DSFunc]] = [] - def __init__(self, __jac__: NodeAnchor | None = None) -> None: - """Create walker architype.""" - if not __jac__: - __jac__ = NodeAnchor(architype=self) - __jac__.allocate() - self.__jac__ = __jac__ + def __post_init__(self) -> None: + """Create node architype.""" + self.__jac__ = NodeAnchor( + architype=self, + edges=[], + access=Permission(), + state=AnchorState(), + ) diff --git a/jaclang_jaseci/core/context.py b/jaclang_jaseci/core/context.py index 10dbe41..5206cb8 100644 --- a/jaclang_jaseci/core/context.py +++ b/jaclang_jaseci/core/context.py @@ -3,7 +3,7 @@ from contextvars import ContextVar from dataclasses import asdict, is_dataclass from os import getenv -from typing import Any, Callable, cast +from typing import Any, cast from bson import ObjectId @@ -15,7 +15,7 @@ AccessLevel, Anchor, AnchorState, - Architype, + BaseArchitype, NodeAnchor, Permission, Root, @@ -25,104 +25,96 @@ SHOW_ENDPOINT_RETURNS = getenv("SHOW_ENDPOINT_RETURNS") == "true" JASECI_CONTEXT = ContextVar["JaseciContext | None"]("JaseciContext") -SUPER_ROOT = ObjectId("000000000000000000000000") -PUBLIC_ROOT = ObjectId("000000000000000000000001") +SUPER_ROOT_ID = ObjectId("000000000000000000000000") +PUBLIC_ROOT_ID = ObjectId("000000000000000000000001") +SUPER_ROOT = NodeAnchor.ref(f"n::{SUPER_ROOT_ID}") +PUBLIC_ROOT = NodeAnchor.ref(f"n::{PUBLIC_ROOT_ID}") -class JaseciContext: + +class JaseciContext(ExecutionContext): """Execution Context.""" - base: ExecutionContext - request: Request - datasource: MongoDB + mem: MongoDB reports: list - super_root: NodeAnchor + system_root: NodeAnchor root: NodeAnchor - entry: NodeAnchor - - def generate_super_root(self) -> NodeAnchor: - """Generate default super root.""" - super_root = NodeAnchor( - id=SUPER_ROOT, state=AnchorState(current_access_level=AccessLevel.WRITE) - ) - architype = super_root.architype = object.__new__(Root) - architype.__jac__ = super_root - self.datasource.set(super_root) - return super_root - - def generate_public_root(self) -> NodeAnchor: - """Generate default super root.""" - public_root = NodeAnchor( - id=PUBLIC_ROOT, - access=Permission(all=AccessLevel.WRITE), - state=AnchorState(current_access_level=AccessLevel.WRITE), - ) - architype = public_root.architype = object.__new__(Root) - architype.__jac__ = public_root - self.datasource.set(public_root) - return public_root - - async def load( - self, - anchor: NodeAnchor | None, - default: NodeAnchor | Callable[[], NodeAnchor], - ) -> NodeAnchor: - """Load initial anchors.""" - if anchor: - if not anchor.state.connected: - if _anchor := await self.datasource.find_one(NodeAnchor, anchor): - _anchor.state.current_access_level = AccessLevel.WRITE - return _anchor - else: - self.datasource.set(anchor) - return anchor - - return default() if callable(default) else default - - async def validate_access(self) -> bool: + entry_node: NodeAnchor + base: ExecutionContext + request: Request + + def validate_access(self) -> bool: """Validate access.""" - return await self.root.has_read_access(self.entry) + return self.root.has_read_access(self.entry_node) - async def close(self) -> None: + def close(self) -> None: """Clean up context.""" - await self.datasource.close() + self.mem.close() @staticmethod - async def create( - request: Request, entry: NodeAnchor | None = None - ) -> "JaseciContext": + def create(request: Request, entry: NodeAnchor | None = None) -> "JaseciContext": # type: ignore[override] """Create JacContext.""" ctx = JaseciContext() ctx.base = ExecutionContext.get() ctx.request = request - ctx.datasource = MongoDB() + ctx.mem = MongoDB() ctx.reports = [] - ctx.super_root = await ctx.load( - NodeAnchor(id=SUPER_ROOT), ctx.generate_super_root - ) - ctx.root = await ctx.load( - getattr(request, "_root", None) or NodeAnchor(id=PUBLIC_ROOT), - ctx.generate_public_root, - ) - ctx.entry = await ctx.load(entry, ctx.root) + + if not isinstance(system_root := ctx.mem.find_by_id(SUPER_ROOT), NodeAnchor): + system_root = NodeAnchor( + architype=object.__new__(Root), + id=SUPER_ROOT_ID, + access=Permission(), + state=AnchorState(persistent=True), + edges=[], + ) + system_root.architype.__jac__ = system_root + ctx.mem.set(system_root.id, system_root) + + ctx.system_root = system_root + + if _root := getattr(request, "_root", None): + ctx.root = _root + else: + if not isinstance( + public_root := ctx.mem.find_by_id(PUBLIC_ROOT), NodeAnchor + ): + public_root = NodeAnchor( + architype=object.__new__(Root), + id=PUBLIC_ROOT_ID, + access=Permission(all=AccessLevel.WRITE), + state=AnchorState(persistent=True), + edges=[], + ) + public_root.architype.__jac__ = public_root + ctx.mem.set(public_root.id, public_root) + + ctx.root = public_root + + if entry: + if not isinstance(entry_node := ctx.mem.find_by_id(entry), NodeAnchor): + raise ValueError(f"Invalid anchor id {entry.ref_id} !") + ctx.entry_node = entry_node + else: + ctx.entry_node = ctx.root if _ctx := JASECI_CONTEXT.get(None): - await _ctx.close() + _ctx.close() JASECI_CONTEXT.set(ctx) return ctx @staticmethod def get() -> "JaseciContext": - """Get current ExecutionContext.""" - if not isinstance(ctx := JASECI_CONTEXT.get(None), JaseciContext): - raise Exception("JaseciContext is not yet available!") - return ctx + """Get current JaseciContext.""" + if ctx := JASECI_CONTEXT.get(None): + return ctx + raise Exception("JaseciContext is not yet available!") @staticmethod - def get_datasource() -> MongoDB: - """Get current datasource.""" - return JaseciContext.get().datasource + def get_root() -> Root: # type: ignore[override] + """Get current root.""" + return cast(Root, JaseciContext.get().root.architype) def response(self, returns: list[Any], status: int = 200) -> dict[str, Any]: """Return serialized version of reports.""" @@ -154,7 +146,7 @@ def clean_response( self.clean_response(key, dval, val) case Anchor(): cast(dict, obj)[key] = val.report() - case Architype(): + case BaseArchitype(): cast(dict, obj)[key] = val.__jac__.report() case val if is_dataclass(val) and not isinstance(val, type): cast(dict, obj)[key] = asdict(val) diff --git a/jaclang_jaseci/core/memory.py b/jaclang_jaseci/core/memory.py index 53448fd..5dae6f2 100644 --- a/jaclang_jaseci/core/memory.py +++ b/jaclang_jaseci/core/memory.py @@ -1,196 +1,153 @@ """Memory abstraction for jaseci plugin.""" -from dataclasses import dataclass, field +from dataclasses import dataclass from os import getenv -from typing import AsyncGenerator, Callable, Generator, Iterable, Type, TypeVar, cast +from typing import Callable, Generator, Iterable, TypeVar, cast from bson import ObjectId -from jaclang.runtimelib.architype import MANUAL_SAVE +from jaclang.runtimelib.memory import Memory -from motor.motor_asyncio import AsyncIOMotorClientSession - from pymongo import InsertOne +from pymongo.client_session import ClientSession from .architype import ( - AccessLevel, Anchor, - AnchorType, + BaseAnchor, BulkWrite, EdgeAnchor, NodeAnchor, Root, + WalkerAnchor, ) from ..jaseci.datasources import Collection DISABLE_AUTO_CLEANUP = getenv("DISABLE_AUTO_CLEANUP") == "true" SINGLE_QUERY = getenv("SINGLE_QUERY") == "true" IDS = ObjectId | Iterable[ObjectId] -A = TypeVar("A", bound="Anchor") - - -@dataclass -class Memory: - """Generic Memory Handler.""" - - __mem__: dict[ObjectId, Anchor] = field(default_factory=dict) - __gc__: set[Anchor] = field(default_factory=set) - - def close(self) -> None: - """Close memory handler.""" - self.__mem__.clear() - self.__gc__.clear() - - def find( - self, ids: IDS, filter: Callable[[Anchor], Anchor] | None = None - ) -> Generator[Anchor, None, None]: - """Find anchors from memory by ids with filter.""" - if not isinstance(ids, Iterable): - ids = [ids] - - return ( - anchor - for id in ids - if (anchor := self.__mem__.get(id)) and (not filter or filter(anchor)) - ) - - def find_one( - self, - ids: IDS, - filter: Callable[[Anchor], Anchor] | None = None, - ) -> Anchor | None: - """Find one anchor from memory by ids with filter.""" - return next(self.find(ids, filter), None) - - def set(self, data: Anchor | Iterable[Anchor]) -> None: - """Save anchor/s to memory.""" - if isinstance(data, Iterable): - for d in data: - if d not in self.__gc__: - self.__mem__[d.id] = d - elif data not in self.__gc__: - self.__mem__[data.id] = data - - def remove(self, data: Anchor | Iterable[Anchor]) -> None: - """Remove anchor/s from memory.""" - if isinstance(data, Iterable): - for d in data: - self.__mem__.pop(d.id, None) - self.__gc__.add(d) - else: - self.__mem__.pop(data.id, None) - self.__gc__.add(data) +BA = TypeVar("BA", bound="BaseAnchor") @dataclass -class MongoDB(Memory): +class MongoDB(Memory[ObjectId, BaseAnchor | Anchor]): """Shelf Handler.""" - __session__: AsyncIOMotorClientSession | None = None + __session__: ClientSession | None = None - async def find( # type: ignore[override] + def populate_data(self, edges: Iterable[EdgeAnchor]) -> None: + """Populate data to avoid multiple query.""" + if not SINGLE_QUERY: + nodes: set[NodeAnchor] = set() + for edge in self.find(edges): + if edge.source: + nodes.add(edge.source) + if edge.target: + nodes.add(edge.target) + self.find(nodes) + + def find( # type: ignore[override] self, - type: Type[A], - anchors: A | Iterable[A], + anchors: BA | Iterable[BA], filter: Callable[[Anchor], Anchor] | None = None, - session: AsyncIOMotorClientSession | None = None, - ) -> AsyncGenerator[A, None]: + session: ClientSession | None = None, + ) -> Generator[BA, None, None]: """Find anchors from datasource by ids with filter.""" if not isinstance(anchors, Iterable): anchors = [anchors] - async for anchor in await type.Collection.find( - { - "_id": { - "$in": [ - anchor.id - for anchor in anchors - if anchor.id not in self.__mem__ and anchor not in self.__gc__ - ] + collections: dict[type[Collection[BaseAnchor]], list[ObjectId]] = {} + for anchor in anchors: + if anchor.id not in self.__mem__ and anchor not in self.__gc__: + coll = collections.get(anchor.Collection) + if coll is None: + coll = collections[anchor.Collection] = [] + + coll.append(anchor.id) + + for cl, ids in collections.items(): + for anch_db in cl.find( + { + "_id": {"$in": ids}, }, - }, - session=session or self.__session__, - ): - self.__mem__[anchor.id] = anchor + session=session or self.__session__, + ): + self.__mem__[anch_db.id] = anch_db for anchor in anchors: if ( anchor not in self.__gc__ - and (_anchor := self.__mem__.get(anchor.id)) - and (not filter or filter(_anchor)) + and (anch_mem := self.__mem__.get(anchor.id)) + and (not filter or filter(anch_mem)) # type: ignore[arg-type] ): - yield cast(A, _anchor) + yield cast(BA, anch_mem) - async def find_one( # type: ignore[override] + def find_one( # type: ignore[override] self, - type: Type[A], - anchors: A | Iterable[A], + anchors: BA | Iterable[BA], filter: Callable[[Anchor], Anchor] | None = None, - session: AsyncIOMotorClientSession | None = None, - ) -> A | None: + session: ClientSession | None = None, + ) -> BA | None: """Find one anchor from memory by ids with filter.""" - return await anext(self.find(type, anchors, filter, session), None) + return next(self.find(anchors, filter, session), None) - async def populate_data(self, edges: Iterable[EdgeAnchor]) -> None: - """Populate data to avoid multiple query.""" - if not SINGLE_QUERY: - nodes: set[NodeAnchor] = set() - async for edge in self.find(EdgeAnchor, edges): - if edge.source: - nodes.add(edge.source) - if edge.target: - nodes.add(edge.target) - self.find(NodeAnchor, nodes) + def find_by_id(self, anchor: BA) -> BA | None: + """Find one by id.""" + data = super().find_by_id(anchor.id) - def get_bulk_write(self) -> BulkWrite: - """Sync memory to database.""" - bulk_write = BulkWrite() + if not data and (data := anchor.Collection.find_by_id(anchor.id)): + self.__mem__[data.id] = data - for anchor in self.__gc__: - if anchor.state.deleted is False: - match anchor.type: - case AnchorType.node: - bulk_write.del_node(anchor.id) - case AnchorType.edge: - bulk_write.del_edge(anchor.id) - case AnchorType.walker: - bulk_write.del_walker(anchor.id) - case _: - pass - - if not MANUAL_SAVE: - for anchor in self.__mem__.values(): - if anchor.architype and anchor.state.persistent: - if not anchor.state.connected: - anchor.state.connected = True - anchor.sync_hash() - bulk_write.operations[anchor.type].append( - InsertOne(anchor.serialize()) - ) - elif anchor.state.current_access_level > AccessLevel.READ: - if ( - not DISABLE_AUTO_CLEANUP - and isinstance(anchor, NodeAnchor) - and not isinstance(anchor.architype, Root) - and not anchor.edges - ): - bulk_write.del_node(anchor.id) - else: - anchor.update(bulk_write) + return data - return bulk_write - - async def close(self) -> None: # type: ignore[override] + def close(self) -> None: """Close memory handler.""" bulk_write = self.get_bulk_write() if bulk_write.has_operations: if session := self.__session__: - await bulk_write.execute(session) + bulk_write.execute(session) else: - async with await Collection.get_session() as session: - async with session.start_transaction(): - await bulk_write.execute(session) + with Collection.get_session() as session, session.start_transaction(): + bulk_write.execute(session) super().close() + + def get_bulk_write(self) -> BulkWrite: + """Sync memory to database.""" + from .context import JaseciContext + + JaseciContext + bulk_write = BulkWrite() + + for anchor in self.__gc__: + match type(anchor): + case NodeAnchor(): + bulk_write.del_node(anchor.id) + case EdgeAnchor(): + bulk_write.del_edge(anchor.id) + case WalkerAnchor(): + bulk_write.del_walker(anchor.id) + case _: + pass + + for anchor in self.__mem__.values(): + if anchor.architype and anchor.state.persistent: + if not anchor.state.connected: + anchor.state.connected = True + anchor.sync_hash() + bulk_write.operations[anchor.__class__].append( + InsertOne(anchor.serialize()) + ) + elif anchor.has_connect_access(anchor): + if ( + not DISABLE_AUTO_CLEANUP + and isinstance(anchor, NodeAnchor) + and not isinstance(anchor.architype, Root) + and not anchor.edges + ): + bulk_write.del_node(anchor.id) + else: + anchor.update(bulk_write) + + return bulk_write diff --git a/jaclang_jaseci/jaseci/__init__.py b/jaclang_jaseci/jaseci/__init__.py index 3ddb269..9ba429c 100644 --- a/jaclang_jaseci/jaseci/__init__.py +++ b/jaclang_jaseci/jaseci/__init__.py @@ -25,7 +25,7 @@ def get(cls) -> _FaststAPI: async def lifespan(app: _FaststAPI) -> AsyncGenerator[None, _FaststAPI]: from .datasources import Collection - await Collection.apply_indexes() + Collection.apply_indexes() yield cls.__app__ = _FaststAPI(lifespan=lifespan) diff --git a/jaclang_jaseci/jaseci/datasources/collection.py b/jaclang_jaseci/jaseci/datasources/collection.py index 47fd28d..5425268 100644 --- a/jaclang_jaseci/jaseci/datasources/collection.py +++ b/jaclang_jaseci/jaseci/datasources/collection.py @@ -125,7 +125,7 @@ def __document__(cls, doc: Mapping[str, Any]) -> T: return cast(T, doc) @classmethod - def __documents__(cls, docs: Cursor) -> Generator[T, None]: + def __documents__(cls, docs: Cursor) -> Generator[T, None, None]: """ Return parsed version of multiple documents. @@ -151,7 +151,7 @@ def get_client() -> MongoClient: @staticmethod def get_session() -> ClientSession: """Return pymongo.client_session.ClientSession used for mongodb transactional operations.""" - return await Collection.get_client().start_session() + return Collection.get_client().start_session() @staticmethod def get_database() -> Database: @@ -238,7 +238,7 @@ def find( projection: Mapping[str, Any] | Iterable[str] | None = None, session: ClientSession | None = None, **kwargs: Any, # noqa: ANN401 - ) -> Generator[T, None]: + ) -> Generator[T, None, None]: """Retrieve multiple documents.""" if projection is None: projection = cls.__excluded_obj__ @@ -378,7 +378,7 @@ class AsyncCollection(Generic[T]): @staticmethod async def apply_indexes() -> None: """Apply Indexes.""" - queue: list[type[Collection]] = Collection.__subclasses__() + queue: list[type[AsyncCollection]] = AsyncCollection.__subclasses__() while queue: cls = queue.pop(-1) @@ -429,8 +429,8 @@ async def __documents__(cls, docs: AsyncIOMotorCursor) -> AsyncGenerator[T, None @staticmethod def get_client() -> AsyncIOMotorClient: """Return pymongo.database.Database for mongodb connection.""" - if not isinstance(Collection.__client__, AsyncIOMotorClient): - Collection.__client__ = AsyncIOMotorClient( + if not isinstance(AsyncCollection.__client__, AsyncIOMotorClient): + AsyncCollection.__client__ = AsyncIOMotorClient( getenv( "DATABASE_HOST", "mongodb://localhost/?retryWrites=true&w=majority", @@ -438,27 +438,27 @@ def get_client() -> AsyncIOMotorClient: server_api=ServerApi("1"), ) - return Collection.__client__ + return AsyncCollection.__client__ @staticmethod async def get_session() -> AsyncIOMotorClientSession: """Return pymongo.client_session.ClientSession used for mongodb transactional operations.""" - return await Collection.get_client().start_session() + return await AsyncCollection.get_client().start_session() @staticmethod def get_database() -> AsyncIOMotorDatabase: """Return pymongo.database.Database for database connection based from current client connection.""" - if not isinstance(Collection.__database__, AsyncIOMotorDatabase): - Collection.__database__ = Collection.get_client().get_database( + if not isinstance(AsyncCollection.__database__, AsyncIOMotorDatabase): + AsyncCollection.__database__ = AsyncCollection.get_client().get_database( getenv("DATABASE_NAME", "jaseci") ) - return Collection.__database__ + return AsyncCollection.__database__ @staticmethod def get_collection(collection: str) -> AsyncIOMotorCollection: """Return pymongo.collection.Collection for collection connection based from current database connection.""" - return Collection.get_database().get_collection(collection) + return AsyncCollection.get_database().get_collection(collection) @classmethod def collection(cls) -> AsyncIOMotorCollection: diff --git a/jaclang_jaseci/jaseci/models/user.py b/jaclang_jaseci/jaseci/models/user.py index 43b11f4..1fe40a2 100644 --- a/jaclang_jaseci/jaseci/models/user.py +++ b/jaclang_jaseci/jaseci/models/user.py @@ -70,9 +70,9 @@ def __document__(cls, doc: Mapping[str, Any]) -> "User": ) @classmethod - async def find_by_email(cls, email: str) -> "User | None": + def find_by_email(cls, email: str) -> "User | None": """Retrieve user via email.""" - return await cls.find_one(filter={"email": email}, projection={}) + return cls.find_one(filter={"email": email}, projection={}) def serialize(self) -> dict: """Return BaseModel.model_dump excluding the password field.""" @@ -118,16 +118,16 @@ def register_type() -> Type[UserRegistration]: return create_model("UserRegister", __base__=UserRegistration, **user_model) @staticmethod - async def send_verification_code(code: str, email: str) -> None: + def send_verification_code(code: str, email: str) -> None: """Send verification code.""" pass @staticmethod - async def send_reset_code(code: str, email: str) -> None: + def send_reset_code(code: str, email: str) -> None: """Send verification code.""" pass @staticmethod - async def sso_mapper(open_id: OpenID) -> dict[str, object]: + def sso_mapper(open_id: OpenID) -> dict[str, object]: """Send verification code.""" return {} diff --git a/jaclang_jaseci/jaseci/routers/sso.py b/jaclang_jaseci/jaseci/routers/sso.py index 5123825..cb77565 100644 --- a/jaclang_jaseci/jaseci/routers/sso.py +++ b/jaclang_jaseci/jaseci/routers/sso.py @@ -3,6 +3,8 @@ from os import getenv from typing import Any, cast +from asyncer import syncify + from bson import ObjectId from fastapi import APIRouter, Request, Response @@ -80,7 +82,7 @@ @router.get("/{platform}/{operation}") -async def sso_operation( +def sso_operation( request: Request, platform: str, operation: str, @@ -91,7 +93,7 @@ async def sso_operation( if sso := SSO.get(platform): with sso: params = request.query_params._dict - return await sso.get_login_redirect( + return syncify(sso.get_login_redirect)( redirect_uri=params.pop("redirect_uri", None) or f"{SSO_HOST}/{platform}/{operation}/callback", state=params.pop("state", None), @@ -101,24 +103,24 @@ async def sso_operation( @router.get("/{platform}/{operation}/callback") -async def sso_callback( +def sso_callback( request: Request, platform: str, operation: str, redirect_uri: str | None = None ) -> Response: """SSO Login API.""" if sso := SSO.get(platform): with sso: - if open_id := await sso.verify_and_process( + if open_id := syncify(sso.verify_and_process)( request, redirect_uri=redirect_uri or f"{SSO_HOST}/{platform}/{operation}/callback", ): match operation: case "login": - return await login(platform, open_id) + return login(platform, open_id) case "register": - return await register(platform, open_id) + return register(platform, open_id) case "attach": - return await attach(platform, open_id) + return attach(platform, open_id) case _: pass @@ -126,10 +128,10 @@ async def sso_callback( @router.post("/attach", dependencies=authenticator) -async def sso_attach(request: Request, attach_sso: AttachSSO) -> ORJSONResponse: +def sso_attach(request: Request, attach_sso: AttachSSO) -> ORJSONResponse: """Generate token from user.""" if SSO.get(attach_sso.platform): - if await User.Collection.find_one( + if User.Collection.find_one( { "$or": [ {f"sso.{platform}.id": attach_sso.id}, @@ -139,7 +141,7 @@ async def sso_attach(request: Request, attach_sso: AttachSSO) -> ORJSONResponse: ): return ORJSONResponse({"message": "Already Attached!"}, 403) - await User.Collection.update_one( + User.Collection.update_one( {"_id": ObjectId(request._user.id)}, { "$set": { @@ -156,10 +158,10 @@ async def sso_attach(request: Request, attach_sso: AttachSSO) -> ORJSONResponse: @router.delete("/detach", dependencies=authenticator) -async def sso_detach(request: Request, detach_sso: DetachSSO) -> ORJSONResponse: +def sso_detach(request: Request, detach_sso: DetachSSO) -> ORJSONResponse: """Generate token from user.""" if SSO.get(detach_sso.platform): - await User.Collection.update_one( + User.Collection.update_one( {"_id": ObjectId(request._user.id)}, {"$unset": {f"sso.{detach_sso.platform}": 1}}, ) @@ -167,17 +169,17 @@ async def sso_detach(request: Request, detach_sso: DetachSSO) -> ORJSONResponse: return ORJSONResponse({"message": "Feature not yet implemented!"}, 501) -async def get_token(user: User) -> ORJSONResponse: # type: ignore +def get_token(user: User) -> ORJSONResponse: # type: ignore """Generate token from user.""" user_json = user.serialize() # type: ignore[attr-defined] - token = await create_token(user_json) + token = create_token(user_json) return ORJSONResponse(content={"token": token, "user": user_json}) -async def login(platform: str, open_id: OpenID) -> Response: +def login(platform: str, open_id: OpenID) -> Response: """Login user method.""" - if user := await BaseUser.Collection.find_one( + if user := BaseUser.Collection.find_one( { "$or": [ {f"sso.{platform}.id": open_id.id}, @@ -186,21 +188,19 @@ async def login(platform: str, open_id: OpenID) -> Response: } ): if not user.is_activated: - await User.send_verification_code( - await create_code(ObjectId(user.id)), user.email - ) + User.send_verification_code(create_code(ObjectId(user.id)), user.email) raise HTTPException( status_code=400, detail="Account not yet verified! Resending verification code...", ) - return await get_token(user) + return get_token(user) return ORJSONResponse({"message": "Not Existing!"}, 403) -async def register(platform: str, open_id: OpenID) -> Response: +def register(platform: str, open_id: OpenID) -> Response: """Register user method.""" - if user := await User.Collection.find_one( + if user := User.Collection.find_one( { "$or": [ {f"sso.{platform}.id": open_id.id}, @@ -208,65 +208,60 @@ async def register(platform: str, open_id: OpenID) -> Response: ] } ): - return await get_token(cast(User, user)) # type: ignore - - async with await User.Collection.get_session() as session: - async with session.start_transaction(): - retry = 0 - max_retry = BulkWrite.SESSION_MAX_TRANSACTION_RETRY - while retry <= max_retry: - try: - if not await User.Collection.update_one( - {"email": open_id.email}, - { - "$set": { - f"sso.{platform}": { - "id": open_id.id, - "email": open_id.email, - }, - "is_activated": True, - } - }, - session=session, - ): - root = Root().__jac__ - ureq: dict[str, object] = User.register_type()( - email=open_id.email, - password=NO_PASSWORD, - **await User.sso_mapper(open_id), - ).obfuscate() - ureq["root_id"] = root.id - ureq["is_activated"] = True - ureq["sso"] = { - platform: {"id": open_id.id, "email": open_id.email} + return get_token(cast(User, user)) # type: ignore + + with User.Collection.get_session() as session, session.start_transaction(): + retry = 0 + max_retry = BulkWrite.SESSION_MAX_TRANSACTION_RETRY + while retry <= max_retry: + try: + if not User.Collection.update_one( + {"email": open_id.email}, + { + "$set": { + f"sso.{platform}": { + "id": open_id.id, + "email": open_id.email, + }, + "is_activated": True, } - await NodeAnchor.Collection.insert_one( - root.serialize(), session - ) - await User.Collection.insert_one(ureq, session=session) - await BulkWrite.commit(session) - return await login(platform, open_id) - except (ConnectionFailure, OperationFailure) as ex: - if ex.has_error_label("TransientTransactionError"): - retry += 1 - logger.error( - "Error executing bulk write! " - f"Retrying [{retry}/{max_retry}] ..." - ) - continue - logger.exception("Error executing bulk write!") - await session.abort_transaction() - break - except Exception: - logger.exception("Error executing bulk write!") - await session.abort_transaction() - break + }, + session=session, + ): + root = Root().__jac__ + ureq: dict[str, object] = User.register_type()( + email=open_id.email, + password=NO_PASSWORD, + **User.sso_mapper(open_id), + ).obfuscate() + ureq["root_id"] = root.id + ureq["is_activated"] = True + ureq["sso"] = {platform: {"id": open_id.id, "email": open_id.email}} + NodeAnchor.Collection.insert_one(root.serialize(), session) + User.Collection.insert_one(ureq, session=session) + BulkWrite.commit(session) + return login(platform, open_id) + except (ConnectionFailure, OperationFailure) as ex: + if ex.has_error_label("TransientTransactionError"): + retry += 1 + logger.error( + "Error executing bulk write! " + f"Retrying [{retry}/{max_retry}] ..." + ) + continue + logger.exception("Error executing bulk write!") + session.abort_transaction() + break + except Exception: + logger.exception("Error executing bulk write!") + session.abort_transaction() + break return ORJSONResponse({"message": "Registration Failed!"}, 409) -async def attach(platform: str, open_id: OpenID) -> Response: +def attach(platform: str, open_id: OpenID) -> Response: """Return openid .""" - if await User.Collection.find_one( + if User.Collection.find_one( { "$or": [ {f"sso.{platform}.id": open_id.id}, diff --git a/jaclang_jaseci/jaseci/routers/user.py b/jaclang_jaseci/jaseci/routers/user.py index 63a959e..e4a1bc2 100644 --- a/jaclang_jaseci/jaseci/routers/user.py +++ b/jaclang_jaseci/jaseci/routers/user.py @@ -36,47 +36,42 @@ @router.post("/register", status_code=status.HTTP_200_OK) -async def register(req: User.register_type()) -> ORJSONResponse: # type: ignore +def register(req: User.register_type()) -> ORJSONResponse: # type: ignore """Register user API.""" - async with await User.Collection.get_session() as session: - async with session.start_transaction(): - root = Root().__jac__ - - req_obf: dict = req.obfuscate() - req_obf["root_id"] = root.id - is_activated = req_obf["is_activated"] = not Emailer.has_client() - - retry = 0 - max_retry = BulkWrite.SESSION_MAX_TRANSACTION_RETRY - while retry <= max_retry: - try: - await NodeAnchor.Collection.insert_one(root.serialize(), session) - if id := ( - await User.Collection.insert_one(req_obf, session=session) - ).inserted_id: - await BulkWrite.commit(session) - if not is_activated: - await User.send_verification_code( - await create_code(id), req.email - ) - return ORJSONResponse( - {"message": "Successfully Registered!"}, 201 - ) - except (ConnectionFailure, OperationFailure) as ex: - if ex.has_error_label("TransientTransactionError"): - retry += 1 - logger.error( - "Error executing bulk write! " - f"Retrying [{retry}/{max_retry}] ..." - ) - continue - logger.exception("Error executing bulk write!") - await session.abort_transaction() - break - except Exception: - logger.exception("Error executing bulk write!") - await session.abort_transaction() - break + with User.Collection.get_session() as session, session.start_transaction(): + root = Root().__jac__ + + req_obf: dict = req.obfuscate() + req_obf["root_id"] = root.id + is_activated = req_obf["is_activated"] = not Emailer.has_client() + + retry = 0 + max_retry = BulkWrite.SESSION_MAX_TRANSACTION_RETRY + while retry <= max_retry: + try: + NodeAnchor.Collection.insert_one(root.serialize(), session) + if id := ( + User.Collection.insert_one(req_obf, session=session) + ).inserted_id: + BulkWrite.commit(session) + if not is_activated: + User.send_verification_code(create_code(id), req.email) + return ORJSONResponse({"message": "Successfully Registered!"}, 201) + except (ConnectionFailure, OperationFailure) as ex: + if ex.has_error_label("TransientTransactionError"): + retry += 1 + logger.error( + "Error executing bulk write! " + f"Retrying [{retry}/{max_retry}] ..." + ) + continue + logger.exception("Error executing bulk write!") + session.abort_transaction() + break + except Exception: + logger.exception("Error executing bulk write!") + session.abort_transaction() + break return ORJSONResponse({"message": "Registration Failed!"}, 409) @@ -85,20 +80,20 @@ async def register(req: User.register_type()) -> ORJSONResponse: # type: ignore status_code=status.HTTP_200_OK, dependencies=authenticator, ) -async def send_verification_code(request: Request) -> ORJSONResponse: +def send_verification_code(request: Request) -> ORJSONResponse: """Verify user API.""" user: BaseUser = request._user # type: ignore[attr-defined] if user.is_activated: return ORJSONResponse({"message": "Account is already verified!"}, 400) else: - await User.send_verification_code(await create_code(user.id), user.email) + User.send_verification_code(create_code(user.id), user.email) return ORJSONResponse({"message": "Successfully sent verification code!"}, 200) @router.post("/verify") -async def verify(req: UserVerification) -> ORJSONResponse: +def verify(req: UserVerification) -> ORJSONResponse: """Verify user API.""" - if (user_id := await verify_code(req.code)) and await User.Collection.update_by_id( + if (user_id := verify_code(req.code)) and User.Collection.update_by_id( user_id, {"$set": {"is_activated": True}} ): return ORJSONResponse({"message": "Successfully Verified!"}, 200) @@ -107,21 +102,21 @@ async def verify(req: UserVerification) -> ORJSONResponse: @router.post("/login") -async def root(req: UserRequest) -> ORJSONResponse: +def root(req: UserRequest) -> ORJSONResponse: """Login user API.""" - user: BaseUser = await User.Collection.find_by_email(req.email) # type: ignore + user: BaseUser = User.Collection.find_by_email(req.email) # type: ignore if not user or not pbkdf2_sha512.verify(req.password, user.password): raise HTTPException(status_code=400, detail="Invalid Email/Password!") if RESTRICT_UNVERIFIED_USER and not user.is_activated: - await User.send_verification_code(await create_code(user.id), req.email) + User.send_verification_code(create_code(user.id), req.email) raise HTTPException( status_code=400, detail="Account not yet verified! Resending verification code...", ) user_json = user.serialize() - token = await create_token(user_json) + token = create_token(user_json) return ORJSONResponse(content={"token": token, "user": user_json}) @@ -129,41 +124,39 @@ async def root(req: UserRequest) -> ORJSONResponse: @router.post( "/change_password", status_code=status.HTTP_200_OK, dependencies=authenticator ) -async def change_password(request: Request, ucp: UserChangePassword) -> ORJSONResponse: # type: ignore +def change_password(request: Request, ucp: UserChangePassword) -> ORJSONResponse: # type: ignore """Register user API.""" user: BaseUser | None = getattr(request, "_user", None) if user: - with_pass = await User.Collection.find_by_email(user.email) + with_pass = User.Collection.find_by_email(user.email) if ( with_pass and pbkdf2_sha512.verify(ucp.old_password, with_pass.password) - and await User.Collection.update_one( + and User.Collection.update_one( {"_id": user.id}, {"$set": {"password": pbkdf2_sha512.hash(ucp.new_password).encode()}}, ) ): - await invalidate_token(user.id) + invalidate_token(user.id) return ORJSONResponse({"message": "Successfully Updated!"}, 200) return ORJSONResponse({"message": "Update Failed!"}, 403) @router.post("/forgot_password", status_code=status.HTTP_200_OK) -async def forgot_password(ufp: UserForgotPassword) -> ORJSONResponse: +def forgot_password(ufp: UserForgotPassword) -> ORJSONResponse: """Forgot password API.""" - user = await User.Collection.find_by_email(ufp.email) + user = User.Collection.find_by_email(ufp.email) if isinstance(user, User): - await User.send_reset_code(await create_code(user.id, True), user.email) + User.send_reset_code(create_code(user.id, True), user.email) return ORJSONResponse({"message": "Reset password email sent!"}, 200) else: return ORJSONResponse({"message": "Failed to process forgot password!"}, 403) @router.post("/reset_password", status_code=status.HTTP_200_OK) -async def reset_password(urp: UserResetPassword) -> ORJSONResponse: +def reset_password(urp: UserResetPassword) -> ORJSONResponse: """Reset password API.""" - if ( - user_id := await verify_code(urp.code, True) - ) and await User.Collection.update_by_id( + if (user_id := verify_code(urp.code, True)) and User.Collection.update_by_id( user_id, { "$set": { @@ -172,7 +165,7 @@ async def reset_password(urp: UserResetPassword) -> ORJSONResponse: } }, ): - await invalidate_token(user_id) + invalidate_token(user_id) return ORJSONResponse({"message": "Password reset successfully!"}, 200) return ORJSONResponse({"message": "Failed to reset password!"}, 403) diff --git a/jaclang_jaseci/jaseci/security/__init__.py b/jaclang_jaseci/jaseci/security/__init__.py index 63288e3..64bfa64 100644 --- a/jaclang_jaseci/jaseci/security/__init__.py +++ b/jaclang_jaseci/jaseci/security/__init__.py @@ -14,7 +14,7 @@ from ..datasources.redis import CodeRedis, TokenRedis from ..models.user import User as BaseUser from ..utils import logger, random_string, utc_timestamp -from ...core.architype import AccessLevel, NodeAnchor +from ...core.architype import NodeAnchor TOKEN_SECRET = getenv("TOKEN_SECRET", random_string(50)) @@ -39,7 +39,7 @@ def decrypt(token: str) -> dict | None: return None -async def create_code(user_id: ObjectId, reset: bool = False) -> str: +def create_code(user_id: ObjectId, reset: bool = False) -> str: """Generate Verification Code.""" verification = encrypt( { @@ -50,41 +50,41 @@ async def create_code(user_id: ObjectId, reset: bool = False) -> str: ), } ) - if await CodeRedis.hset(key=verification, data=True): + if CodeRedis.hset(key=verification, data=True): return verification raise HTTPException(500, "Verification Creation Failed!") -async def verify_code(code: str, reset: bool = False) -> ObjectId | None: +def verify_code(code: str, reset: bool = False) -> ObjectId | None: """Verify Code.""" decrypted = decrypt(code) if ( decrypted and decrypted["reset"] == reset and decrypted["expiration"] > utc_timestamp() - and await CodeRedis.hget(key=code) + and CodeRedis.hget(key=code) ): - await CodeRedis.hdelete(code) + CodeRedis.hdelete(code) return ObjectId(decrypted["user_id"]) return None -async def create_token(user: dict[str, Any]) -> str: +def create_token(user: dict[str, Any]) -> str: """Generate token for current user.""" user["expiration"] = utc_timestamp(hours=TOKEN_TIMEOUT) user["state"] = random_string(8) token = encrypt(user) - if await TokenRedis.hset(f"{user['id']}:{token}", True): + if TokenRedis.hset(f"{user['id']}:{token}", True): return token raise HTTPException(500, "Token Creation Failed!") -async def invalidate_token(user_id: ObjectId) -> None: +def invalidate_token(user_id: ObjectId) -> None: """Invalidate token of current user.""" - await TokenRedis.hdelete_rgx(f"{user_id}:*") + TokenRedis.hdelete_rgx(f"{user_id}:*") -async def authenticate(request: Request) -> None: +def authenticate(request: Request) -> None: """Authenticate current request and attach authenticated user and their root.""" authorization = request.headers.get("Authorization") if authorization and authorization.lower().startswith("bearer"): @@ -93,11 +93,10 @@ async def authenticate(request: Request) -> None: if ( decrypted and decrypted["expiration"] > utc_timestamp() - and await TokenRedis.hget(f"{decrypted['id']}:{token}") - and (user := await User.Collection.find_by_id(decrypted["id"])) - and (root := await NodeAnchor.Collection.find_by_id(user.root_id)) + and TokenRedis.hget(f"{decrypted['id']}:{token}") + and (user := User.Collection.find_by_id(decrypted["id"])) + and (root := NodeAnchor.Collection.find_by_id(user.root_id)) ): - root.state.current_access_level = AccessLevel.WRITE request._user = user # type: ignore[attr-defined] request._root = root # type: ignore[attr-defined] return diff --git a/jaclang_jaseci/plugin/jaseci.py b/jaclang_jaseci/plugin/jaseci.py index 92ce55b..fa507c4 100644 --- a/jaclang_jaseci/plugin/jaseci.py +++ b/jaclang_jaseci/plugin/jaseci.py @@ -3,11 +3,12 @@ from collections import OrderedDict from dataclasses import Field, MISSING, fields from functools import wraps -from inspect import iscoroutinefunction from os import getenv from re import compile from typing import Any, Callable, Type, TypeVar, cast, get_type_hints +from asyncer import syncify + from fastapi import ( APIRouter, Depends, @@ -19,12 +20,9 @@ ) from fastapi.responses import ORJSONResponse -from jaclang.compiler.absyntree import Ability, AstAsyncNode -from jaclang.compiler.constant import EdgeDir -from jaclang.compiler.passes.main.pyast_gen_pass import PyastGenPass from jaclang.plugin.default import hookimpl from jaclang.plugin.feature import JacFeature as Jac -from jaclang.runtimelib.context import ExecutionContext +from jaclang.runtimelib.architype import DSFunc from orjson import loads @@ -35,16 +33,16 @@ from ..core.architype import ( Anchor, Architype, - DSFunc, EdgeArchitype, GenericEdge, NodeAnchor, NodeArchitype, + ObjectArchitype, Root, WalkerAnchor, WalkerArchitype, ) -from ..core.context import JASECI_CONTEXT, JaseciContext +from ..core.context import JaseciContext from ..jaseci.security import authenticator @@ -141,7 +139,7 @@ def populate_apis(cls: Type[WalkerArchitype]) -> None: payload_model = create_model(f"{cls.__name__.lower()}_request_model", **payload) - async def api_entry( + def api_entry( request: Request, node: str | None, payload: payload_model = Depends(), # type: ignore # noqa: B008 @@ -150,31 +148,31 @@ async def api_entry( body = pl.get("body", {}) if isinstance(body, BaseUploadFile) and body_model: - body = loads(await body.read()) + body = loads(syncify(body.read)()) try: body = body_model(**body).model_dump() except ValidationError as e: return ORJSONResponse({"detail": e.errors()}) - jctx = await JaseciContext.create(request, NodeAnchor.ref(node or "")) + jctx = JaseciContext.create(request, NodeAnchor.ref(node) if node else None) wlk: WalkerAnchor = cls(**body, **pl["query"], **pl["files"]).__jac__ - if await jctx.validate_access(): - await wlk.spawn_call(jctx.entry) - await jctx.close() + if jctx.validate_access(): + wlk.spawn_call(jctx.entry_node) + jctx.close() return ORJSONResponse(jctx.response(wlk.returns)) else: - await jctx.close() + jctx.close() raise HTTPException( 403, - f"You don't have access on target entry{cast(Anchor, jctx.entry).ref_id}!", + f"You don't have access on target entry{cast(Anchor, jctx.entry_node).ref_id}!", ) - async def api_root( + def api_root( request: Request, payload: payload_model = Depends(), # type: ignore # noqa: B008 ) -> Response: - return await api_entry(request, None, payload) + return api_entry(request, None, payload) for method in methods: method = method.lower() @@ -245,12 +243,9 @@ class JacPlugin: @staticmethod @hookimpl - def current_context() -> JaseciContext | ExecutionContext: - """Get the execution context.""" - if not isinstance(ctx := JASECI_CONTEXT.get(None), JaseciContext): - return ctx - - return ExecutionContext.get() + def get_context() -> JaseciContext: + """Get current execution context.""" + return JaseciContext.get() @staticmethod @hookimpl @@ -291,13 +286,8 @@ def make_architype( inner_init = cls.__init__ # type: ignore @wraps(inner_init) - def new_init( - self: Architype, - *args: object, - __jac__: Anchor | None = None, - **kwargs: object, - ) -> None: - arch_base.__init__(self, __jac__) + def new_init(self: Architype, *args: object, **kwargs: object) -> None: + arch_base.__init__(self) inner_init(self, *args, **kwargs) cls.__init__ = new_init # type: ignore @@ -312,8 +302,8 @@ def make_obj( def decorator(cls: Type[Architype]) -> Type[Architype]: """Decorate class.""" - cls = JacPlugin.make_architype( - cls=cls, arch_base=Architype, on_entry=on_entry, on_exit=on_exit + cls = Jac.make_architype( + cls=cls, arch_base=ObjectArchitype, on_entry=on_entry, on_exit=on_exit ) return cls @@ -328,7 +318,7 @@ def make_node( def decorator(cls: Type[Architype]) -> Type[Architype]: """Decorate class.""" - cls = JacPlugin.make_architype( + cls = Jac.make_architype( cls=cls, arch_base=NodeArchitype, on_entry=on_entry, on_exit=on_exit ) return cls @@ -344,7 +334,7 @@ def make_edge( def decorator(cls: Type[Architype]) -> Type[Architype]: """Decorate class.""" - cls = JacPlugin.make_architype( + cls = Jac.make_architype( cls=cls, arch_base=EdgeArchitype, on_entry=on_entry, on_exit=on_exit ) return cls @@ -360,7 +350,7 @@ def make_walker( def decorator(cls: Type[Architype]) -> Type[Architype]: """Decorate class.""" - cls = JacPlugin.make_architype( + cls = Jac.make_architype( cls=cls, arch_base=WalkerArchitype, on_entry=on_entry, on_exit=on_exit ) populate_apis(cls) @@ -368,17 +358,6 @@ def decorator(cls: Type[Architype]) -> Type[Architype]: return decorator - @staticmethod - @hookimpl - async def spawn_call(op1: Architype, op2: Architype) -> WalkerArchitype: - """Jac's spawn operator feature.""" - if isinstance(op1, WalkerArchitype): - return await op1.__jac__.spawn_call(op2.__jac__) - elif isinstance(op2, WalkerArchitype): - return await op2.__jac__.spawn_call(op1.__jac__) - else: - raise TypeError("Invalid walker object") - @staticmethod @hookimpl def report(expr: Any) -> Any: # noqa: ANN401 @@ -387,150 +366,9 @@ def report(expr: Any) -> Any: # noqa: ANN401 @staticmethod @hookimpl - async def ignore( - walker: WalkerArchitype, - expr: ( - list[NodeArchitype | EdgeArchitype] - | list[NodeArchitype] - | list[EdgeArchitype] - | NodeArchitype - | EdgeArchitype - ), - ) -> bool: - """Jac's ignore stmt feature.""" - if isinstance(walker, WalkerArchitype): - return await walker.__jac__.ignore_node( - (i.__jac__ for i in expr) if isinstance(expr, list) else [expr.__jac__] - ) - else: - raise TypeError("Invalid walker object") - - @staticmethod - @hookimpl - async def visit_node( - walker: WalkerArchitype, - expr: ( - list[NodeArchitype | EdgeArchitype] - | list[NodeArchitype] - | list[EdgeArchitype] - | NodeArchitype - | EdgeArchitype - ), - ) -> bool: - """Jac's visit stmt feature.""" - if isinstance(walker, WalkerArchitype): - return await walker.__jac__.visit_node( - (i.__jac__ for i in expr) if isinstance(expr, list) else [expr.__jac__] - ) - else: - raise TypeError("Invalid walker object") - - @staticmethod - @hookimpl - async def edge_ref( - node_obj: NodeArchitype | list[NodeArchitype], - target_cls: Type[NodeArchitype] | list[Type[NodeArchitype]] | None, - dir: EdgeDir, - filter_func: Callable[[list[EdgeArchitype]], list[EdgeArchitype]] | None, - edges_only: bool, - ) -> list[NodeArchitype] | list[EdgeArchitype]: - """Jac's apply_dir stmt feature.""" - if isinstance(node_obj, NodeArchitype): - node_obj = [node_obj] - targ_cls_set: list[Type[NodeArchitype]] | None = ( - [target_cls] if isinstance(target_cls, type) else target_cls - ) - if edges_only: - connected_edges: list[EdgeArchitype] = [] - for node in node_obj: - connected_edges += await node.__jac__.get_edges( - dir, filter_func, target_cls=targ_cls_set - ) - return list(set(connected_edges)) - else: - connected_nodes: list[NodeArchitype] = [] - for node in node_obj: - connected_nodes.extend( - await node.__jac__.edges_to_nodes( - dir, filter_func, target_cls=targ_cls_set - ) - ) - return list(set(connected_nodes)) - - @staticmethod - @hookimpl - async def connect( - left: NodeArchitype | list[NodeArchitype], - right: NodeArchitype | list[NodeArchitype], - edge_spec: Callable[[], EdgeArchitype], - edges_only: bool, - ) -> list[NodeArchitype] | list[EdgeArchitype]: - """Jac's connect operator feature. - - Note: connect needs to call assign compr with tuple in op - """ - left = [left] if isinstance(left, NodeArchitype) else left - right = [right] if isinstance(right, NodeArchitype) else right - edges = [] - nodes = [] - for i in left: - for j in right: - if await (source := i.__jac__).has_connect_access(target := j.__jac__): - conn_edge = edge_spec() - edges.append(conn_edge) - nodes.append(j) - source.connect_node(target, conn_edge.__jac__) - return nodes if not edges_only else edges - - @staticmethod - @hookimpl - async def disconnect( - left: NodeArchitype | list[NodeArchitype], - right: NodeArchitype | list[NodeArchitype], - dir: EdgeDir, - filter_func: Callable[[list[EdgeArchitype]], list[EdgeArchitype]] | None, - ) -> bool: # noqa: ANN401 - """Jac's disconnect operator feature.""" - disconnect_occurred = False - left = [left] if isinstance(left, NodeArchitype) else left - right = [right] if isinstance(right, NodeArchitype) else right - for i in left: - node = i.__jac__ - for anchor in set(node.edges): - if ( - (architype := await anchor.sync(node)) - and (source := anchor.source) - and (target := anchor.target) - and (not filter_func or filter_func([architype])) - and (src_arch := await source.sync()) - and (trg_arch := await target.sync()) - ): - if ( - dir in [EdgeDir.OUT, EdgeDir.ANY] - and node == source - and trg_arch in right - and await source.has_write_access(target) - ): - anchor.destroy() - disconnect_occurred = True - if ( - dir in [EdgeDir.IN, EdgeDir.ANY] - and node == target - and src_arch in right - and await target.has_write_access(source) - ): - anchor.destroy() - disconnect_occurred = True - - return disconnect_occurred - - @staticmethod - @hookimpl - async def get_root() -> Root: + def get_root() -> Root: """Jac's assign comprehension feature.""" - if architype := await JaseciContext.get().root.sync(): - return cast(Root, architype) - raise Exception("No Available Root!") + return JaseciContext.get_root() @staticmethod @hookimpl @@ -544,19 +382,23 @@ def build_edge( is_undirected: bool, conn_type: Type[EdgeArchitype] | EdgeArchitype | None, conn_assign: tuple[tuple, tuple] | None, - ) -> Callable[[], EdgeArchitype]: + ) -> Callable[[NodeAnchor, NodeAnchor], EdgeArchitype]: """Jac's root getter.""" conn_type = conn_type if conn_type else GenericEdge - def builder() -> EdgeArchitype: + def builder(source: NodeAnchor, target: NodeAnchor) -> EdgeArchitype: edge = conn_type() if isinstance(conn_type, type) else conn_type - edge.__jac__.is_undirected = is_undirected + edge.__attach__(source, target, is_undirected) if conn_assign: for fld, val in zip(conn_assign[0], conn_assign[1]): if hasattr(edge, fld): setattr(edge, fld, val) else: raise ValueError(f"Invalid attribute: {fld}") + if source.persistent or target.persistent: + edge.__jac__.save() + target.save() + source.save() return edge return builder @@ -567,22 +409,7 @@ def builder() -> EdgeArchitype: ########################################################## Jac.RootType = Root # type: ignore[assignment] -Jac.Obj = Architype # type: ignore[assignment] +Jac.Obj = ObjectArchitype # type: ignore[assignment] Jac.Node = NodeArchitype # type: ignore[assignment] Jac.Edge = EdgeArchitype # type: ignore[assignment] Jac.Walker = WalkerArchitype # type: ignore[assignment] - - -def overrided_init(self: AstAsyncNode, is_async: bool) -> None: - """Initialize ast.""" - self.is_async = True if isinstance(self, Ability) else is_async - - -AstAsyncNode.__init__ = overrided_init # type: ignore[method-assign] -PyastGenPass.set( - [ - name - for name, func in JacPlugin.__dict__.items() - if isinstance(func, staticmethod) and iscoroutinefunction(func.__func__) - ] -) diff --git a/jaclang_jaseci/tests/simple_graph.jac b/jaclang_jaseci/tests/simple_graph.jac index 6547a30..f69d7cd 100644 --- a/jaclang_jaseci/tests/simple_graph.jac +++ b/jaclang_jaseci/tests/simple_graph.jac @@ -221,7 +221,7 @@ walker manual_create_nested_node { ) )); here ++> n; - await here.__jac__.save(); + here.__jac__.save(); # simulate no auto save jsrc = JaseciContext.get_datasource(); @@ -244,7 +244,7 @@ walker manual_update_nested_node { nested.json["a"] = 1; nested.arr.append(1); nested.val = 1; - await nested.__jac__.save(); + nested.__jac__.save(); # simulate no auto save jsrc = JaseciContext.get_datasource(); @@ -260,7 +260,7 @@ walker manual_detach_nested_node { nested = [-->Nested][0]; detached = here del--> [-->Nested]; nested.__jac__.destroy(); - await nested.__jac__.save(); + nested.__jac__.save(); # simulate no auto save jsrc = JaseciContext.get_datasource(); @@ -275,7 +275,7 @@ walker delete_nested_node { can enter_root with `root entry { nested = [-->Nested][0]; nested.__jac__.destroy(); - await nested.__jac__.save(); + nested.__jac__.save(); report [-->Nested]; } @@ -285,7 +285,7 @@ walker manual_delete_nested_node { can enter_root with `root entry { nested = [-->Nested][0]; nested.__jac__.destroy(); - await nested.__jac__.save(); + nested.__jac__.save(); # simulate no auto save jsrc = JaseciContext.get_datasource(); @@ -309,7 +309,7 @@ walker manual_delete_nested_edge { can enter_root with `root entry { nested_edge = :e:[-->][0]; nested_edge.__jac__.destroy(); - await nested_edge.__jac__.save(); + nested_edge.__jac__.save(); # simulate no auto save jsrc = JaseciContext.get_datasource(); diff --git a/jaclang_jaseci/tests/test_simple_graph.py b/jaclang_jaseci/tests/test_simple_graph.py index 87fe51c..4f92b06 100644 --- a/jaclang_jaseci/tests/test_simple_graph.py +++ b/jaclang_jaseci/tests/test_simple_graph.py @@ -37,7 +37,7 @@ async def asyncSetUp(self) -> None: async def asyncTearDown(self) -> None: """Clean up DB.""" - await self.client.drop_database(self.database) + self.client.drop_database(self.database) @overload def post_api(self, api: str, json: dict | None = None, user: int = 0) -> dict: @@ -459,10 +459,10 @@ def trigger_access_validation_test( async def nested_count_should_be(self, node: int, edge: int) -> None: """Test nested node count.""" - self.assertEqual(node, await self.q_node.count_documents({"name": "Nested"})) + self.assertEqual(node, self.q_node.count_documents({"name": "Nested"})) self.assertEqual( edge, - await self.q_edge.count_documents( + self.q_edge.count_documents( { "$or": [ {"source": {"$regex": "^n:Nested:"}}, diff --git a/setup.py b/setup.py index dfaec20..20fd1c4 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "sendgrid==6.11.0", "fastapi-sso==0.15.0", "google-auth==2.32.0", + "asyncer==0.0.8", ], package_data={}, entry_points={