Skip to content

Commit

Permalink
Merge pull request #138 from davidprueser/tests
Browse files Browse the repository at this point in the history
[orm] improved orm insert speed
  • Loading branch information
tomsch420 authored Apr 2, 2024
2 parents 88b0421 + 24ea429 commit 5457009
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 85 deletions.
21 changes: 9 additions & 12 deletions src/pycram/designator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .orm.action_designator import (Action as ORMAction)
from .orm.object_designator import (Object as ORMObjectDesignator)

from .orm.base import Quaternion, Position, Base, RobotState, ProcessMetaData
from .orm.base import RobotState, ProcessMetaData
from .task import with_tree


Expand Down Expand Up @@ -424,15 +424,15 @@ def insert(self, session: Session, *args, **kwargs) -> ORMAction:
metadata = ProcessMetaData().insert(session)

# create robot-state object
robot_state = RobotState(self.robot_torso_height, self.robot_type, pose.id)
robot_state.process_metadata_id = metadata.id
robot_state = RobotState(self.robot_torso_height, self.robot_type)
robot_state.pose = pose
robot_state.process_metadata = metadata
session.add(robot_state)
session.commit()

# create action
action = self.to_sql()
action.process_metadata_id = metadata.id
action.robot_state_id = robot_state.id
action.process_metadata = metadata
action.robot_state = robot_state

return action

