-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathdiffusion_policy.py
786 lines (648 loc) · 30.2 KB
/
diffusion_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
"""
Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi
"""
from typing import Callable, Union
import math
from collections import OrderedDict, deque
from packaging.version import parse as parse_version
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
# requires diffusers==0.11.1
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.training_utils import EMAModel
import robomimic.models.obs_nets as ObsNets
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.algo import register_algo_factory_func, PolicyAlgo
import random
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
import os
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
lang_model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16)
lang_model.to('cuda')
# import torch.distributed as dist
# from torch.nn.parallel import DistributedDataParallel as DDP
import cv2
import copy
@register_algo_factory_func("diffusion_policy")
def algo_config_to_class(algo_config):
"""
Maps algo config to the BC algo class to instantiate, along with additional algo kwargs.
Args:
algo_config (Config instance): algo config
Returns:
algo_class: subclass of Algo
algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
"""
if algo_config.unet.enabled:
return DiffusionPolicyUNet, {}
elif algo_config.transformer.enabled:
raise NotImplementedError()
else:
raise RuntimeError()
class DiffusionPolicyUNet(PolicyAlgo):
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
# set up different observation groups for @MIMO_MLP
observation_group_shapes = OrderedDict()
observation_group_shapes["obs"] = OrderedDict(self.obs_shapes)
encoder_kwargs = ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder)
obs_encoder = ObsNets.ObservationGroupEncoder(
observation_group_shapes=observation_group_shapes,
encoder_kwargs=encoder_kwargs,
)
# IMPORTANT!
# replace all BatchNorm with GroupNorm to work with EMA
# performance will tank if you forget to do this!
obs_encoder = replace_bn_with_gn(obs_encoder)
obs_dim = obs_encoder.output_shape()[0]
# create network object
noise_pred_net = ConditionalUnet1D(
input_dim=self.ac_dim,
global_cond_dim=obs_dim*self.algo_config.horizon.observation_horizon
)
# the final arch has 2 parts
nets = nn.ModuleDict({
'policy': nn.ModuleDict({
'obs_encoder': torch.nn.parallel.DataParallel(obs_encoder, device_ids=list(range(0,torch.cuda.device_count()))),
'noise_pred_net': torch.nn.parallel.DataParallel(noise_pred_net, device_ids=list(range(0,torch.cuda.device_count())))
})
})
nets = nets.float().to(self.device)
# setup noise scheduler
noise_scheduler = None
if self.algo_config.ddpm.enabled:
noise_scheduler = DDPMScheduler(
num_train_timesteps=self.algo_config.ddpm.num_train_timesteps,
beta_schedule=self.algo_config.ddpm.beta_schedule,
clip_sample=self.algo_config.ddpm.clip_sample,
prediction_type=self.algo_config.ddpm.prediction_type
)
elif self.algo_config.ddim.enabled:
noise_scheduler = DDIMScheduler(
num_train_timesteps=self.algo_config.ddim.num_train_timesteps,
beta_schedule=self.algo_config.ddim.beta_schedule,
clip_sample=self.algo_config.ddim.clip_sample,
set_alpha_to_one=self.algo_config.ddim.set_alpha_to_one,
steps_offset=self.algo_config.ddim.steps_offset,
prediction_type=self.algo_config.ddim.prediction_type
)
else:
raise RuntimeError()
# setup EMA
ema = None
if self.algo_config.ema.enabled:
ema = EMAModel(model=nets, power=self.algo_config.ema.power)
# set attrs
self.nets = nets
self.noise_scheduler = noise_scheduler
self.ema = ema
self.action_check_done = False
self.obs_queue = None
self.action_queue = None
def process_batch_for_training(self, batch):
"""
Processes input batch from a data loader to filter out
relevant information and prepare the batch for training.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader
Returns:
input_batch (dict): processed and filtered batch that
will be used for training
"""
To = self.algo_config.horizon.observation_horizon
Ta = self.algo_config.horizon.action_horizon
Tp = self.algo_config.horizon.prediction_horizon
input_batch = dict()
## Semi-hacky fix which does the filtering for raw language which is just a list of lists of strings
input_batch["obs"] = {k: batch["obs"][k][:, :To, :] for k in batch["obs"] if "raw" not in k }
if "lang_fixed/language_raw" in batch["obs"].keys():
str_ls = list(batch['obs']['lang_fixed/language_raw'][0])
input_batch["obs"]["lang_fixed/language_raw"] = [str_ls] * To
with torch.no_grad():
if "raw_language" in batch["obs"].keys():
raw_lang_strings = [byte_string.decode('utf-8') for byte_string in batch["obs"]['raw_language']]
encoded_input = tokenizer(raw_lang_strings, padding=True, truncation=True, return_tensors='pt').to('cuda')
outputs = lang_model(**encoded_input)
encoded_lang = outputs.last_hidden_state.sum(1).squeeze().unsqueeze(1).repeat(1, To, 1)
input_batch["obs"]["lang_fixed/language_distilbert"] = encoded_lang.type(torch.float32)
input_batch["actions"] = batch["actions"][:, :Tp, :]
# check if actions are normalized to [-1,1]
if not self.action_check_done:
actions = input_batch["actions"]
in_range = (-1 <= actions) & (actions <= 1)
all_in_range = torch.all(in_range).item()
if not all_in_range:
raise ValueError('"actions" must be in range [-1,1] for Diffusion Policy! Check if hdf5_normalize_action is enabled.')
self.action_check_done = True
## LOGGING HOW MANY NANs there are
# bz = input_batch["actions"].shape[0]
# nanamt = torch.BoolTensor([False] * bz)
# for key in input_batch["obs"]:
# if key == "pad_mask":
# continue
# nanamt = torch.logical_or(nanamt, torch.isnan(input_batch["obs"][key].reshape(bz, -1).mean(1)))
# print(nanamt.float().mean())
for key in input_batch["obs"]:
input_batch["obs"][key] = torch.nan_to_num(input_batch["obs"][key])
input_batch["actions"] = torch.nan_to_num(input_batch["actions"])
return TensorUtils.to_device(TensorUtils.to_float(input_batch), self.device)
def train_on_batch(self, batch, epoch, validate=False):
"""
Training on a single batch of data.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
epoch (int): epoch number - required by some Algos that need
to perform staged training and early stopping
validate (bool): if True, don't perform any learning updates.
Returns:
info (dict): dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
"""
To = self.algo_config.horizon.observation_horizon
Ta = self.algo_config.horizon.action_horizon
Tp = self.algo_config.horizon.prediction_horizon
action_dim = self.ac_dim
B = batch['actions'].shape[0]
with TorchUtils.maybe_no_grad(no_grad=validate):
info = super(DiffusionPolicyUNet, self).train_on_batch(batch, epoch, validate=validate)
actions = batch['actions']
# encode obs
inputs = {
'obs': batch["obs"],
}
for k in self.obs_shapes:
## Shape assertion does not apply to list of strings for raw language
if "raw" in k:
continue
# first two dimensions should be [B, T] for inputs
assert inputs['obs'][k].ndim - 2 == len(self.obs_shapes[k])
obs_features = TensorUtils.time_distributed({"obs":inputs["obs"]}, self.nets['policy']['obs_encoder'], inputs_as_kwargs=True)
assert obs_features.ndim == 3 # [B, T, D]
obs_cond = obs_features.flatten(start_dim=1)
num_noise_samples = self.algo_config.noise_samples
# sample noise to add to actions
noise = torch.randn([num_noise_samples] + list(actions.shape), device=self.device)
# sample a diffusion iteration for each data point
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=self.device
).long()
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
# (this is the forward diffusion process)
noisy_actions = torch.cat([self.noise_scheduler.add_noise(
actions, noise[i], timesteps)
for i in range(len(noise))], dim=0)
obs_cond = obs_cond.repeat(num_noise_samples, 1)
timesteps = timesteps.repeat(num_noise_samples)
# predict the noise residual
noise_pred = self.nets['policy']['noise_pred_net'](
noisy_actions, timesteps, global_cond=obs_cond)
# L2 loss
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
loss = F.mse_loss(noise_pred, noise)
# logging
losses = {
'l2_loss': loss
}
info["losses"] = TensorUtils.detach(losses)
if not validate:
# gradient step
policy_grad_norms = TorchUtils.backprop_for_loss(
net=self.nets,
optim=self.optimizers["policy"],
loss=loss,
)
# update Exponential Moving Average of the model weights
if self.ema is not None:
self.ema.step(self.nets)
step_info = {
'policy_grad_norms': policy_grad_norms
}
info.update(step_info)
return info
def log_info(self, info):
"""
Process info dictionary from @train_on_batch to summarize
information to pass to tensorboard for logging.
Args:
info (dict): dictionary of info
Returns:
loss_log (dict): name -> summary statistic
"""
log = super(DiffusionPolicyUNet, self).log_info(info)
log["Loss"] = info["losses"]["l2_loss"].item()
if "policy_grad_norms" in info:
log["Policy_Grad_Norms"] = info["policy_grad_norms"]
return log
def reset(self):
"""
Reset algo state to prepare for environment rollouts.
"""
# setup inference queues
To = self.algo_config.horizon.observation_horizon
Ta = self.algo_config.horizon.action_horizon
obs_queue = deque(maxlen=To)
action_queue = deque(maxlen=Ta)
self.obs_queue = obs_queue
self.action_queue = action_queue
def get_action(self, obs_dict, goal_mode=None, eval_mode=False):
"""
Get policy action outputs.
Args:
obs_dict (dict): current observation [1, Do]
goal_dict (dict): (optional) goal
Returns:
action (torch.Tensor): action tensor [1, Da]
"""
# obs_dict: key: [1,D]
To = self.algo_config.horizon.observation_horizon
Ta = self.algo_config.horizon.action_horizon
if eval_mode:
from droid.misc.parameters import hand_camera_id, varied_camera_1_id, varied_camera_2_id
root_path = os.path.join(os. getcwd(), "eval_params")
if goal_mode is not None:
# Read in goal images
goal_hand_camera_left_image = torch.FloatTensor((cv2.cvtColor(cv2.imread(os.path.join(root_path, f"{hand_camera_id}_left.png")), cv2.COLOR_BGR2RGB) / 255.0)).cuda().permute(2, 0, 1).unsqueeze(0).repeat([1, 1, 1, 1]).unsqueeze(0)
goal_hand_camera_right_image = torch.FloatTensor((cv2.cvtColor(cv2.imread(os.path.join(root_path, f"{hand_camera_id}_right.png")), cv2.COLOR_BGR2RGB) / 255.0)).cuda().permute(2, 0, 1).unsqueeze(0).repeat([1, 1, 1, 1]).unsqueeze(0)
goal_varied_camera_1_left_image = torch.FloatTensor((cv2.cvtColor(cv2.imread(os.path.join(root_path, f"{varied_camera_1_id}_left.png")), cv2.COLOR_BGR2RGB) / 255.0)).cuda().permute(2, 0, 1).unsqueeze(0).repeat([1, 1, 1, 1]).unsqueeze(0)
goal_varied_camera_1_right_image = torch.FloatTensor((cv2.cvtColor(cv2.imread(os.path.join(root_path, f"{varied_camera_1_id}_right.png")), cv2.COLOR_BGR2RGB) / 255.0)).cuda().permute(2, 0, 1).unsqueeze(0).repeat([1, 1, 1, 1]).unsqueeze(0)
goal_varied_camera_2_left_image = torch.FloatTensor((cv2.cvtColor(cv2.imread(os.path.join(root_path, f"{varied_camera_2_id}_left.png")), cv2.COLOR_BGR2RGB) / 255.0)).cuda().permute(2, 0, 1).unsqueeze(0).repeat([1, 1, 1, 1]).unsqueeze(0)
goal_varied_camera_2_right_image = torch.FloatTensor((cv2.cvtColor(cv2.imread(os.path.join(root_path, f"{varied_camera_2_id}_right.png")), cv2.COLOR_BGR2RGB) / 255.0)).cuda().permute(2, 0, 1).unsqueeze(0).repeat([1, 1, 1, 1]).unsqueeze(0)
obs_dict['camera/image/hand_camera_left_image'] = torch.cat([obs_dict['camera/image/hand_camera_left_image'], goal_hand_camera_left_image.repeat(1, To, 1, 1, 1)], dim=2)
obs_dict['camera/image/hand_camera_right_image'] = torch.cat([obs_dict['camera/image/hand_camera_right_image'], goal_hand_camera_right_image.repeat(1, To, 1, 1, 1)], dim=2)
obs_dict['camera/image/varied_camera_1_left_image'] = torch.cat([obs_dict['camera/image/varied_camera_1_left_image'], goal_varied_camera_1_left_image.repeat(1, To, 1, 1, 1)], dim=2)
obs_dict['camera/image/varied_camera_1_right_image'] = torch.cat([obs_dict['camera/image/varied_camera_1_right_image'] , goal_varied_camera_1_right_image.repeat(1, To, 1, 1, 1)], dim=2)
obs_dict['camera/image/varied_camera_2_left_image'] = torch.cat([obs_dict['camera/image/varied_camera_2_left_image'] , goal_varied_camera_2_left_image.repeat(1, To, 1, 1, 1)], dim=2)
obs_dict['camera/image/varied_camera_2_right_image'] = torch.cat([obs_dict['camera/image/varied_camera_2_right_image'], goal_varied_camera_2_right_image.repeat(1, To, 1, 1, 1)], dim=2)
# Note: currently assumes that you are never doing both goal and language conditioning
else:
# Reads in current language instruction from file and fills the appropriate obs key, only will
# actually use it if the policy uses language instructions
with open(os.path.join(root_path, "lang_command.txt"), 'r') as file:
raw_lang = file.read()
encoded_input = tokenizer(raw_lang, return_tensors='pt').to('cuda')
outputs = lang_model(**encoded_input)
encoded_lang = outputs.last_hidden_state.sum(1).squeeze().unsqueeze(0).repeat(To, 1).unsqueeze(0)
obs_dict["lang_fixed/language_distilbert"] = encoded_lang.type(torch.float32)
###############################
# TODO: obs_queue already handled by frame_stack
# make sure we have at least To observations in obs_queue
# if not enough, repeat
# if already full, append one to the obs_queue
# n_repeats = max(To - len(self.obs_queue), 1)
# self.obs_queue.extend([obs_dict] * n_repeats)
if len(self.action_queue) == 0:
# no actions left, run inference
# turn obs_queue into dict of tensors (concat at T dim)
# import pdb; pdb.set_trace()
# obs_dict_list = TensorUtils.list_of_flat_dict_to_dict_of_list(list(self.obs_queue))
# obs_dict_tensor = dict((k, torch.cat(v, dim=0).unsqueeze(0)) for k,v in obs_dict_list.items())
# run inference
# [1,T,Da]
action_sequence = self._get_action_trajectory(obs_dict=obs_dict)
# put actions into the queue
self.action_queue.extend(action_sequence[0])
# has action, execute from left to right
# [Da]
action = self.action_queue.popleft()
# [1,Da]
action = action.unsqueeze(0)
return action
def _get_action_trajectory(self, obs_dict):
assert not self.nets.training
To = self.algo_config.horizon.observation_horizon
Ta = self.algo_config.horizon.action_horizon
Tp = self.algo_config.horizon.prediction_horizon
action_dim = self.ac_dim
if self.algo_config.ddpm.enabled is True:
num_inference_timesteps = self.algo_config.ddpm.num_inference_timesteps
elif self.algo_config.ddim.enabled is True:
num_inference_timesteps = self.algo_config.ddim.num_inference_timesteps
else:
raise ValueError
# select network
nets = self.nets
if self.ema is not None:
nets = self.ema.averaged_model
# encode obs
inputs = {
'obs': obs_dict,
}
for k in self.obs_shapes:
## Shape assertion does not apply to list of strings for raw language
if "raw" in k:
continue
# first two dimensions should be [B, T] for inputs
assert inputs['obs'][k].ndim - 2 == len(self.obs_shapes[k])
obs_features = TensorUtils.time_distributed({"obs":inputs["obs"]}, nets['policy']['obs_encoder'].module, inputs_as_kwargs=True)
assert obs_features.ndim == 3 # [B, T, D]
B = obs_features.shape[0]
# reshape observation to (B,obs_horizon*obs_dim)
obs_cond = obs_features.flatten(start_dim=1)
# initialize action from Guassian noise
noisy_action = torch.randn(
(B, Tp, action_dim), device=self.device)
naction = noisy_action
# init scheduler
self.noise_scheduler.set_timesteps(num_inference_timesteps)
for k in self.noise_scheduler.timesteps:
# predict noise
noise_pred = nets['policy']['noise_pred_net'].module(
sample=naction,
timestep=k,
global_cond=obs_cond
)
# inverse diffusion step (remove noise)
naction = self.noise_scheduler.step(
model_output=noise_pred,
timestep=k,
sample=naction
).prev_sample
# process action using Ta
start = To - 1
end = start + Ta
action = naction[:,start:end]
return action
def serialize(self):
"""
Get dictionary of current model parameters.
"""
return {
"nets": self.nets.state_dict(),
"ema": self.ema.averaged_model.state_dict() if self.ema is not None else None,
}
def deserialize(self, model_dict):
"""
Load model from a checkpoint.
Args:
model_dict (dict): a dictionary saved by self.serialize() that contains
the same keys as @self.network_classes
"""
self.nets.load_state_dict(model_dict["nets"])
if model_dict.get("ema", None) is not None:
self.ema.averaged_model.load_state_dict(model_dict["ema"])
# =================== Vision Encoder Utils =====================
def replace_submodules(
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
"""
Replace all submodules selected by the predicate with
the output of func.
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
if parse_version(torch.__version__) < parse_version('1.9.0'):
raise ImportError('This function requires pytorch >= 1.9.0')
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule('.'.join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all modules are replaced
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
assert len(bn_list) == 0
return root_module
def replace_bn_with_gn(
root_module: nn.Module,
features_per_group: int=16) -> nn.Module:
"""
Relace all BatchNorm layers with GroupNorm.
"""
replace_submodules(
root_module=root_module,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features//features_per_group,
num_channels=x.num_features)
)
return root_module
# =================== UNet for Diffusion ==============
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ConditionalResidualBlock1D(nn.Module):
def __init__(self,
in_channels,
out_channels,
cond_dim,
kernel_size=3,
n_groups=8):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
])
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels * 2
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
nn.Unflatten(-1, (-1, 1))
)
# make sure dimensions compatible
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
def forward(self, x, cond):
'''
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
embed = embed.reshape(
embed.shape[0], 2, self.out_channels, 1)
scale = embed[:,0,...]
bias = embed[:,1,...]
out = scale * out + bias
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
def __init__(self,
input_dim,
global_cond_dim,
diffusion_step_embed_dim=256,
down_dims=[256,512,1024],
kernel_size=5,
n_groups=8
):
"""
input_dim: Dim of actions.
global_cond_dim: Dim of global conditioning applied with FiLM
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
down_dims: Channel size for each UNet level.
The length of this array determines numebr of levels.
kernel_size: Conv kernel size
n_groups: Number of groups for GroupNorm
"""
super().__init__()
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed + global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:]))
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups
),
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups
),
])
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
ConditionalResidualBlock1D(
dim_out, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_out*2, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
ConditionalResidualBlock1D(
dim_in, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
print("number of parameters: {:e}".format(
sum(p.numel() for p in self.parameters()))
)
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
global_cond=None):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
"""
# (B,T,C)
sample = sample.moveaxis(-1,-2)
# (B,C,T)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([
global_feature, global_cond
], axis=-1)
x = sample
h = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
h.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
# (B,C,T)
x = x.moveaxis(-1,-2)
# (B,T,C)
return x