From 54f4ba21d8abce51711b7fce0df3772a34e61a97 Mon Sep 17 00:00:00 2001 From: zhuzhongshu123 Date: Mon, 8 Jan 2024 14:21:50 +0800 Subject: [PATCH] add relation type check --- python/knext/knext/operator/op.py | 6 ++++-- python/knext/knext/operator/spg_record.py | 9 +++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/knext/knext/operator/op.py b/python/knext/knext/operator/op.py index 8ea7d867c..2fa315736 100644 --- a/python/knext/knext/operator/op.py +++ b/python/knext/knext/operator/op.py @@ -10,7 +10,7 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. from abc import ABC -from typing import List, Dict, Any +from typing import List, Dict, Tuple, Any import knext.common.cache from knext.common.schema_helper import SPGTypeName, TripletName @@ -27,7 +27,9 @@ class ExtractOp(BaseOp, ABC): def __init__(self, params: Dict[str, str] = None): super().__init__(params) - def invoke(self, record: Dict[str, str]) -> List[Dict[str, str]]: + def invoke( + self, record: Dict[str, str] + ) -> Tuple[List[Dict[str, str]], List[SPGRecord]]: raise NotImplementedError( f"{self.__class__.__name__} need to implement `invoke` method." ) diff --git a/python/knext/knext/operator/spg_record.py b/python/knext/knext/operator/spg_record.py index 391add896..62819d466 100644 --- a/python/knext/knext/operator/spg_record.py +++ b/python/knext/knext/operator/spg_record.py @@ -16,6 +16,7 @@ SPGTypeName, PropertyName, RelationName, + SPGTypeHelper, ) @@ -170,6 +171,10 @@ def upsert_relation( :param value: The updated relation value. # noqa: E501 :type: str """ + if isinstance(self._spg_type_name, SPGTypeHelper): + if not hasattr(self._spg_type_name, str(relation_name)): + raise KeyError(f"Relation {relation_name} not defined in Schema") + self.relations[relation_name + "#" + object_type_name] = value return self @@ -181,7 +186,7 @@ def upsert_relations(self, relations: Dict[Tuple[RelationName, SPGTypeName], str :type: dict """ for (relation_name, object_type_name), value in relations.items(): - self.relations[relation_name + "#" + object_type_name] = value + self.upsert_relation(relation_name, object_type_name, value) return self def remove_relation( @@ -204,7 +209,7 @@ def remove_relations(self, relation_names: List[Tuple[RelationName, SPGTypeName] :param relation_names: A list of relation names. # noqa: E501 :type: list """ - for (relation_name, object_type_name) in relation_names: + for relation_name, object_type_name in relation_names: self.relations.pop(relation_name + "#" + object_type_name) return self