Expand Down Expand Up @@ -545,16 +545,13 @@ def insert(self, session: Session) -> ORMObjectDesignator:
:return: The completely instanced ORM object
"""
metadata = ProcessMetaData().insert(session)
pose = self.pose.insert(session)

# create object orm designator
obj = self.to_sql()
obj.process_metadata_id = metadata.id

pose = self.pose.insert(session)
obj.pose_id = pose.id

obj.process_metadata = metadata
obj.pose = pose
session.add(obj)
session.commit()
return obj

def frozen_copy(self) -> 'ObjectDesignatorDescription.Object':
Expand Down
9 changes: 4 additions & 5 deletions src/pycram/designators/actions/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...bullet_world import BulletWorld
from ...helper import multiply_quaternions
from ...local_transformer import LocalTransformer
from ...orm.base import Base, Pose as ORMPose
from ...orm.base import Pose as ORMPose
from ...orm.object_designator import Object as ORMObject, ObjectPart as ORMObjectPart
from ...orm.action_designator import Action as ORMAction
from ...plan_failures import ObjectUnfetchable, ReachabilityFailure
Expand Down Expand Up @@ -94,12 +94,11 @@ def insert(self, session: Session, **kwargs) -> Action:
if key not in orm_class_variables:
variable = value.insert(session)
if isinstance(variable, ORMObject):
action.object_id = variable.id
action.object = variable
elif isinstance(variable, ORMPose):
action.pose_id = variable.id

action.pose = variable
session.add(action)
session.commit()

return action


Expand Down
32 changes: 8 additions & 24 deletions src/pycram/designators/motion_designator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def insert(self, session: Session, *args, **kwargs) -> ORMMotionDesignator:
metadata = ProcessMetaData().insert(session)

motion = self.to_sql()
motion.process_metadata_id = metadata.id
motion.process_metadata = metadata

return motion

Expand Down Expand Up @@ -108,12 +108,9 @@ def to_sql(self) -> ORMMoveMotion:

def insert(self, session, *args, **kwargs) -> ORMMoveMotion:
motion = super().insert(session)

pose = self.target.insert(session)
motion.pose_id = pose.id

motion.pose = pose
session.add(motion)
session.commit()

return motion

Expand Down Expand Up @@ -147,12 +144,9 @@ def to_sql(self) -> ORMMoveTCPMotion:

def insert(self, session: Session, *args, **kwargs) -> ORMMoveTCPMotion:
motion = super().insert(session)

pose = self.target.insert(session)
motion.pose_id = pose.id

motion.pose = pose
session.add(motion)
session.commit()

return motion

Expand All @@ -174,12 +168,9 @@ def to_sql(self) -> ORMLookingMotion:

def insert(self, session: Session, *args, **kwargs) -> ORMLookingMotion:
motion = super().insert(session)

pose = self.target.insert(session)
motion.pose_id = pose.id

motion.pose = pose
session.add(motion)
session.commit()

return motion

Expand Down Expand Up @@ -213,9 +204,8 @@ def to_sql(self) -> ORMMoveGripperMotion:

def insert(self, session: Session, *args, **kwargs) -> ORMMoveGripperMotion:
motion = super().insert(session)

session.add(motion)
session.commit()

return motion


Expand Down Expand Up @@ -250,7 +240,7 @@ def to_sql(self) -> ORMDetectingMotion:
def insert(self, session: Session, *args, **kwargs) -> ORMDetectingMotion:
motion = super().insert(session)
session.add(motion)
session.commit()

return motion


Expand Down Expand Up @@ -353,12 +343,9 @@ def to_sql(self) -> ORMOpeningMotion:

def insert(self, session: Session, *args, **kwargs) -> ORMOpeningMotion:
motion = super().insert(session)

op = self.object_part.insert(session)
motion.object_id = op.id

motion.object = op
session.add(motion)
session.commit()

return motion

Expand Down Expand Up @@ -388,11 +375,8 @@ def to_sql(self) -> ORMClosingMotion:

def insert(self, session: Session, *args, **kwargs) -> ORMClosingMotion:
motion = super().insert(session)

op = self.object_part.insert(session)
motion.object_id = op.id

motion.object = op
session.add(motion)
session.commit()

return motion
14 changes: 6 additions & 8 deletions src/pycram/designators/object_designator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def to_sql(self) -> ORMBelieveObject:
return ORMBelieveObject(self.type, self.name)

def insert(self, session: sqlalchemy.orm.session.Session) -> ORMBelieveObject:
metadata = ProcessMetaData().insert(session)
self_ = self.to_sql()
self_.process_metadata = metadata
session.add(self_)
session.commit()
metadata = ProcessMetaData().insert(session)
self_.process_metadata_id = metadata.id

return self_


Expand All @@ -47,14 +47,12 @@ def to_sql(self) -> ORMObjectPart:
return ORMObjectPart(self.type, self.name)

def insert(self, session: sqlalchemy.orm.session.Session) -> ORMObjectPart:
obj = self.to_sql()
metadata = ProcessMetaData().insert(session)
obj.process_metadata_id = metadata.id
pose = self.part_pose.insert(session)
obj.pose_id = pose.id

obj = self.to_sql()
obj.process_metadata = metadata
obj.pose = pose
session.add(obj)
session.commit()

return obj

Expand Down
5 changes: 1 addition & 4 deletions src/pycram/orm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Base(_Base):
__abstract__ = True

@declared_attr
def process_metadata_id(self) -> Mapped[Optional[int]]:
def process_metadata_id(self) -> Mapped[int]:
return mapped_column(ForeignKey(f'{ProcessMetaData.__tablename__}.id'), default=None, init=False)
"""Related MetaData Object to store information about the context of this experiment."""

Expand Down Expand Up @@ -164,7 +164,6 @@ def insert(self, session: Session):
"""Insert this into the database using the session. Skipped if it already is inserted."""
if not self.committed():
session.add(self)
session.commit()
return self

@classmethod
Expand Down Expand Up @@ -223,8 +222,6 @@ class Color(Base):
class RobotState(PoseMixin, Base):
"""ORM Representation of a robots state."""

pose_to_init = True

torso_height: Mapped[float]
"""The torso height of the robot."""

Expand Down
14 changes: 7 additions & 7 deletions src/pycram/orm/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ class TaskTreeNode(Base):

id: Mapped[int] = mapped_column(autoincrement=True, primary_key=True, init=False)

action_id: Mapped[Optional[int]] = mapped_column(ForeignKey(f'{Designator.__tablename__}.id'), default=None)
action_id: Mapped[Optional[int]] = mapped_column(ForeignKey(f'{Designator.__tablename__}.id'), init=False)
action: Mapped[Optional[Designator]] = relationship(init=False)

start_time: Mapped[datetime.datetime] = mapped_column(default=None)
end_time: Mapped[Optional[datetime.datetime]] = mapped_column(default=None)
start_time: Mapped[datetime.datetime]
end_time: Mapped[Optional[datetime.datetime]]

status: Mapped[TaskStatus] = mapped_column(default=None)
reason: Mapped[Optional[str]] = mapped_column(default=None)
status: Mapped[TaskStatus]
reason: Mapped[Optional[str]]

parent_id: Mapped[Optional[int]] = mapped_column(ForeignKey("TaskTreeNode.id"), default=None)
parent: Mapped["TaskTreeNode"] = relationship(foreign_keys=[parent_id], init=False, remote_side=[id])
parent_id: Mapped[Optional[int]] = mapped_column(ForeignKey("TaskTreeNode.id"), init=False)
parent: Mapped[Optional["TaskTreeNode"]] = relationship(init=False, remote_side=[id])


15 changes: 6 additions & 9 deletions src/pycram/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,17 @@ def insert(self, session: sqlalchemy.orm.Session) -> ORMPose:
metadata = ProcessMetaData().insert(session)

position = Position(*self.position_as_list())
position.process_metadata_id = metadata.id
position.process_metadata = metadata
orientation = Quaternion(*self.orientation_as_list())
orientation.process_metadata_id = metadata.id

orientation.process_metadata = metadata
session.add(position)
session.add(orientation)
session.commit()
pose = self.to_sql()
pose.process_metadata_id = metadata.id
pose.position_id = position.id
pose.orientation_id = orientation.id

pose = self.to_sql()
pose.process_metadata = metadata
pose.orientation = orientation
pose.position = position
session.add(pose)
session.commit()

return pose

Expand Down
27 changes: 15 additions & 12 deletions src/pycram/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,17 @@ def to_sql(self) -> ORMTaskTreeNode:
else:
reason = None

return ORMTaskTreeNode(None, self.start_time, self.end_time, self.status.name,
reason, id(self.parent) if self.parent else None)
return ORMTaskTreeNode(self.start_time, self.end_time, self.status.name, reason)

def insert(self, session: sqlalchemy.orm.session.Session, recursive: bool = True,
parent_id: Optional[int] = None, use_progress_bar: bool = True,
parent: Optional[TaskTreeNode] = None, use_progress_bar: bool = True,
progress_bar: Optional[tqdm.tqdm] = None) -> ORMTaskTreeNode:
"""
Insert this node into the database.
:param session: The current session with the database.
:param recursive: Rather if the entire tree should be inserted or just this node, defaults to True
:param parent_id: The primary key of the parent node, defaults to None
:param parent: The parent node, defaults to None
:param use_progress_bar: Rather to use a progressbar or not
:param progress_bar: The progressbar to update. If a progress bar is desired and this is None, a new one will be
created.
Expand All @@ -147,30 +146,34 @@ def insert(self, session: sqlalchemy.orm.session.Session, recursive: bool = True
# insert action if possible
if getattr(self.action, "insert", None):
action = self.action.insert(session)
node.action_id = action.id
node.action = action
else:
action = None
node.action_id = None
node.action = None

# get and set metadata
metadata = ProcessMetaData().insert(session)
node.process_metadata_id = metadata.id
node.process_metadata = metadata

# set parent to id from constructor
node.parent_id = parent_id
# set node parent
node.parent = parent

# add the node to database to retrieve the new id
# add the node to the session; note that the instance is not yet committed to the db, but rather in a
# pending state
session.add(node)
session.commit()

if progress_bar:
progress_bar.update()

# if recursive, insert all children
if recursive:
[child.insert(session, parent_id=node.id, use_progress_bar=use_progress_bar, progress_bar=progress_bar)
[child.insert(session, parent=node, use_progress_bar=use_progress_bar, progress_bar=progress_bar)
for child in self.children]

# once recursion is done and the root node is reached again, commit the session to the database
if self.parent is None:
session.commit()

return node


Expand Down
4 changes: 0 additions & 4 deletions test/test_orm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
import unittest

import anytree
from sqlalchemy import select
import sqlalchemy.orm
import pycram.orm.action_designator
Expand All @@ -12,7 +10,6 @@
import pycram.task
import pycram.task
from bullet_world_testcase import BulletWorldTestCase
import test_task_tree
from pycram.bullet_world import Object
from pycram.designators import action_designator, object_designator, motion_designator
from pycram.designators.actions.actions import ParkArmsActionPerformable, MoveTorsoActionPerformable, \
Expand All @@ -38,7 +35,6 @@ def setUp(self):
super().setUp()
pycram.orm.base.Base.metadata.create_all(self.engine)
self.session = sqlalchemy.orm.Session(bind=self.engine)
self.session.commit()

def tearDown(self):
super().tearDown()
Expand Down

0 comments on commit 5457009

Please sign in to comment.