diff --git a/src/pycram/designator.py b/src/pycram/designator.py index d91b18beb..4c53fab0a 100644 --- a/src/pycram/designator.py +++ b/src/pycram/designator.py @@ -25,7 +25,7 @@ from .orm.action_designator import (Action as ORMAction) from .orm.object_designator import (Object as ORMObjectDesignator) -from .orm.base import Quaternion, Position, Base, RobotState, ProcessMetaData +from .orm.base import RobotState, ProcessMetaData from .task import with_tree @@ -424,15 +424,15 @@ def insert(self, session: Session, *args, **kwargs) -> ORMAction: metadata = ProcessMetaData().insert(session) # create robot-state object - robot_state = RobotState(self.robot_torso_height, self.robot_type, pose.id) - robot_state.process_metadata_id = metadata.id + robot_state = RobotState(self.robot_torso_height, self.robot_type) + robot_state.pose = pose + robot_state.process_metadata = metadata session.add(robot_state) - session.commit() # create action action = self.to_sql() - action.process_metadata_id = metadata.id - action.robot_state_id = robot_state.id + action.process_metadata = metadata + action.robot_state = robot_state return action @@ -545,16 +545,13 @@ def insert(self, session: Session) -> ORMObjectDesignator: :return: The completely instanced ORM object """ metadata = ProcessMetaData().insert(session) + pose = self.pose.insert(session) # create object orm designator obj = self.to_sql() - obj.process_metadata_id = metadata.id - - pose = self.pose.insert(session) - obj.pose_id = pose.id - + obj.process_metadata = metadata + obj.pose = pose session.add(obj) - session.commit() return obj def frozen_copy(self) -> 'ObjectDesignatorDescription.Object': diff --git a/src/pycram/designators/actions/actions.py b/src/pycram/designators/actions/actions.py index f9dec14fc..16175592a 100644 --- a/src/pycram/designators/actions/actions.py +++ b/src/pycram/designators/actions/actions.py @@ -14,7 +14,7 @@ from ...bullet_world import BulletWorld from ...helper import multiply_quaternions from ...local_transformer import LocalTransformer -from ...orm.base import Base, Pose as ORMPose +from ...orm.base import Pose as ORMPose from ...orm.object_designator import Object as ORMObject, ObjectPart as ORMObjectPart from ...orm.action_designator import Action as ORMAction from ...plan_failures import ObjectUnfetchable, ReachabilityFailure @@ -94,12 +94,11 @@ def insert(self, session: Session, **kwargs) -> Action: if key not in orm_class_variables: variable = value.insert(session) if isinstance(variable, ORMObject): - action.object_id = variable.id + action.object = variable elif isinstance(variable, ORMPose): - action.pose_id = variable.id - + action.pose = variable session.add(action) - session.commit() + return action diff --git a/src/pycram/designators/motion_designator.py b/src/pycram/designators/motion_designator.py index b6779c300..b289f345d 100644 --- a/src/pycram/designators/motion_designator.py +++ b/src/pycram/designators/motion_designator.py @@ -54,7 +54,7 @@ def insert(self, session: Session, *args, **kwargs) -> ORMMotionDesignator: metadata = ProcessMetaData().insert(session) motion = self.to_sql() - motion.process_metadata_id = metadata.id + motion.process_metadata = metadata return motion @@ -108,12 +108,9 @@ def to_sql(self) -> ORMMoveMotion: def insert(self, session, *args, **kwargs) -> ORMMoveMotion: motion = super().insert(session) - pose = self.target.insert(session) - motion.pose_id = pose.id - + motion.pose = pose session.add(motion) - session.commit() return motion @@ -147,12 +144,9 @@ def to_sql(self) -> ORMMoveTCPMotion: def insert(self, session: Session, *args, **kwargs) -> ORMMoveTCPMotion: motion = super().insert(session) - pose = self.target.insert(session) - motion.pose_id = pose.id - + motion.pose = pose session.add(motion) - session.commit() return motion @@ -174,12 +168,9 @@ def to_sql(self) -> ORMLookingMotion: def insert(self, session: Session, *args, **kwargs) -> ORMLookingMotion: motion = super().insert(session) - pose = self.target.insert(session) - motion.pose_id = pose.id - + motion.pose = pose session.add(motion) - session.commit() return motion @@ -213,9 +204,8 @@ def to_sql(self) -> ORMMoveGripperMotion: def insert(self, session: Session, *args, **kwargs) -> ORMMoveGripperMotion: motion = super().insert(session) - session.add(motion) - session.commit() + return motion @@ -250,7 +240,7 @@ def to_sql(self) -> ORMDetectingMotion: def insert(self, session: Session, *args, **kwargs) -> ORMDetectingMotion: motion = super().insert(session) session.add(motion) - session.commit() + return motion @@ -353,12 +343,9 @@ def to_sql(self) -> ORMOpeningMotion: def insert(self, session: Session, *args, **kwargs) -> ORMOpeningMotion: motion = super().insert(session) - op = self.object_part.insert(session) - motion.object_id = op.id - + motion.object = op session.add(motion) - session.commit() return motion @@ -388,11 +375,8 @@ def to_sql(self) -> ORMClosingMotion: def insert(self, session: Session, *args, **kwargs) -> ORMClosingMotion: motion = super().insert(session) - op = self.object_part.insert(session) - motion.object_id = op.id - + motion.object = op session.add(motion) - session.commit() return motion diff --git a/src/pycram/designators/object_designator.py b/src/pycram/designators/object_designator.py index 02a58ef41..dbb3d8fbe 100644 --- a/src/pycram/designators/object_designator.py +++ b/src/pycram/designators/object_designator.py @@ -24,11 +24,11 @@ def to_sql(self) -> ORMBelieveObject: return ORMBelieveObject(self.type, self.name) def insert(self, session: sqlalchemy.orm.session.Session) -> ORMBelieveObject: + metadata = ProcessMetaData().insert(session) self_ = self.to_sql() + self_.process_metadata = metadata session.add(self_) - session.commit() - metadata = ProcessMetaData().insert(session) - self_.process_metadata_id = metadata.id + return self_ @@ -47,14 +47,12 @@ def to_sql(self) -> ORMObjectPart: return ORMObjectPart(self.type, self.name) def insert(self, session: sqlalchemy.orm.session.Session) -> ORMObjectPart: - obj = self.to_sql() metadata = ProcessMetaData().insert(session) - obj.process_metadata_id = metadata.id pose = self.part_pose.insert(session) - obj.pose_id = pose.id - + obj = self.to_sql() + obj.process_metadata = metadata + obj.pose = pose session.add(obj) - session.commit() return obj diff --git a/src/pycram/orm/base.py b/src/pycram/orm/base.py index 6bf7f55a9..d238dfff3 100644 --- a/src/pycram/orm/base.py +++ b/src/pycram/orm/base.py @@ -49,7 +49,7 @@ class Base(_Base): __abstract__ = True @declared_attr - def process_metadata_id(self) -> Mapped[Optional[int]]: + def process_metadata_id(self) -> Mapped[int]: return mapped_column(ForeignKey(f'{ProcessMetaData.__tablename__}.id'), default=None, init=False) """Related MetaData Object to store information about the context of this experiment.""" @@ -164,7 +164,6 @@ def insert(self, session: Session): """Insert this into the database using the session. Skipped if it already is inserted.""" if not self.committed(): session.add(self) - session.commit() return self @classmethod @@ -223,8 +222,6 @@ class Color(Base): class RobotState(PoseMixin, Base): """ORM Representation of a robots state.""" - pose_to_init = True - torso_height: Mapped[float] """The torso height of the robot.""" diff --git a/src/pycram/orm/task.py b/src/pycram/orm/task.py index 9a5669890..e352f065e 100644 --- a/src/pycram/orm/task.py +++ b/src/pycram/orm/task.py @@ -14,16 +14,16 @@ class TaskTreeNode(Base): id: Mapped[int] = mapped_column(autoincrement=True, primary_key=True, init=False) - action_id: Mapped[Optional[int]] = mapped_column(ForeignKey(f'{Designator.__tablename__}.id'), default=None) + action_id: Mapped[Optional[int]] = mapped_column(ForeignKey(f'{Designator.__tablename__}.id'), init=False) action: Mapped[Optional[Designator]] = relationship(init=False) - start_time: Mapped[datetime.datetime] = mapped_column(default=None) - end_time: Mapped[Optional[datetime.datetime]] = mapped_column(default=None) + start_time: Mapped[datetime.datetime] + end_time: Mapped[Optional[datetime.datetime]] - status: Mapped[TaskStatus] = mapped_column(default=None) - reason: Mapped[Optional[str]] = mapped_column(default=None) + status: Mapped[TaskStatus] + reason: Mapped[Optional[str]] - parent_id: Mapped[Optional[int]] = mapped_column(ForeignKey("TaskTreeNode.id"), default=None) - parent: Mapped["TaskTreeNode"] = relationship(foreign_keys=[parent_id], init=False, remote_side=[id]) + parent_id: Mapped[Optional[int]] = mapped_column(ForeignKey("TaskTreeNode.id"), init=False) + parent: Mapped[Optional["TaskTreeNode"]] = relationship(init=False, remote_side=[id]) diff --git a/src/pycram/pose.py b/src/pycram/pose.py index 852f20967..dc45b4679 100644 --- a/src/pycram/pose.py +++ b/src/pycram/pose.py @@ -242,20 +242,17 @@ def insert(self, session: sqlalchemy.orm.Session) -> ORMPose: metadata = ProcessMetaData().insert(session) position = Position(*self.position_as_list()) - position.process_metadata_id = metadata.id + position.process_metadata = metadata orientation = Quaternion(*self.orientation_as_list()) - orientation.process_metadata_id = metadata.id - + orientation.process_metadata = metadata session.add(position) session.add(orientation) - session.commit() - pose = self.to_sql() - pose.process_metadata_id = metadata.id - pose.position_id = position.id - pose.orientation_id = orientation.id + pose = self.to_sql() + pose.process_metadata = metadata + pose.orientation = orientation + pose.position = position session.add(pose) - session.commit() return pose diff --git a/src/pycram/task.py b/src/pycram/task.py index 2477c1880..6dbdcfb5c 100644 --- a/src/pycram/task.py +++ b/src/pycram/task.py @@ -118,18 +118,17 @@ def to_sql(self) -> ORMTaskTreeNode: else: reason = None - return ORMTaskTreeNode(None, self.start_time, self.end_time, self.status.name, - reason, id(self.parent) if self.parent else None) + return ORMTaskTreeNode(self.start_time, self.end_time, self.status.name, reason) def insert(self, session: sqlalchemy.orm.session.Session, recursive: bool = True, - parent_id: Optional[int] = None, use_progress_bar: bool = True, + parent: Optional[TaskTreeNode] = None, use_progress_bar: bool = True, progress_bar: Optional[tqdm.tqdm] = None) -> ORMTaskTreeNode: """ Insert this node into the database. :param session: The current session with the database. :param recursive: Rather if the entire tree should be inserted or just this node, defaults to True - :param parent_id: The primary key of the parent node, defaults to None + :param parent: The parent node, defaults to None :param use_progress_bar: Rather to use a progressbar or not :param progress_bar: The progressbar to update. If a progress bar is desired and this is None, a new one will be created. @@ -147,30 +146,34 @@ def insert(self, session: sqlalchemy.orm.session.Session, recursive: bool = True # insert action if possible if getattr(self.action, "insert", None): action = self.action.insert(session) - node.action_id = action.id + node.action = action else: action = None - node.action_id = None + node.action = None # get and set metadata metadata = ProcessMetaData().insert(session) - node.process_metadata_id = metadata.id + node.process_metadata = metadata - # set parent to id from constructor - node.parent_id = parent_id + # set node parent + node.parent = parent - # add the node to database to retrieve the new id + # add the node to the session; note that the instance is not yet committed to the db, but rather in a + # pending state session.add(node) - session.commit() if progress_bar: progress_bar.update() # if recursive, insert all children if recursive: - [child.insert(session, parent_id=node.id, use_progress_bar=use_progress_bar, progress_bar=progress_bar) + [child.insert(session, parent=node, use_progress_bar=use_progress_bar, progress_bar=progress_bar) for child in self.children] + # once recursion is done and the root node is reached again, commit the session to the database + if self.parent is None: + session.commit() + return node diff --git a/test/test_orm.py b/test/test_orm.py index 23150b9be..ce5775a84 100644 --- a/test/test_orm.py +++ b/test/test_orm.py @@ -1,7 +1,5 @@ import os import unittest - -import anytree from sqlalchemy import select import sqlalchemy.orm import pycram.orm.action_designator @@ -12,7 +10,6 @@ import pycram.task import pycram.task from bullet_world_testcase import BulletWorldTestCase -import test_task_tree from pycram.bullet_world import Object from pycram.designators import action_designator, object_designator, motion_designator from pycram.designators.actions.actions import ParkArmsActionPerformable, MoveTorsoActionPerformable, \ @@ -38,7 +35,6 @@ def setUp(self): super().setUp() pycram.orm.base.Base.metadata.create_all(self.engine) self.session = sqlalchemy.orm.Session(bind=self.engine) - self.session.commit() def tearDown(self): super().tearDown()