Skip to content

Commit

Permalink
sovle align to world error
Browse files Browse the repository at this point in the history
  • Loading branch information
BJHYZJ committed Dec 21, 2024
1 parent 439690b commit 863a7a9
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 712 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ dist

*.TimeRecord

*.so
*.pkl
*.mp4
data
checkpoints

evaluation/output
data_example
*.zip
*.zip

test.py
2 changes: 1 addition & 1 deletion ace/train_ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def train_ace(scene, output_map_file):
torch.cuda.empty_cache()

if __name__ == '__main__':
scene = Path("/home/yanzj/workspace/code1/DovSG/data/company")
scene = Path("/home/yanzj/workspace/code/DovSG/data/company")
output_map_file = scene / "ace/ace.pt"
train_ace(scene, output_map_file)
17 changes: 11 additions & 6 deletions dovsg/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
nb_neighbors: int=35,
std_ratio: float=1.5,

socket_ip: str="192.168.1.54",
socket_ip: str="192.168.1.50",
socket_port: str="9999"
):

Expand Down Expand Up @@ -722,8 +722,7 @@ def get_align_observations(self, use_inlier_mask=False, show_align=False, just_w
print("===> get observations from robot.")

_, observations = self.get_observations(just_wrist=just_wrist, save_name=save_name)
rough_poses = self.test_ace(observations)


for name, obs in observations.items():
point = obs["point"]
rgb = obs["rgb"]
Expand All @@ -732,14 +731,17 @@ def get_align_observations(self, use_inlier_mask=False, show_align=False, just_w
inlier_mask = get_inlier_mask(point=point, color=rgb, mask=mask)
mask = np.logical_and(mask, inlier_mask)
obs["mask"] = mask
obs["pose"] = rough_poses[name]
# obs["pose"] = rough_poses[name]

if self_align:
observations = self.self_align_observations(observations)

is_success = True
# align observations from base coord to world coord
if align_to_world:
rough_poses = self.test_ace(observations)
for name, obs in observations.items():
obs["pose"] = rough_poses[name]
observations, is_success = self.correct_pose_observations(observations)
if show_align and is_success:
self.show_pointcloud_for_align(observations)
Expand Down Expand Up @@ -1111,6 +1113,9 @@ def run_tasks(self, tasks: Union[List[dict]]):
align_to_world=True,
save_name="0_start",
)
if not correct_success:
assert 1 == 0, "Init Pose Error!"

init_position, init_rotation = self.get_current_position(observations)
current_position = init_position
current_rotation = init_rotation
Expand Down Expand Up @@ -1149,8 +1154,8 @@ def run_tasks(self, tasks: Union[List[dict]]):
# just after pick up and place, update scene
if task["action"] in ["Pick up", "Place"]:
if not correct_success:
print("relocalize error.")
exit(0)
assert 1 == 0, "relocalize error!"

self.update_scene(observations=observations)
end_time = time.time()
print(f"Step spend time is: {end_time - start_time}")
Expand Down
Loading

0 comments on commit 863a7a9

Please sign in to comment.