forked from nickgkan/3d_diffuser_actor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_dataset.py
176 lines (148 loc) · 6.08 KB
/
generate_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from rlbench.backend.task import Task
from rlbench.backend.scene import DemoError
from rlbench.observation_config import ObservationConfig, CameraConfig
from pyrep import PyRep
from pyrep.robots.arms.panda import Panda
from pyrep.robots.end_effectors.panda_gripper import PandaGripper
from rlbench.backend.const import TTT_FILE
from rlbench.backend.scene import Scene
from rlbench.backend.utils import task_file_to_task_class
from rlbench.backend.task import TASKS_PATH
from rlbench.backend.robot import Robot
from pyrep.const import RenderMode
import numpy as np
import os
import argparse
import pickle
DEMO_ATTEMPTS = 5
MAX_VARIATIONS = 100
class TaskValidationError(Exception):
pass
def task_smoke(task: Task, scene: Scene, variation=-1, demos=1, success=0.50,
max_variations=3, test_demos=True):
print('Running task validator on task: %s' % task.get_name())
# Loading
scene.load(task)
# Number of variations
variation_count = task.variation_count()
if variation_count < 0:
raise TaskValidationError(
"The method 'variation_count' should return a number > 0.")
if variation_count > MAX_VARIATIONS:
raise TaskValidationError(
"This task had %d variations. Currently the limit is set to %d" %
(variation_count, MAX_VARIATIONS))
# Base rotation bounds
base_pos, base_ori = task.base_rotation_bounds()
if len(base_pos) != 3 or len(base_ori) != 3:
raise TaskValidationError(
"The method 'base_rotation_bounds' should return a tuple "
"containing a list of floats.")
# Boundary root
root = task.boundary_root()
if not root.still_exists():
raise TaskValidationError(
"The method 'boundary_root' should return a Dummy that is the root "
"of the task.")
def variation_smoke(i):
print('Running task validator on variation: %d' % i)
attempt_result = False
failed_demos = 0
for j in range(DEMO_ATTEMPTS):
failed_demos = run_demos(i)
attempt_result = (failed_demos / float(demos) <= 1. - success)
if attempt_result:
break
else:
print('Failed on attempt %d. Trying again...' % j)
# Make sure we don't fail too often
if not attempt_result:
raise TaskValidationError(
"Too many failed demo runs. %d of %d demos failed." % (
failed_demos, demos))
else:
print('Variation %d of task %s is good!' % (i, task.get_name()))
if test_demos:
print('%d of %d demos were successful.' % (
demos - failed_demos, demos))
def run_demos(variation_num):
fails = 0
for dr in range(demos):
try:
scene.reset()
desc = scene.init_episode(variation_num, max_attempts=10)
origin = task.boundary_root().get_position(), task.boundary_root().get_orientation()
if not isinstance(desc, list) or len(desc) <= 0:
raise TaskValidationError(
"The method 'init_variation' should return a list of "
"string descriptions.")
if test_demos:
# inference with gt position
demo = scene.get_demo(record=True)
state_ = {'variation_num': variation_num,
'left_gt_coord' : task.left,
'left_gt_orientation': task.left_orient,
'right_gt_coord': task.right,
'right_gt_orientation': task.right_orient}
json_dump.append(state_)
except DemoError as e:
fails += 1
print(e)
except Exception as e:
# TODO: check that we don't fall through all of these cases
fails += 1
print(e)
scene.reset()
desc = scene.init_episode(variation_num, max_attempts=10, randomly_place=False)
task.boundary_root().set_position(origin[0])
task.boundary_root().set_orientation(origin[1])
task.augment()
demo2 = scene.get_demo(record=True, randomly_place=False)
state_ = json_dump[-1]
state_['left_coord'] = task.left
state_['left_orientation'] = task.left_orient
state_['right_coord'] = task.right
state_['right_orientation'] = task.right_orient
json_dump.append(state_)
return fails
variations_to_test = [variation]
if variation < 0:
variations_to_test = list(range(
np.minimum(variation_count, max_variations)))
# Task set-up
scene.init_task()
[variation_smoke(i) for i in variations_to_test]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("task", help="The task file to test.")
args = parser.parse_args()
python_file = os.path.join(TASKS_PATH, args.task)
if not os.path.isfile(python_file):
raise RuntimeError('Could not find the task file: %s' % python_file)
task_class = task_file_to_task_class(args.task)
sim = PyRep()
ttt_file = os.path.join(
'/home/commonsense/data/cvpr/3d_diffuser_actor/RLBench/rlbench', TTT_FILE)
json_dump = []
sim.launch(ttt_file, headless=False, responsive_ui=True)
sim.step_ui()
sim.set_simulation_timestep(0.005)
sim.step_ui()
sim.start()
robot = Robot(Panda(), PandaGripper())
active_task = task_class(sim, robot)
# camera
obs = ObservationConfig()
obs.set_all(False)
cam_config = CameraConfig(rgb=True, depth=False, mask=False,
render_mode=RenderMode.OPENGL)
obs.wrist_camera = cam_config
scene = Scene(sim, robot, obs)
try:
task_smoke(active_task, scene, variation=1)
except TaskValidationError as e:
sim.shutdown()
raise e
sim.shutdown()
print(json_dump)
print('Validation successful!')