-
Notifications
You must be signed in to change notification settings - Fork 9
/
main_trajectory.py
428 lines (374 loc) · 14.5 KB
/
main_trajectory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
"""Main script for trajectory optimization."""
import io
import os
from pathlib import Path
import random
from typing import Tuple, Optional
import cv2
from matplotlib import pyplot as plt
import numpy as np
import tap
import torch
import torch.distributed as dist
from torch.nn import functional as F
from datasets import RLBenchDataset
from engine import BaseTrainTester
from model import DiffusionPlanner
from utils.utils_without_rlbench import (
load_instructions, count_parameters, get_gripper_loc_bounds
)
class Arguments(tap.Tap):
# local_rank: int
cameras: Tuple[str, ...] = ("wrist", "left_shoulder", "right_shoulder")
image_size: str = "256,256"
max_episodes_per_task: int = 100
instructions: Optional[Path] = "instructions.pkl"
seed: int = 0
tasks: Tuple[str, ...]
variations: Tuple[int, ...] = (0,)
checkpoint: Optional[Path] = None
accumulate_grad_batches: int = 1
val_freq: int = 500
gripper_loc_bounds: Optional[str] = None
eval_only: int = 0
# Training and validation datasets
dataset: Path
valset: Path
dense_interpolation: int = 0
interpolation_length: int = 100
# Logging to base_log_dir/exp_log_dir/run_log_dir
base_log_dir: Path = Path(__file__).parent / "train_logs"
exp_log_dir: str = "exp"
run_log_dir: str = "run"
# Main training parameters
num_workers: int = 1
batch_size: int = 16
batch_size_val: int = 4
cache_size: int = 100
cache_size_val: int = 100
lr: float = 1e-4
train_iters: int = 200_000
max_episode_length: int = 5 # -1 for no limit
# Data augmentations
image_rescale: str = "0.75,1.25" # (min, max), "1.0,1.0" for no rescaling
point_cloud_rotate_yaw_range: float = 0.0 # in degrees, 0.0 for no rot
# Model
action_dim: int = 7
backbone: str = "clip" # one of "resnet", "clip"
embedding_dim: int = 120
num_query_cross_attn_layers: int = 6
num_vis_ins_attn_layers: int = 2
use_instruction: int = 0
use_goal: int = 0
use_goal_at_test: int = 1
feat_scales_to_use: int = 1
attn_rounds: int = 1
weight_tying: int = 0
rotation_parametrization: str = 'quat'
diffusion_timesteps: int = 100
class TrainTester(BaseTrainTester):
"""Train/test a trajectory optimization algorithm."""
def __init__(self, args):
"""Initialize."""
super().__init__(args)
def get_datasets(self):
"""Initialize datasets."""
# Load instruction, based on which we load tasks/variations
instruction = load_instructions(
self.args.instructions,
tasks=self.args.tasks,
variations=self.args.variations
)
if instruction is None:
raise NotImplementedError()
else:
taskvar = [
(task, var)
for task, var_instr in instruction.items()
for var in var_instr.keys()
]
# Initialize datasets with arguments
train_dataset = RLBenchDataset(
root=self.args.dataset,
instructions=instruction,
taskvar=taskvar,
max_episode_length=self.args.max_episode_length,
max_episodes_per_task=self.args.max_episodes_per_task,
cache_size=self.args.cache_size,
num_iters=self.args.train_iters,
cameras=self.args.cameras,
training=True,
gripper_loc_bounds=self.args.gripper_loc_bounds,
image_rescale=tuple(
float(x) for x in self.args.image_rescale.split(",")
),
point_cloud_rotate_yaw_range=self.args.point_cloud_rotate_yaw_range,
return_low_lvl_trajectory=True,
dense_interpolation=bool(self.args.dense_interpolation),
interpolation_length=self.args.interpolation_length,
action_dim=self.args.action_dim,
predict_short=False
)
test_dataset = RLBenchDataset(
root=self.args.valset,
instructions=instruction,
taskvar=taskvar,
max_episode_length=self.args.max_episode_length,
max_episodes_per_task=self.args.max_episodes_per_task,
cache_size=self.args.cache_size_val,
cameras=self.args.cameras,
training=False,
gripper_loc_bounds=self.args.gripper_loc_bounds,
image_rescale=tuple(
float(x) for x in self.args.image_rescale.split(",")
),
point_cloud_rotate_yaw_range=self.args.point_cloud_rotate_yaw_range,
return_low_lvl_trajectory=True,
dense_interpolation=bool(self.args.dense_interpolation),
interpolation_length=self.args.interpolation_length,
action_dim=self.args.action_dim,
predict_short=False
)
return train_dataset, test_dataset
def get_model(self):
"""Initialize the model."""
# Initialize model with arguments
_model = DiffusionPlanner(
backbone=self.args.backbone,
image_size=tuple(int(x) for x in self.args.image_size.split(",")),
embedding_dim=self.args.embedding_dim,
output_dim=self.args.action_dim,
num_vis_ins_attn_layers=self.args.num_vis_ins_attn_layers,
num_query_cross_attn_layers=self.args.num_query_cross_attn_layers,
use_instruction=bool(self.args.use_instruction),
use_goal=bool(self.args.use_goal),
use_goal_at_test=bool(self.args.use_goal_at_test),
feat_scales_to_use=self.args.feat_scales_to_use,
attn_rounds=self.args.attn_rounds,
weight_tying=bool(self.args.weight_tying),
gripper_loc_bounds=self.args.gripper_loc_bounds,
rotation_parametrization=self.args.rotation_parametrization,
diffusion_timesteps=self.args.diffusion_timesteps
)
print("Model parameters:", count_parameters(_model))
return _model
@staticmethod
def get_criterion():
return TrajectoryCriterion()
def train_one_step(self, model, criterion, optimizer, step_id, sample):
"""Run a single training step."""
if step_id % self.args.accumulate_grad_batches == 0:
optimizer.zero_grad()
# Forward pass
out = model(
sample["trajectory"],
sample["trajectory_mask"],
sample["rgbs"],
sample["pcds"],
sample["instr"],
sample["curr_gripper"],
sample["action"]
)
# Backward pass
loss = criterion.compute_loss(out)
loss.backward()
# Update
if step_id % self.args.accumulate_grad_batches == self.args.accumulate_grad_batches - 1:
optimizer.step()
# Log
if dist.get_rank() == 0 and (step_id + 1) % self.args.val_freq == 0:
self.writer.add_scalar("lr", self.args.lr, step_id)
self.writer.add_scalar("train-loss/noise_mse", loss, step_id)
@torch.no_grad()
def evaluate_nsteps(self, model, criterion, loader, step_id, val_iters,
split='val'):
"""Run a given number of evaluation steps."""
values = {}
device = next(model.parameters()).device
model.eval()
for i, sample in enumerate(loader):
if i == val_iters:
break
action = model(
sample["trajectory"].to(device),
sample["trajectory_mask"].to(device),
sample["rgbs"].to(device),
sample["pcds"].to(device),
sample["instr"].to(device),
sample["curr_gripper"].to(device),
sample["action"].to(device),
run_inference=True
)
losses, losses_B = criterion.compute_metrics(
action,
sample["trajectory"].to(device),
sample["trajectory_mask"].to(device)
)
# Gather global statistics
for n, l in losses.items():
key = f"{split}-losses/{n}"
if key not in values:
values[key] = torch.Tensor([]).to(device)
values[key] = torch.cat([values[key], l.unsqueeze(0)])
# Gather per-task statistics
tasks = np.array(sample["task"])
for n, l in losses_B.items():
for task in np.unique(tasks):
key = f"{split}-loss/{task}/{n}"
l_task = l[tasks == task].mean()
if key not in values:
values[key] = torch.Tensor([]).to(device)
values[key] = torch.cat([values[key], l_task.unsqueeze(0)])
# Generate visualizations
if i == 0 and dist.get_rank() == 0 and step_id > -1:
viz_key = f'{split}-viz/viz'
viz = generate_visualizations(
action,
sample["trajectory"].to(device),
sample["trajectory_mask"].to(device)
)
self.writer.add_image(viz_key, viz, step_id)
# Log all statistics
values = self.synchronize_between_processes(values)
values = {k: v.mean().item() for k, v in values.items()}
if dist.get_rank() == 0:
if step_id > -1:
for key, val in values.items():
self.writer.add_scalar(key, val, step_id)
# Also log to terminal
print(f"Step {step_id}:")
for key, value in values.items():
print(f"{key}: {value:.03f}")
return values.get('val-losses/traj_action_mse', None)
def traj_collate_fn(batch):
keys = [
"trajectory", "trajectory_mask",
"rgbs", "pcds", "curr_gripper", "action", "instr"
]
ret_dict = {
key: torch.cat([
item[key].float() if key != 'trajectory_mask' else item[key]
for item in batch
]) for key in keys
}
ret_dict["task"] = []
for item in batch:
ret_dict["task"] += item['task']
return ret_dict
class TrajectoryCriterion:
def __init__(self):
pass
def compute_loss(self, pred, gt=None, mask=None, is_loss=True):
if not is_loss:
assert gt is not None and mask is not None
return self.compute_metrics(pred, gt, mask)[0]['action_mse']
return pred
@staticmethod
def compute_metrics(pred, gt, mask):
# pred/gt are (B, L, 7), mask (B, L)
pos_l2 = ((pred[..., :3] - gt[..., :3]) ** 2).sum(-1).sqrt()
# symmetric quaternion eval
quat_l1 = (pred[..., 3:7] - gt[..., 3:7]).abs().sum(-1)
quat_l1_ = (pred[..., 3:7] + gt[..., 3:7]).abs().sum(-1)
select_mask = (quat_l1 < quat_l1_).float()
quat_l1 = (select_mask * quat_l1 + (1 - select_mask) * quat_l1_)
tr = 'traj_'
# Trajectory metrics
ret_1, ret_2 = {
tr + 'action_mse': F.mse_loss(pred, gt),
tr + 'pos_l2': pos_l2.mean(),
tr + 'pos_acc_001': (pos_l2 < 0.01).float().mean(),
tr + 'rot_l1': quat_l1.mean(),
tr + 'rot_acc_0025': (quat_l1 < 0.025).float().mean()
}, {
tr + 'pos_l2': pos_l2.mean(-1),
tr + 'pos_acc_001': (pos_l2 < 0.01).float().mean(-1),
tr + 'rot_l1': quat_l1.mean(-1),
tr + 'rot_acc_0025': (quat_l1 < 0.025).float().mean(-1)
}
# Keypose metrics (useful when not goal-conditioned)
pos_l2 = ((pred[:, -1, :3] - gt[:, -1, :3]) ** 2).sum(-1).sqrt()
quat_l1 = (pred[:, -1, 3:7] - gt[:, -1, 3:7]).abs().sum(-1)
quat_l1_ = (pred[:, -1, 3:7] + gt[:, -1, 3:7]).abs().sum(-1)
select_mask = (quat_l1 < quat_l1_).float()
quat_l1 = (select_mask * quat_l1 + (1 - select_mask) * quat_l1_)
ret_1.update({
'pos_l2': pos_l2.mean(),
'pos_acc_001': (pos_l2 < 0.01).float().mean(),
'rot_l1': quat_l1.mean(),
'rot_acc_0025': (quat_l1 < 0.025).float().mean()
})
return ret_1, ret_2
def fig_to_numpy(fig, dpi=60):
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi)
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
return img
def generate_visualizations(pred, gt, mask, box_size=0.3):
batch_idx = 0
pred = pred[batch_idx].detach().cpu().numpy()
gt = gt[batch_idx].detach().cpu().numpy()
mask = mask[batch_idx].detach().cpu().numpy()
fig = plt.figure(figsize=(10, 10))
ax = plt.axes(projection='3d')
ax.scatter3D(
pred[~mask][:, 0], pred[~mask][:, 1], pred[~mask][:, 2],
color='red', label='pred'
)
ax.scatter3D(
gt[~mask][:, 0], gt[~mask][:, 1], gt[~mask][:, 2],
color='blue', label='gt'
)
center = gt[~mask].mean(0)
ax.set_xlim(center[0] - box_size, center[0] + box_size)
ax.set_ylim(center[1] - box_size, center[1] + box_size)
ax.set_zlim(center[2] - box_size, center[2] + box_size)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
plt.legend()
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
img = fig_to_numpy(fig, dpi=120)
plt.close()
return img.transpose(2, 0, 1)
if __name__ == '__main__':
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Arguments
args = Arguments().parse_args()
print("Arguments:")
print(args)
print("-" * 100)
if args.gripper_loc_bounds is None:
args.gripper_loc_bounds = np.array([[-2, -2, -2], [2, 2, 2]]) * 1.0
else:
args.gripper_loc_bounds = get_gripper_loc_bounds(
args.gripper_loc_bounds,
task=args.tasks[0] if len(args.tasks) == 1 else None,
buffer=0.04
)
log_dir = args.base_log_dir / args.exp_log_dir / args.run_log_dir
args.log_dir = log_dir
log_dir.mkdir(exist_ok=True, parents=True)
print("Logging:", log_dir)
print(
"Available devices (CUDA_VISIBLE_DEVICES):",
os.environ.get("CUDA_VISIBLE_DEVICES")
)
print("Device count", torch.cuda.device_count())
args.local_rank = int(os.environ["LOCAL_RANK"])
# Seeds
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
# DDP initialization
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
# Run
train_tester = TrainTester(args)
train_tester.main(collate_fn=traj_collate_fn)