This repository has been archived by the owner on Nov 25, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
1804 lines (1611 loc) · 105 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import numpy as np
import tensorflow as tf
import functools
import sys
import os
import re
from tensorflow.python.framework import ops
from tensorflow.python.ops import functional_ops
from BasicRNNCell import BasicLSTMCell as LSTM
from BasicRNNCell import BasicMTRNNCell as MTRNN
from BasicRNNCell import _linear
import ops
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.stats import entropy
from pkg_resources import parse_version
def lazy_property(function):
attribute = '_' + function.__name__
@property
@functools.wraps(function)
def wrapper(self):
if not hasattr(self, attribute):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return wrapper
def print_dict(dict):
for key in dict:
print(str(key) + " = " + str(dict[key]))
"""
Predictive Coding Variational Bayes Recurrent Neural Network (PV-RNN)
"""
class PVRNN(object):
def __init__(self, sess, config_data, config_network, training=True, planning=None, input_cl_ratio=1.0, learning_rate=0.001, optimizer_epsilon=1e-8, prior_generation=False, reset_posterior_src=False, hybrid_posterior_src=None, hybrid_posterior_src_idx=None, data_masking=False, overrides=dict()):
self.sess = sess
self.modalities = [str(m.strip()) for m in config_network["modalities"].split(',')] # determine how many modalities are used
## Dump config
print("training = " + str(training))
if training:
print("**learning args**")
print("learning_rate = " + str(learning_rate))
print("optimizer_epsilon = " + str(optimizer_epsilon))
print("**generation args**")
print("prior_generation = " + str(prior_generation))
print("reset_posterior_src = " + str(reset_posterior_src))
print("hybrid_posterior_src = " + str(hybrid_posterior_src))
print("hybrid_posterior_src_idx = " + str(hybrid_posterior_src_idx))
print("**config_data**")
print_dict(config_data)
print("**config_network**")
print_dict(config_network)
if planning is not None:
print("**planning**")
print_dict(planning)
if overrides:
print("**overrides**")
print_dict(overrides)
## Data
self.n_seq = int(config_data["sequences"]) # number of sequences
self.mask_seq = data_masking # use masking to compensate for different sequence lengths (disable if data is padded)
self.max_timesteps = int(config_data["max_timesteps"]) # if real data is loaded, this is overwritten
self.dims = dict()
self.path = dict()
for m in self.modalities:
self.dims[m] = int(config_data[m + "_dims"])
self.path[m] = config_data[m + "_path"]
self.batch_size = self.n_seq # = n for performance, = 1 for online learning/error regression
self.softmax_quant = int(config_data["softmax_quant"])
self.override_load_data = True if overrides is not None and "load_training_data" in overrides and overrides["load_training_data"] else False
## Training
self.training = training # set whether the network is learning (training or planning)
self.planning = True if planning is not None else False # set to do error regression to generate a plan from start to goal
if self.planning:
self.planning_init_modalities = [x.strip() for x in planning["init_modalities"].split(',')] if "init_modalities" in planning else self.modalities[0]
self.planning_goal_modalities = [x.strip() for x in planning["goal_modalities"].split(',')] if "goal_modalities" in planning else self.modalities[0]
self.planning_goal_modalities_mask = [int(x.strip()) for x in planning["goal_modalities_mask"].split(',')] if "goal_modalities_mask" in planning else None
self.planning_init_modalities_mask = [int(x.strip()) for x in planning["init_modalities_mask"].split(',')] if "init_modalities_mask" in planning else None
## Planning
self.planning_initial_frame = 0 # start of initial frames
self.planning_initial_depth = 1 # how many frames to provide at the start
self.planning_duplicate_initial_frame = False # set to true to duplicate initial frame for padding
self.planning_goal_frame = -1 # start of goal frames (should be negative index)
self.planning_goal_depth = None # how many frames to provide at the end (set to None to continue until the end of the sequence)
self.planning_goal_offset = 0 # move goal position in plan relative to target data
self.planning_goal_padding = False # set to true to copy the goal frame until the end of the sequence
self.planning_auto_weight = 1 # increase rec error pressure by a factor of missing frames (0 to disable, >1 to multiply)
if self.planning:
if "init_frame" in planning:
self.planning_initial_frame = int(planning["init_frame"])
if "init_depth" in planning:
self.planning_initial_depth = int(planning["init_depth"]) if planning["init_depth"].lower() != "none" else None
if "goal_frame" in planning:
self.planning_goal_frame = int(planning["goal_frame"])
if "goal_depth" in planning:
self.planning_goal_depth = int(planning["goal_depth"]) if planning["goal_depth"].lower() != "none" else None
if "init_frame_duplicate" in planning:
self.planning_duplicate_initial_frame = True if planning["init_frame_duplicate"].lower() == "true" else False
if "goal_padding" in planning:
self.planning_goal_padding = True if planning["goal_padding"].lower() == "true" else False
if "rec_weighting" in planning:
self.planning_auto_weight = float(planning["rec_weighting"])
## Optimizer
if "optimizer" in config_network:
self.optimizer_func = config_network["optimizer"].lower()
else:
self.optimizer_func = "adam"
print("model: Using default optimizer adam")
self.learning_rate = learning_rate
self.optimizer_epsilon = optimizer_epsilon
self.gradient_clip = float(config_network["gradient_clip"])
## Model
self.d_neurons = dict()
self.z_units = dict()
self.n_layers = dict()
self.layers = dict()
self.layers_names = dict()
self.layers_params = dict()
self.shared_layer = None
# Default connections
self.connect_z = True
self.connect_topdown_dz = False
self.connect_topdown_dt = False
self.connect_horizontal = True # connect between modalities
self.connect_topdown_d = True # connect from higher to lower layers
self.connect_bottomup_d = True # connect from lowest to higher layers
# Overrides
if "connect_z" in config_network:
self.connect_z = True if config_network["connect_z"].lower() == "true" else False
if "connect_topdown_dz" in config_network:
self.connect_topdown_dz = True if config_network["connect_topdown_dz"].lower() == "true" else False
if "connect_topdown_dt" in config_network:
self.connect_topdown_dt = True if config_network["connect_topdown_dt"].lower() == "true" else False
if "connect_horizontal" in config_network:
self.connect_horizontal = True if config_network["connect_horizontal"].lower() == "true" else False
if "connect_topdown_d" in config_network:
self.connect_topdown_d = True if config_network["connect_topdown_d"].lower() == "true" else False
if "connect_bottomup_d" in config_network:
self.connect_bottomup_d = True if config_network["connect_bottomup_d"].lower() == "true" else False
self.layers_concat_input = False # use concatenation instead of addition when preparing input to neurons
self.gradient_clip_input = float(config_network["gradient_clip_input"]) # clip layer output values
self.dropout_mask_error = False if not self.training or (overrides is not None and "dropout_mask_error" in overrides and not overrides["dropout_mask_error"]) else True # use masking to manipulate reconstruction error in training and planning
if "override_d_output" in overrides and overrides["override_d_output"] is not None:
self.override_d_output = overrides["override_d_output"][1] # set to enable overriding the output of D per timestep with desired value, starting with output layer (L0)
self.override_d_output_range = overrides["override_d_output"][0]
else:
self.override_d_output = None
self.override_d_output_range = None
if "kld_range" in overrides:
self.override_kld_range = [int(x.strip()) for x in overrides["kld_range"].split(',')]
else:
self.override_kld_range = None
for m in self.modalities:
self.d_neurons[m] = [int(x.strip()) for x in config_network[m + "_layers_neurons"].split(',')]
if m + "_layers_z_units" in config_network:
self.z_units[m] = [int(x.strip()) for x in config_network[m + "_layers_z_units"].split(',')]
else:
self.z_units[m] = [int(round(float(d)/10)) for d in self.d_neurons[m]]
print("model: Using default z_units " + str(self.z_units[m]))
if m + "_layers_param" in config_network:
self.layers_params[m] = [float(x.strip()) for x in config_network[m + "_layers_param"].split(',')]
else:
self.layers_params[m] = [float(2**(i+1)) for i in range(len(self.d_neurons[m]))]
print("model: Using default layer parameters " + str(self.layers_params[m]))
# Append layer 0
self.d_neurons[m].insert(0, (self.dims[m] * self.softmax_quant))
self.z_units[m].insert(0, 0)
self.layers[m] = [None for _ in range(len(self.d_neurons[m]))] # Layer 0 is for output, no cells
self.layers_names[m] = ["l" + str(l) + "_" + m for l in range(len(self.d_neurons[m]))]
self.n_layers[m] = len(self.layers[m]) # including I/O
# Assume only the top layer might be shared
if max(self.n_layers.values()) != min(self.n_layers.values()):
self.shared_layer = max(self.n_layers, key=self.n_layers.get)
## Variational Bayes
self.vb_meta_prior = dict()
self.vb_seq_prior = dict()
for m in self.modalities:
self.vb_meta_prior[m] = [float(x.strip()) for x in config_network[m + "_meta_prior"].split(',')] # meta-prior setting
if m + "_seq_prior" in config_network:
self.vb_seq_prior[m] = [False if x.strip().lower() == "false" else True for x in config_network[m + "_seq_prior"].split(',')]
if len(self.vb_seq_prior[m]) < self.n_layers[m]-1:
self.vb_seq_prior[m] = [self.vb_seq_prior[m][0]] * (self.n_layers[m]-1)
print("model: Assuming all layers of " + m + " to be sequential prior=" + str(self.vb_seq_prior[m][0]))
else:
self.vb_seq_prior[m] = [True] * (self.n_layers[m]-1) # set to false to use unit gaussian in Z calculation
self.vb_seq_prior[m].insert(0, False) # I/O layer
if "ugaussian_t_range" in config_network:
self.vb_ugaussian_t_range = [int(x.strip()) for x in config_network["ugaussian_t_range"].split(',')]
if "ugaussian_weight" in config_network: # larger weight = less initial sensitivity
self.vb_ugaussian_weight = float(config_network["ugaussian_weight"])
else:
self.vb_ugaussian_weight = 0.001
else:
self.vb_ugaussian_t_range = None
self.vb_ugaussian_weight = None
self.vb_new_meta_prior_loss = True # use new loss calculation
self.vb_return_full_loss = config_network.get("return_full_loss", False) # returns loss per timestep per sequence (true or false, not text)
self.vb_per_t_meta_prior = False # apply W at loss calculation per timestep (always true when vb_ugaussian is used)
self.vb_zero_initial_out = True # set to true for d=0 at t=0
self.vb_reset_posterior = reset_posterior_src # reset the posterior's trained input, use in planning
self.vb_prior_output = prior_generation # use either prior or posterior in calculating output
self.vb_posterior_past_input = True if config_network.get("connect_posterior_dz", "false").lower() == "true" else False # set to true to include d_{t-1} in posterior calculation
self.vb_posterior_src_extend = True if config_network.get("posterior_map_src", "false").lower() == "true" else False # set to true apply weights to Z source
self.vb_posterior_blend_factor = 0.0 # combine posterior and prior during plan generation
self.vb_limit_sigma = False # hard clip sigma to be non-zero and not too large
self.vb_hybrid_posterior_src = False if hybrid_posterior_src is None else True # create a window of trained A values
self.vb_hybrid_posterior_src_range = hybrid_posterior_src # how many trained A samples to provide
self.vb_hybrid_posterior_src_zero_init = False # set to true to zero A in the window instead of providing a trained A window
self.vb_hybrid_posterior_src_idx_override = None if hybrid_posterior_src_idx is None else np.repeat(hybrid_posterior_src_idx, self.n_seq)
self.vb_hybrid_prior_override = overrides.get("hybrid_posterior_override", False) # set to true to use posterior for all Z calculations during window
self.vb_prior_override_l = None
self.vb_prior_override_t_range = None
self.vb_prior_override_sigma = None
self.vb_prior_override_myu = None
self.vb_prior_override_epsilon = None
if "prior_override_l" in overrides:
# set to None, True or a list of levels to override (starting at 1)
if overrides["prior_override_l"].lower() == "all" or overrides["prior_override_l"].lower() == "true":
self.vb_prior_override_l = True
elif overrides["prior_override_l"].lower() != "none" or overrides["prior_override_l"].lower() != "false":
self.vb_prior_override_l = [int(x.strip()) for x in overrides["prior_override_l"].split(',')]
if "prior_override_t_range" in overrides and overrides["prior_override_t_range"] is not None:
self.vb_prior_override_t_range = [int(x.strip()) for x in overrides["prior_override_t_range"].split(',')]
self.vb_prior_override_sigma = float(overrides["prior_override_sigma"]) if "prior_override_sigma" in overrides else None
self.vb_prior_override_myu = float(overrides["prior_override_myu"]) if "prior_override_myu" in overrides else None
self.vb_prior_override_epsilon = float(overrides["prior_override_epsilon"]) if "prior_override_epsilon" in overrides else None # set to 0 to disable noise sampling
self.vb_posterior_override_l = None
self.vb_posterior_override_t_range = None
self.vb_posterior_override_sigma = None
self.vb_posterior_override_myu = None
self.vb_posterior_override_epsilon = None
if "posterior_override_l" in overrides:
# set to None, True or a list of levels to override (starting at 1)
if overrides["posterior_override_l"].lower() == "all" or overrides["posterior_override_l"].lower() == "true":
self.vb_posterior_override_l = True
elif overrides["posterior_override_l"].lower() != "none" or overrides["posterior_override_l"].lower() != "false":
self.vb_posterior_override_l = [int(x.strip()) for x in overrides["posterior_override_l"].split(',')]
if "posterior_override_t_range" in overrides and overrides["posterior_override_t_range"] is not None:
self.vb_posterior_override_t_range = [int(x.strip()) for x in overrides["posterior_override_t_range"].split(',')]
self.vb_posterior_override_sigma = float(overrides["posterior_override_sigma"]) if "posterior_override_sigma" in overrides else None
self.vb_posterior_override_myu = float(overrides["posterior_override_myu"]) if "posterior_override_myu" in overrides else None
self.vb_posterior_override_epsilon = float(overrides["posterior_override_epsilon"]) if "posterior_override_epsilon" in overrides else None # set to 0 to disable noise sampling
# Activation functions
self.activation_func = dict()
self.z_activation_func = dict()
for m in self.modalities:
# Supported activation functions: ReLU, Leaky ReLU, Sigmoid, Extended Hyperbolic Tangent, Hyperbolic Tangent (default)
if m + "_activation_func" in config_network:
if config_network[m + "_activation_func"].lower() == "relu":
self.activation_func[m] = tf.nn.relu
elif config_network[m + "_activation_func"].lower() == "leaky_relu":
self.activation_func[m] = tf.nn.leaky_relu
elif config_network[m + "_activation_func"].lower() == "sigmoid":
self.activation_func[m] = tf.nn.sigmoid
elif config_network[m + "_activation_func"].lower() == "extended_tanh":
self.activation_func[m] = ops.extended_hyperbolic
elif config_network[m + "_activation_func"].lower() == "tanh":
self.activation_func[m] = tf.nn.tanh
else:
self.activation_func[m] = tf.nn.tanh
print("model: Unknown activation function " + config_network[m + "_activation_func"] + ", falling back to tanh")
else:
self.activation_func[m] = tf.nn.tanh
print("model: Using default activation function tanh")
self.z_activation_func[m] = self.activation_func[m]
# Layer 1+
for i in range(1,self.n_layers[m]):
with tf.compat.v1.variable_scope(self.layers_names[m][i]):
# Supported celltypes: LSTM and MTRNN (default)
if m + "_celltype" in config_network:
if config_network[m + "_celltype"].lower() == "lstm":
self.layers[m][i] = LSTM(self.d_neurons[m][i], activation=self.activation_func[m], forget_bias=self.layers_params[m][i-1])
elif config_network[m + "_celltype"].lower() == "mtrnn":
self.layers[m][i] = MTRNN(self.d_neurons[m][i], activation=self.activation_func[m], tau=self.layers_params[m][i-1])
else:
self.layers[m][i] = MTRNN(self.d_neurons[m][i], activation=self.activation_func[m], tau=self.layers_params[m][i-1])
print("model: Unknown cell type " + config_network[m + "_celltype"] + ", falling back to MTRNN")
else:
self.layers[m][i] = MTRNN(self.d_neurons[m][i], activation=self.activation_func[m], tau=self.layers_params[m][i-1])
print("model: Using default cell type MTRNN")
# Layer 0
self.input_provide_data = False # set to false to not provide any input to the network, causes cl_ratio to be ignored
self.input_cl_ratio = input_cl_ratio
self.output_z_factor = dict()
for m in self.modalities:
if m + "_output_z_factor" in config_network: # how much L1 z-units influence output for additional regularization (0.0 = disabled)
self.output_z_factor[m] = float(config_network[m + "_output_z_factor"])
else:
self.output_z_factor[m] = 0.0
self.build_model
# __init__ #
## For backwards compatibility, data must be preprocessed into the P-DVMRNN npy format first
def load_dataset(self):
data_raw = dict()
if self.planning:
data_orig_raw = dict()
for m in self.modalities:
## Motor data
if (self.training and not self.planning) or self.override_load_data:
data_raw[m] = np.load(self.path[m]) # seq, timestep, dim (softmax)
print("load_dataset: Loaded data for " + m + " with shape " + str(np.shape(data_raw[m])))
else:
data_raw[m] = np.zeros((self.n_seq, self.max_timesteps, self.dims[m]*self.softmax_quant))
print("load_dataset: Generated null dataset for " + m + " with shape " + str(np.shape(data_raw[m])))
if self.planning:
data_orig_raw[m] = np.load(self.path[m]) # seq, timestep, dim (softmax)
print("load_dataset: Loaded ground truth data for " + m + " with shape " + str(np.shape(data_orig_raw[m])))
plan_goal_start = self.planning_goal_frame + self.planning_goal_offset
plan_goal_end = None if self.planning_goal_depth is None else plan_goal_start + self.planning_goal_depth + self.planning_goal_offset
load_init_end = self.planning_initial_frame+self.planning_initial_depth if not self.planning_duplicate_initial_frame else self.planning_initial_frame+1
load_goal_end = None if self.planning_goal_depth is None else plan_goal_start + self.planning_goal_depth
# Copy initial frame(s)
if self.planning_init_modalities_mask is not None:
plan_mask_start = self.planning_init_modalities_mask[0]*self.softmax_quant
plan_mask_end = (self.planning_init_modalities_mask[1]+1)*self.softmax_quant if self.planning_init_modalities_mask[1] is not None else None
if not self.planning_duplicate_initial_frame:
data_raw[m][:, self.planning_initial_frame:self.planning_initial_depth, plan_mask_start:plan_mask_end] = data_orig_raw[m][:, self.planning_initial_frame:load_init_end, plan_mask_start:plan_mask_end]
else: # only mask duplicated frames
data_raw[m][:, self.planning_initial_frame, :] = data_orig_raw[m][:, self.planning_initial_frame, :]
data_raw[m][:, self.planning_initial_frame+1:self.planning_initial_depth, plan_mask_start:plan_mask_end] = data_orig_raw[m][:, self.planning_initial_frame:load_init_end, plan_mask_start:plan_mask_end]
else:
data_raw[m][:, self.planning_initial_frame:self.planning_initial_depth, :] = data_orig_raw[m][:, self.planning_initial_frame:load_init_end, :]
# Copy goal frame(s)
if self.planning_goal_modalities_mask is not None:
plan_mask_start = self.planning_goal_modalities_mask[0]*self.softmax_quant
plan_mask_end = (self.planning_goal_modalities_mask[1]+1)*self.softmax_quant if self.planning_goal_modalities_mask[1] is not None else None
data_raw[m][:, plan_goal_start:plan_goal_end, plan_mask_start:plan_mask_end] = data_orig_raw[m][:, self.planning_goal_frame:load_goal_end, plan_mask_start:plan_mask_end] # copy goal frame
else:
data_raw[m][:, plan_goal_start:plan_goal_end, :] = data_orig_raw[m][:, self.planning_goal_frame:load_goal_end, :] # copy goal frame
if self.planning_goal_padding:
data_raw[m][:, plan_goal_end:, :] = data_orig_raw[m][:, load_goal_end, :] # duplicate last frame
return self._load_dataset(data_raw)
## Using Dataset API
def _load_dataset(self, data_raw):
data_tensors = dict()
data_dataset = dict()
setlist = []
batch = dict()
for m in self.modalities:
data_tensors[m] = tf.convert_to_tensor(value=data_raw[m], dtype=tf.float32, name="load_data_tensors") # [seq, step, dim*quant_level]
data_dataset[m] = tf.data.Dataset.from_tensor_slices(data_tensors[m])
data_shape = np.shape(next(iter(data_raw.values())))
# Override defined number of sequences with real value
if self.n_seq == self.batch_size or data_shape[0] < self.batch_size:
self.batch_size = data_shape[0] # update batch size
self.n_seq = data_shape[0]
if self.n_seq < self.batch_size:
self.batch_size = data_shape[0] # fix batch size
# Override predefined number of timesteps with real value
self.max_timesteps = data_shape[1]
# Build index
idxs = [[j for _ in range(self.max_timesteps)] for j in range(self.n_seq)]
idx_tensors = tf.convert_to_tensor(value=idxs, dtype=tf.int32, name="load_idx_tensors")
idx_dataset = tf.data.Dataset.from_tensor_slices(idx_tensors)
# Batching
setlist.append(idx_dataset)
for m in self.modalities:
setlist.append(data_dataset[m])
dataset = tf.data.Dataset.zip(tuple(setlist))
if self.training and not self.planning:
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.repeat()
dataset = dataset.batch(self.batch_size, drop_remainder=True)
dataset = dataset.prefetch(buffer_size=self.n_seq//self.batch_size)
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
batch = iterator.get_next()
out = dict()
out["idx_data"] = idx_tensors
out["idx_next"] = batch[0]
for i in range(len(self.modalities)):
out[self.modalities[i] + "_data"] = data_tensors[self.modalities[i]]
out[self.modalities[i] + "_next"] = batch[i+1]
out["timesteps"] = self.max_timesteps
return out
@lazy_property
def build_model(self):
## Dataset
data = self.load_dataset()
if not self.mask_seq:
data_mask_reg = None
else:
data_mask_reg = ops.data_mask(data[self.modalities[0] + "_next"]) # mask for different sequence lengths
if not self.input_provide_data:
data_mask_rec = data_mask_reg
else:
data_mask_rec = ops.data_mask(data[self.modalities[0] + "_next"], skip_ahead=1)
## Initialize initial states for all network units
with tf.compat.v1.variable_scope('initial_state', reuse=tf.compat.v1.AUTO_REUSE):
init_data = data if self.input_provide_data else None
initial_states = ops.set_trainable_initial_states(self.modalities, init_data, self.batch_size, self.d_neurons, self.z_units)
## Train model
with tf.compat.v1.variable_scope("training"):
data_next = []
data_next.append(tf.transpose(a=data["idx_next"], perm=[1, 0], name="transpose_idx_next"))
if self.input_provide_data:
for m in self.modalities:
data_next.append(tf.transpose(a=data[m + "_next"], perm=[1, 0, 2], name="transpose_data_next")) # step, batch, dim
step_data = tuple(data_next)
# Run one epoch, all timesteps
output_model = functional_ops.scan(self.build_model_one_step_scan, step_data, initializer=initial_states, parallel_iterations=self.batch_size)
# Collect output (dicts)
generated_out = {m: tf.transpose(a=output_model["out"][m][0], perm=[1, 0, 2], name="transpose_generated_out") for m in self.modalities} # output layer 0: seq, step, dim
generated_z_prior = dict()
generated_z_prior_mean = dict()
generated_z_prior_var = dict()
generated_z_posterior = dict()
generated_z_posterior_mean = dict()
generated_z_posterior_var = dict()
generated_z_posterior_src = dict()
for m in self.modalities:
if max(self.vb_meta_prior[m]) >= 0:
generated_z_prior[m] = [tf.transpose(a=output_model["z_prior"][m][i], perm=[1, 0, 2], name="transpose_generated_z_p") for i in range(len(output_model["z_prior"][m]))]
generated_z_prior_mean[m] = [tf.transpose(a=output_model["z_prior_mean"][m][i], perm=[1, 0, 2], name="transpose_generated_zm_p") for i in range(len(output_model["z_prior_mean"][m]))]
generated_z_prior_var[m] = [tf.transpose(a=output_model["z_prior_var"][m][i], perm=[1, 0, 2], name="transpose_generated_zv_p") for i in range(len(output_model["z_prior_var"][m]))]
generated_z_posterior[m] = [tf.transpose(a=output_model["z_posterior"][m][i], perm=[1, 0, 2], name="transpose_generated_z_q") for i in range(len(output_model["z_posterior"][m]))]
generated_z_posterior_mean[m] = [tf.transpose(a=output_model["z_posterior_mean"][m][i], perm=[1, 0, 2], name="transpose_generated_zm_q") for i in range(len(output_model["z_posterior_mean"][m]))]
generated_z_posterior_var[m] = [tf.transpose(a=output_model["z_posterior_var"][m][i], perm=[1, 0, 2], name="transpose_generated_zv_q") for i in range(len(output_model["z_posterior_var"][m]))]
if any(self.vb_seq_prior[m]):
with tf.compat.v1.variable_scope("model_variables", reuse=True):
src_z = max(self.z_units[m]) if self.vb_posterior_src_extend else max(self.z_units[m])*2
if not self.vb_hybrid_posterior_src:
if not self.vb_reset_posterior:
z_posterior_src_var_name = m + "_z_posterior_src"
else:
z_posterior_src_var_name = m + "_z_posterior_src_zero"
generated_z_posterior_src[m] = tf.compat.v1.get_variable(z_posterior_src_var_name, shape=[self.max_timesteps, self.n_layers[m], self.n_seq, src_z])
else:
trained_src = tf.compat.v1.get_variable(m + "_z_posterior_src", shape=[self.max_timesteps, self.n_layers[m], self.n_seq, src_z])
zero_src = tf.compat.v1.get_variable(m + "_z_posterior_src_zero", shape=[self.max_timesteps, self.n_layers[m], self.n_seq, src_z])
initial_start = self.vb_hybrid_posterior_src_range[0]
initial_end = self.vb_hybrid_posterior_src_range[1]
if not self.vb_hybrid_posterior_src_zero_init:
generated_z_posterior_src[m] = tf.concat([trained_src[initial_start:initial_end, :, :, :], zero_src[initial_end:, :, :, :]], axis=0, name="concat_hybrid_src_tz")
else:
generated_z_posterior_src[m] = tf.concat([zero_src[initial_start:initial_end, :, :, :], trained_src[initial_end:, :, :, :]], axis=0, name="concat_hybrid_src_zt")
## Calculate loss per modality
batch_reconstruction_loss = dict.fromkeys(self.modalities, tf.constant(0.0))
batch_regularization_loss = dict.fromkeys(self.modalities, tf.constant(0.0))
for m in self.modalities:
## Reconstruction loss
if not self.planning:
if not self.dropout_mask_error:
rec_loss = ops.kld_with_mask(data[m + "_next"][:, :, :], generated_out[m][:, :data["timesteps"], :], data_mask_rec)
if self.vb_return_full_loss:
batch_reconstruction_loss[m] = (rec_loss[0] / float(self.dims[m]) / float(self.softmax_quant), rec_loss[1] / float(self.dims[m]) / float(self.softmax_quant))
else:
batch_reconstruction_loss[m] = (rec_loss[0] / float(self.dims[m]) / float(self.softmax_quant), rec_loss[1])
else:
dropout_mask = ops.dropout_mask([self.batch_size, self.max_timesteps])
rec_loss = ops.kld_with_mask(data[m + "_next"][:, :, :], generated_out[m][:, :data["timesteps"], :], dropout_mask)
if self.vb_return_full_loss:
batch_reconstruction_loss[m] = (rec_loss[0] * 2.0 / float(self.dims[m]) / float(self.softmax_quant), rec_loss[1] * 2.0 / float(self.dims[m]) / float(self.softmax_quant))
else:
batch_reconstruction_loss[m] = (rec_loss[0] * 2.0 / float(self.dims[m]) / float(self.softmax_quant), rec_loss[1])
else: # Rec loss only exists at initial and goal frames
selected_loss_idx = -1
selected_loss = tf.constant(sys.float_info.max)
selected_loss_rec = selected_loss
selected_loss_reg = selected_loss
if m in self.planning_init_modalities:
plan_iframe_start = self.planning_initial_frame
plan_iframe_end = self.planning_initial_frame + self.planning_initial_depth
if m in self.planning_goal_modalities:
plan_gframe_start = self.planning_goal_frame + self.planning_goal_offset
plan_gframe_end = None if self.planning_goal_depth is None else plan_gframe_start + self.planning_goal_depth + self.planning_goal_offset
if self.planning_goal_modalities_mask is not None:
plan_mask_start = self.planning_goal_modalities_mask[0]*self.softmax_quant
plan_mask_end = (self.planning_goal_modalities_mask[1]+1)*self.softmax_quant if self.planning_goal_modalities_mask[1] != -1 else -1
else:
plan_mask_start = None
plan_mask_end = None
error_weight = 1.0
if m not in self.planning_init_modalities and m not in self.planning_goal_modalities:
batch_reconstruction_loss[m] = 0.0
else:
if m in self.planning_init_modalities and self.planning_auto_weight > 0:
error_weight += self.planning_initial_depth
if m in self.planning_goal_modalities:
if plan_mask_end == -1:
plan_mask_end = self.dims[m]*self.softmax_quant
if self.planning_auto_weight > 0:
error_weight += 1.0
if self.planning_auto_weight > 0:
error_weight = (self.max_timesteps * self.planning_auto_weight) / error_weight
if self.planning_goal_modalities_mask is not None:
dmask1 = [1 if d >= plan_mask_start and d < plan_mask_end else 0 for d in range(self.dims[m]*self.softmax_quant)]
planning_mask = ops.windowed_dmask(dmask1, [self.n_seq, self.max_timesteps, self.dims[m]*self.softmax_quant], start=[plan_iframe_start, plan_iframe_end], end=[plan_gframe_start, plan_gframe_end], end_zeropad=(not self.planning_goal_padding))
rec_loss = ops.kld_with_mask(data[m + "_next"][:, :, :], generated_out[m][:, :data["timesteps"], :], dmask=planning_mask)
batch_reconstruction_loss[m] = rec_loss # (reduced, all sequences)
error_weight += 1.0
else:
planning_mask = ops.windowed_mask([self.batch_size, self.max_timesteps], start=[plan_iframe_start, plan_iframe_end], end=[plan_gframe_start, plan_gframe_end], end_zeropad=(not self.planning_goal_padding))
rec_loss = ops.kld_with_mask(data[m + "_next"][:, :, :], generated_out[m][:, :data["timesteps"], :], mask=planning_mask)
batch_reconstruction_loss[m] = rec_loss # (reduced, all sequences)
# Average
batch_reconstruction_loss[m] = (batch_reconstruction_loss[m][0] * error_weight / float(self.dims[m]) / self.softmax_quant, batch_reconstruction_loss[m][1] * error_weight / float(self.dims[m]) / self.softmax_quant)
## Regularization loss (per layer)
if max(self.vb_meta_prior[m]) >= 0:
if self.override_kld_range is not None:
# Override mask
reg_mask = ops.windowed_mask([self.batch_size, self.max_timesteps], start=self.override_kld_range, end=[0,0])
else:
reg_mask = data_mask_reg
batch_regularization_loss[m] = ops.vb_kld_with_mask(output_model, reg_mask, self.z_units[m], m, self.vb_seq_prior[m], self.vb_ugaussian_t_range, self.vb_ugaussian_weight, self.vb_meta_prior[m], seq_kld_weight_by_t=self.vb_per_t_meta_prior)
## Calculate total loss
total_batch_loss = tf.constant(0.0)
total_batch_reconstruction_loss = tf.constant(0.0)
total_batch_regularization_loss = tf.constant(0.0)
# Find least loss for planner
if self.planning:
full_batch_reconstruction_loss = tf.zeros(tf.shape(input=batch_reconstruction_loss[m][1])[0])
full_batch_regularization_loss = tf.zeros_like(full_batch_reconstruction_loss)
full_batch_loss = tf.zeros_like(full_batch_reconstruction_loss)
for m in self.modalities:
# Reconstruction loss
rec_loss = batch_reconstruction_loss[m][0]
total_batch_reconstruction_loss += rec_loss
if self.planning:
rec_loss_seq = tf.reduce_sum(input_tensor=batch_reconstruction_loss[m][1], axis=1, name="reduce_recloss")
full_batch_reconstruction_loss += rec_loss_seq
# Regularization loss
if self.vb_meta_prior[m][0] != -1:
zs = len(self.z_units[m])-1
for i in range(1, zs+1):
reg_loss = batch_regularization_loss[m][i][0]
total_batch_regularization_loss += reg_loss
if self.vb_ugaussian_t_range is not None or self.vb_per_t_meta_prior: # W is applied in KLD calculation
total_batch_loss += rec_loss/zs - reg_loss
else: # apply W to the whole sequence
W = self.vb_meta_prior[m][i-1] if self.vb_meta_prior[m][i-1] >= 0 else 0.0
if self.vb_new_meta_prior_loss:
total_batch_loss += rec_loss/zs - W * reg_loss
else:
total_batch_loss += ((1.0 - W) * (rec_loss/zs)) - (W * reg_loss)
if self.planning:
for i in range(1, zs+1):
reg_loss_seq = tf.reduce_sum(input_tensor=batch_regularization_loss[m][i][1], axis=1, name="reduce_regloss")
full_batch_regularization_loss += reg_loss_seq
if self.vb_ugaussian_t_range is not None or self.vb_per_t_meta_prior: # W is applied in KLD calculation
full_batch_loss += rec_loss_seq/zs - reg_loss_seq
else: # apply W to the whole sequence
W = self.vb_meta_prior[m][i-1] if self.vb_meta_prior[m][i-1] >= 0 else 0.0
if self.vb_new_meta_prior_loss:
full_batch_loss += rec_loss_seq/zs - W * reg_loss_seq
else:
full_batch_loss += ((1.0 - W) * (rec_loss_seq/zs)) - (W * reg_loss_seq)
else:
total_batch_loss = total_batch_reconstruction_loss
# Select least loss as recommended plan
if self.planning:
selected_loss_idx = tf.argmin(input=full_batch_loss, name="argmin_full_loss")
selected_loss = full_batch_loss[selected_loss_idx]
selected_loss_rec = full_batch_reconstruction_loss[selected_loss_idx]
selected_loss_reg = full_batch_regularization_loss[selected_loss_idx]
## Run optimizer (backprop)
self.model_train_var = tf.compat.v1.trainable_variables()
if self.training and not self.planning:
deselect_var_name = ["training/model_variables/" + m + "_z_posterior_src_zero:0" for m in self.modalities] # don't train src_zero here
opt_train_var = [var for var in self.model_train_var if var.name not in deselect_var_name]
elif self.planning:
if not self.vb_hybrid_posterior_src:
if not self.vb_reset_posterior:
select_var_name = ["training/model_variables/" + m + "_z_posterior_src:0" for m in self.modalities]
else:
select_var_name = ["training/model_variables/" + m + "_z_posterior_src_zero:0" for m in self.modalities]
else:
select_var_name = ["training/model_variables/" + m + "_z_posterior_src:0" for m in self.modalities]
select_var_name.append(["training/model_variables/" + m + "_z_posterior_src_zero:0" for m in self.modalities])
if self.vb_posterior_src_extend:
select_var_name.append([var for var in self.model_train_var if "z_posterior_from_src" in var.name])
# select_var_name = []
# for v in self.model_train_var:
# select_var_name.append(v.name)
opt_train_var = [var for var in self.model_train_var if var.name in select_var_name]
else:
opt_train_var = None
print("**Trainable variables in use**")
vidx = 0
if opt_train_var is not None:
for v in opt_train_var:
print(str(vidx) + " " + str(v.name) + " " + str(v.get_shape()))
vidx += 1
else:
print("None")
if self.training:
# Supported optimizers: ADAM (default), Gradient descent, Momentum, Adagrad, RMSProp
if self.optimizer_func == "gradient_descent":
optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=self.learning_rate)
elif self.optimizer_func == "momentum":
optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=0.9, use_nesterov=True)
elif self.optimizer_func == "adagrad":
optimizer = tf.compat.v1.train.AdagradOptimizer(learning_rate=self.learning_rate)
elif self.optimizer_func == "rmsprop":
optimizer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=self.learning_rate, epsilon=self.optimizer_epsilon)#, decay=0.1)#, momentum=0.5)
elif self.optimizer_func == "rmsprop_momentum":
optimizer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=self.learning_rate, epsilon=self.optimizer_epsilon, momentum=0.5, centered=True)
elif self.optimizer_func == "adam":
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate, epsilon=self.optimizer_epsilon)
gradients, variables = list(zip(*optimizer.compute_gradients(total_batch_loss, var_list=opt_train_var)))
if self.gradient_clip > 0:
pruned_gradients = [tf.compat.v1.where(tf.math.is_nan(grad), tf.zeros_like(grad), grad) if grad is not None else None for grad in gradients]
pruned_gradients = [tf.compat.v1.where(tf.math.is_inf(grad), tf.constant(self.gradient_clip, shape=np.shape(grad)), grad) if grad is not None else None for grad in pruned_gradients]
clipped_gradients, _ = tf.clip_by_global_norm(pruned_gradients, self.gradient_clip)
training_batch = optimizer.apply_gradients(list(zip(clipped_gradients, variables)))
else:
training_batch = optimizer.apply_gradients(list(zip(gradients, variables)))
else:
optimizer = None
training_batch = None
## Save model summary
with tf.compat.v1.name_scope("loss"):
for m in self.modalities:
rec_loss = batch_reconstruction_loss[m][0]
tf.compat.v1.summary.scalar(m + "_batch_reconstruction_loss", rec_loss / data["timesteps"])
if self.vb_meta_prior[m][0] != -1 and self.training:
for i in range(1, len(self.z_units[m])):
reg_loss = batch_regularization_loss[m][i][0]
tf.compat.v1.summary.scalar(m + "_z" + str(i) + "_batch_regularization_loss", -reg_loss / data["timesteps"])
tf.compat.v1.summary.scalar("total_batch_loss", total_batch_loss / data["timesteps"])
self.saver = tf.compat.v1.train.Saver(var_list=self.model_train_var, max_to_keep=None)
## Output the model
model = dict()
model["data"] = data
model["data_length"] = data["timesteps"]
model["generated_out"] = {m: [tf.transpose(a=output_model["out"][m][i], perm=[1, 0, 2], name="transpose_generated_out_all") for i in range(len(output_model["out"][m]))] for m in self.modalities}
model["initial"] = {m: [tf.transpose(a=output_model["out_initial"][m][i], perm=[1, 0, 2], name="transpose_initial") for i in range(len(output_model["out_initial"][m]))] for m in self.modalities}
model["generated_z_prior_mean"] = generated_z_prior_mean
model["generated_z_prior_var"] = generated_z_prior_var
model["generated_z_prior"] = generated_z_prior
model["generated_z_posterior_mean"] = generated_z_posterior_mean
model["generated_z_posterior_var"] = generated_z_posterior_var
model["generated_z_posterior"] = generated_z_posterior
model["generated_z_posterior_src"] = generated_z_posterior_src
model["batch_reconstruction_loss"] = {m: batch_reconstruction_loss[m] for m in self.modalities}
model["batch_regularization_loss"] = {m: batch_regularization_loss[m] for m in self.modalities}
model["total_batch_reconstruction_loss"] = total_batch_reconstruction_loss
model["total_batch_regularization_loss"] = total_batch_regularization_loss
model["total_batch_loss"] = total_batch_loss
model["initial_states"] = initial_states
model["training_batch"] = training_batch
# if self.training:
# model["gradients"] = {variables[i].name: gradients[i].values for i in xrange(len(gradients))}
if self.planning:
model["selected_loss_idx"] = selected_loss_idx
model["selected_loss"] = selected_loss
model["selected_loss_rec"] = selected_loss_rec
model["selected_loss_reg"] = selected_loss_reg
return model
# build_model #
def build_model_one_step_scan(self, previous_states, current_input):
with tf.compat.v1.variable_scope('model_variables', reuse=tf.compat.v1.AUTO_REUSE):
return self.build_model_one_step(previous_states, current_input)
def calculate_z_prior(self, idx_layer, out, modality, scope=None, override_l=None, override_sigma=None, override_myu=None, override_epsilon=None):
scope_name = 'l' + str(idx_layer) + '_' + modality
if scope is not None:
scope_name += '_' + str(scope)
if override_l is not None and (override_l == True or idx_layer in override_l):
return ops.calculate_z_prior(idx_layer, out, self.d_neurons[modality], self.z_units[modality], self.batch_size, self.z_activation_func[modality], scope_name, override_sigma=override_sigma, override_myu=override_myu, override_epsilon=override_epsilon)
else:
return ops.calculate_z_prior(idx_layer, out, self.d_neurons[modality], self.z_units[modality], self.batch_size, self.z_activation_func[modality], scope_name, limit_sigma=self.vb_limit_sigma)
def calculate_z_posterior(self, idx_layer, out, source, modality, scope=None, override_l=None, override_sigma=None, override_myu=None, override_epsilon=None):
scope_name = 'l' + str(idx_layer) + '_' + modality
if scope is not None:
scope_name += '_' + str(scope)
if self.vb_posterior_past_input:
posterior_in = out
else:
posterior_in = None
if override_l is not None and (override_l == True or idx_layer in override_l):
return ops.calculate_z_posterior(idx_layer, posterior_in, source, self.d_neurons[modality], self.z_units[modality], self.batch_size, self.z_activation_func[modality], scope_name, override_sigma=override_sigma, override_myu=override_myu, override_epsilon=override_epsilon, source_extend=self.vb_posterior_src_extend)
else:
return ops.calculate_z_posterior(idx_layer, posterior_in, source, self.d_neurons[modality], self.z_units[modality], self.batch_size, self.z_activation_func[modality], scope_name, limit_sigma=self.vb_limit_sigma, source_extend=self.vb_posterior_src_extend)
def build_model_one_step(self, previous_states, current_input):
if self.vb_hybrid_posterior_src_idx_override is None:
idx_seq = current_input[0]
else:
idx_seq = tf.constant(self.vb_hybrid_posterior_src_idx_override)
t_step = previous_states["t_step"]
out_initial = previous_states["out_initial"]
out = {m: self.n_layers[m]*[None] for m in self.modalities}
state = {m: self.n_layers[m]*[None] for m in self.modalities}
z_prior_mean = {m: self.n_layers[m]*[None] for m in self.modalities}
z_prior_var = {m: self.n_layers[m]*[None] for m in self.modalities}
z_prior = {m: self.n_layers[m]*[None] for m in self.modalities}
z_posterior_mean = {m: self.n_layers[m]*[None] for m in self.modalities}
z_posterior_var = {m: self.n_layers[m]*[None] for m in self.modalities}
z_posterior = {m: self.n_layers[m]*[None] for m in self.modalities}
if self.override_d_output is not None:
for m in self.modalities:
for i in range(max(self.n_layers.values())-1, 0, -1):
fixed_d = tf.reshape(tf.tile(tf.gather(tf.convert_to_tensor(value=self.override_d_output[m][i], name="load_override_d"), t_step), [self.n_seq]), [self.n_seq, self.d_neurons[m][i]])
previous_states["out"][m][i] = tf.compat.v1.where(tf.logical_and(tf.greater_equal(t_step, self.override_d_output_range[0]), tf.less(t_step, self.override_d_output_range[1])), fixed_d, previous_states["out"][m][i], name="where_override_d_range")
if not self.vb_zero_initial_out:
for m in self.modalities:
out_initial_var = tf.compat.v1.get_variable(m + "_out_initial", shape=[1], initializer=tf.compat.v1.zeros_initializer, trainable=self.training) # TODO: reuse for continued generation?
previous_states["out"][m] = tf.compat.v1.where(tf.equal(t_step, 0), out_initial_var, previous_states["out"][m], name="where_initial_d_check")
# Layers 1+
for i in range(max(self.n_layers.values())-1, 0, -1):
current_z_logits = {m: None for m in self.modalities}
higher_level_out_logits = {m: None for m in self.modalities}
lower_level_out_logits = {m: None for m in self.modalities}
current_level_out_logits = {m: None for m in self.modalities}
for mi, m in enumerate(self.modalities):
if i > self.n_layers[m]-1:
continue
lower_level_out = None
higher_level_out = None
## Collect previous timestep outputs
# Input from lower levels
if i != 1 and self.connect_bottomup_d:
# ll input from this modality
lower_level_out = previous_states["out"][m][i-1]
# If this is a shared layer, gather ll input from all modalities
if i == self.n_layers[m]-1 and self.shared_layer == m:
for x in self.modalities:
if x == m:
continue
else:
lower_level_xout = previous_states["out"][x][i-1]
if self.layers_concat_input:
lower_level_out = tf.concat([lower_level_out, lower_level_xout], axis=1, name="concat_ll_out")
else:
lower_level_out = tf.add_n([lower_level_out, lower_level_xout], name="addn_ll_out")
else: # lowest level
if self.input_provide_data:
lower_level_out = tf.add(self.input_cl_ratio * previous_states["out"][m][0], (1.0 - self.input_cl_ratio) * current_input[mi+1], name="add_ll_inmix") # Mix input and previous output
# else nothing enters the lowest level
# Input from higher levels
if i < self.n_layers[m]-1 and self.connect_topdown_d:
# hl input from this modality
if not self.connect_topdown_dt:
higher_level_out = previous_states["out"][m][i+1]
else:
higher_level_out = out[m][i+1]
if i == max(self.n_layers.values())-2 and self.shared_layer is not None and m != self.shared_layer: # top layer-1
if not self.connect_topdown_dt:
higher_level_xout = previous_states["out"][self.shared_layer][i+1]
else:
higher_level_xout = out[self.shared_layer][i+1]
if self.layers_concat_input:
higher_level_out = tf.concat([higher_level_out, higher_level_xout], axis=1, name="concat_hl_out")
else:
higher_level_out = tf.add_n([higher_level_out, higher_level_xout], name="addn_hl_out")
# Input from current level (previous timestep)
current_level_out = previous_states["out"][m][i]
with tf.compat.v1.variable_scope('l' + str(i) + '_' + m):
lower_level_out_logits[m] = _linear([lower_level_out], self.d_neurons[m][i], bias=True, scope_here=m+"_ll_to_cell") if lower_level_out is not None else tf.zeros([self.batch_size, self.d_neurons[m][i]])
higher_level_out_logits[m] = _linear([higher_level_out], self.d_neurons[m][i], bias=True, scope_here=m+"_hl_to_cell") if higher_level_out is not None else tf.zeros([self.batch_size, self.d_neurons[m][i]])
current_level_out_logits[m] = _linear([current_level_out], self.d_neurons[m][i], bias=True, scope_here=m+"_cl_to_cell")
if self.connect_topdown_dz and i < self.n_layers[m]-1:
# Independent of connect_d
if not self.connect_topdown_dt:
d_to_z = previous_states["out"][m][i+1]
else:
d_to_z = out[m][i+1]
if i == max(self.n_layers.values())-2 and self.shared_layer is not None and m != self.shared_layer: # top layer-1
if not self.connect_topdown_dt:
higher_level_xout = previous_states["out"][self.shared_layer][i+1]
else:
higher_level_xout = out[self.shared_layer][i+1]
if self.layers_concat_input:
d_to_z = tf.concat([higher_level_out, higher_level_xout], axis=1, name="concat_dtoz")
else:
d_to_z = tf.add_n([higher_level_out, higher_level_xout], name="addn_dtoz")
else:
d_to_z = current_level_out
if self.vb_meta_prior[m][i-1] >= 0 and self.connect_z:
## Calculate prior
if self.vb_ugaussian_t_range is None:
if self.vb_prior_override_t_range is None:
current_z_prior, current_z_prior_mean, current_z_prior_var = self.calculate_z_prior(i, d_to_z, m, override_l=self.vb_prior_override_l, override_sigma=self.vb_prior_override_sigma, override_myu=self.vb_prior_override_myu, override_epsilon=self.vb_prior_override_epsilon)
else:
current_z_prior, current_z_prior_mean, current_z_prior_var = tf.cond(pred=tf.logical_and(tf.greater_equal(t_step, self.vb_prior_override_t_range[0]), tf.less(t_step, self.vb_prior_override_t_range[1])), true_fn=lambda: self.calculate_z_prior(i, d_to_z, m, override_l=self.vb_prior_override_l, override_sigma=self.vb_prior_override_sigma, override_myu=self.vb_prior_override_myu, override_epsilon=self.vb_prior_override_epsilon), false_fn=lambda: self.calculate_z_prior(i, d_to_z, m), name="cond_z_p_range_t")
else:
if self.vb_prior_override_t_range is None:
current_z_prior, current_z_prior_mean, current_z_prior_var = tf.cond(pred=tf.logical_and(tf.greater_equal(t_step, self.vb_ugaussian_t_range[0]), tf.less(t_step, self.vb_ugaussian_t_range[1])), true_fn=lambda: self.calculate_z_prior(i, d_to_z, m, override_l=True, override_sigma=1.0, override_myu=0.0), false_fn=lambda: self.calculate_z_prior(i, d_to_z, m, override_l=self.vb_prior_override_l, override_sigma=self.vb_prior_override_sigma, override_myu=self.vb_prior_override_myu, override_epsilon=self.vb_prior_override_epsilon), name="cond_z_p_range_u")
else:
current_z_prior, current_z_prior_mean, current_z_prior_var = tf.cond(pred=tf.logical_and(tf.greater_equal(t_step, self.vb_ugaussian_t_range[0]), tf.less(t_step, self.vb_ugaussian_t_range[1])), true_fn=lambda: self.calculate_z_prior(i, d_to_z, m, override_l=True, override_sigma=1.0, override_myu=0.0), false_fn=lambda: tf.cond(pred=tf.logical_and(tf.greater_equal(t_step, self.vb_prior_override_t_range[0]), tf.less(t_step, self.vb_prior_override_t_range[1])), true_fn=lambda: self.calculate_z_prior(i, d_to_z, m, override_l=self.vb_prior_override_l, override_sigma=self.vb_prior_override_sigma, override_myu=self.vb_prior_override_myu, override_epsilon=self.vb_prior_override_epsilon), false_fn=lambda: self.calculate_z_prior(i, d_to_z, m)), name="cond_z_p_range_tu")
if self.vb_seq_prior[m][i]:
## Calculate posterior
# Load the correct posterior source
src_z = max(self.z_units[m]) if self.vb_posterior_src_extend else max(self.z_units[m])*2
if not self.vb_hybrid_posterior_src:
_ = tf.compat.v1.get_variable(m + "_z_posterior_src_zero", shape=[self.max_timesteps, self.n_layers[m], self.n_seq, src_z], initializer=tf.compat.v1.zeros_initializer, trainable=True) # this src is reserved in case we don't want to use the primary
if not self.vb_reset_posterior:
z_posterior_src_var_name = m + "_z_posterior_src"
else:
z_posterior_src_var_name = m + "_z_posterior_src_zero"
full_z_posterior_src = tf.compat.v1.get_variable(z_posterior_src_var_name, shape=[self.max_timesteps, self.n_layers[m], self.n_seq, src_z], initializer=tf.compat.v1.zeros_initializer, trainable=True)
else:
initial_start = self.vb_hybrid_posterior_src_range[0]
initial_end = self.vb_hybrid_posterior_src_range[1]
alt_z_posterior_src = tf.compat.v1.get_variable(m + "_z_posterior_src_zero", shape=[self.max_timesteps, self.n_layers[m], self.n_seq, src_z], initializer=tf.compat.v1.zeros_initializer, trainable=True)
z_posterior_src = tf.compat.v1.get_variable(m + "_z_posterior_src", shape=[self.max_timesteps, self.n_layers[m], self.n_seq, src_z], initializer=tf.compat.v1.zeros_initializer, trainable=True)
if not self.vb_hybrid_posterior_src_zero_init:
full_z_posterior_src = tf.concat([z_posterior_src[initial_start:initial_end, :, :, :], alt_z_posterior_src[initial_end:, :, :, :]], axis=0, name="concat_z_qsrc_hybrid0")
else:
full_z_posterior_src = tf.concat([alt_z_posterior_src[initial_start:initial_end, :, :, :], z_posterior_src[initial_end:, :, :, :]], axis=0, name="concat_z_qsrc_hybrida")
z_posterior_src = tf.gather(full_z_posterior_src, t_step, name="gather_z_qsrc_t")
current_z_posterior_src = tf.gather(z_posterior_src[i], idx_seq, name="gather_z_qsrc_idx") # reorder posterior d of this layer to match data sequences
if self.vb_posterior_override_t_range is None:
current_z_posterior, current_z_posterior_mean, current_z_posterior_var = self.calculate_z_posterior(i, d_to_z, current_z_posterior_src, m, override_l=self.vb_posterior_override_l, override_sigma=self.vb_posterior_override_sigma, override_myu=self.vb_posterior_override_myu, override_epsilon=self.vb_posterior_override_epsilon)
else:
current_z_posterior, current_z_posterior_mean, current_z_posterior_var = tf.cond(pred=tf.logical_and(tf.greater_equal(t_step, self.vb_posterior_override_t_range[0]), tf.less(t_step, self.vb_posterior_override_t_range[1])), true_fn=lambda: self.calculate_z_posterior(i, d_to_z, current_z_posterior_src, m, override_l=self.vb_posterior_override_l, override_sigma=self.vb_posterior_override_sigma, override_myu=self.vb_posterior_override_myu, override_epsilon=self.vb_posterior_override_epsilon), false_fn=lambda: self.calculate_z_posterior(i, d_to_z, current_z_posterior_src, m), name="cond_z_q_range")
else:
current_z_posterior = None
current_z_posterior_mean = previous_states["z_posterior_mean"][m][i]
current_z_posterior_var = previous_states["z_posterior_var"][m][i]
else: # No-op
current_z_prior = None
current_z_prior_mean = previous_states["z_prior_mean"][m][i]
current_z_prior_var = previous_states["z_prior_var"][m][i]
current_z_posterior = None
current_z_posterior_mean = previous_states["z_posterior_mean"][m][i]
current_z_posterior_var = previous_states["z_posterior_var"][m][i]
# Select current Z
if self.vb_seq_prior[m][i]:
if self.vb_posterior_blend_factor > 0.0:
current_z = tf.add_n([tf.multiply(current_z_prior, 1-self.vb_posterior_blend_factor), tf.multiply(current_z_posterior, self.vb_posterior_blend_factor)], name="addn_z_pq_blend")
elif self.vb_hybrid_posterior_src:
initial_start = self.vb_hybrid_posterior_src_range[0]
initial_end = self.vb_hybrid_posterior_src_range[1]
if self.vb_hybrid_prior_override:
current_z_prior = tf.compat.v1.where(tf.logical_and(tf.greater_equal(t_step, initial_start), tf.less(t_step, initial_end)), current_z_posterior, current_z_prior, name="where_z_p_range")
current_z_prior_mean = tf.compat.v1.where(tf.logical_and(tf.greater_equal(t_step, initial_start), tf.less(t_step, initial_end)), current_z_posterior_mean, current_z_prior_mean, name="where_zm_p_range")
current_z_prior_var = tf.compat.v1.where(tf.logical_and(tf.greater_equal(t_step, initial_start), tf.less(t_step, initial_end)), current_z_posterior_var, current_z_prior_var, name="where_zv_p_range")
current_z = tf.compat.v1.where(tf.logical_and(tf.greater_equal(t_step, initial_start), tf.less(t_step, initial_end)), current_z_posterior, current_z_prior, name="where_z_pq_range")
else:
current_z = current_z_posterior if not self.vb_prior_output else current_z_prior
else:
current_z = current_z_prior
z_prior[m][i] = current_z_prior if current_z_prior is not None else previous_states["z_prior"][m][i]
z_prior_mean[m][i] = current_z_prior_mean
z_prior_var[m][i] = current_z_prior_var
z_posterior[m][i] = current_z_posterior if current_z_posterior is not None else previous_states["z_posterior"][m][i]
z_posterior_mean[m][i] = current_z_posterior_mean
z_posterior_var[m][i] = current_z_posterior_var
with tf.compat.v1.variable_scope('l' + str(i) + '_' + m):
current_z_logits[m] = _linear([current_z], self.d_neurons[m][i], bias=True, scope_here=m+"_z_to_cell") if current_z is not None else tf.zeros([self.batch_size, self.d_neurons[m][i]])
## Synchronize level
for m in self.modalities:
if i > self.n_layers[m]-1:
continue
# Add horizontal and vertical connections in this layer
if self.gradient_clip_input == -1: # special case: layer normalization
# L2 norm on d and z separately
z_logits_norm = tf.nn.l2_normalize(current_z_logits[m], axis=1, name="normalize_z")
if self.layers_concat_input:
d_logits = tf.concat([lower_level_out_logits[m], higher_level_out_logits[m], current_level_out_logits[m]], axis=1, name="concat_d")
else:
d_logits = tf.add_n([lower_level_out_logits[m], higher_level_out_logits[m], current_level_out_logits[m]], name="addn_d")
d_logits_norm = tf.nn.l2_normalize(d_logits, axis=1)
level_output = [z_logits_norm, d_logits_norm]
else:
level_output = [current_z_logits[m], lower_level_out_logits[m], higher_level_out_logits[m], current_level_out_logits[m]]
if self.connect_horizontal:
for x in self.modalities:
if x == m:
continue
if current_level_out_logits[x] is not None:
level_output.append(current_level_out[x]) # current level output from all modalities
if self.layers_concat_input:
sum_level_output = tf.concat(level_output, axis=1, name="concat_l_out")
else:
sum_level_output = tf.add_n(level_output, name="addn_l_out")
# There's no gradient here, but keeping it for bc
if self.gradient_clip_input > 0: # clip input
sum_level_output = tf.clip_by_norm(sum_level_output, self.gradient_clip_input, name="clip_l_out")
# Finally compute D
out[m][i], state[m][i], _ = self.layers[m][i](sum_level_output, previous_states["state"][m][i], scope=self.layers_names[m][i]) # TODO: read internal (gate) states?
## Synchronize modalities
for m in self.modalities:
# Layer 0 a.k.a. output layer
with tf.compat.v1.variable_scope('l0_' + m):
# Layer 0 has no Z units
z_prior[m][0] = previous_states["z_prior"][m][0]
z_prior_mean[m][0] = previous_states["z_prior_mean"][m][0]
z_prior_var[m][0] = previous_states["z_prior_var"][m][0]
z_posterior[m][0] = previous_states["z_posterior"][m][0]
z_posterior_mean[m][0] = previous_states["z_posterior_mean"][m][0]
z_posterior_var[m][0] = previous_states["z_posterior_var"][m][0]
# Final output
l0_o = _linear(out[m][1], self.dims[m] * self.softmax_quant, bias=True, scope_here=m+"_to_out")
if self.output_z_factor[m] > 0.0:
z_to_output = z_prior[m][1] if self.vb_prior_output else z_posterior[m][1]
l1_z_logits = _linear(z_to_output, self.dims[m] * self.softmax_quant, bias=True, scope_here=m+"_z_blend_output") if z_to_output is not None else tf.zeros([self.batch_size, self.dims[m] * self.softmax_quant])
l0_o += self.output_z_factor[m] * l1_z_logits
l0_softmax = []
for i in range(self.dims[m]):
l0_softmax.append(tf.nn.softmax(l0_o[:, self.softmax_quant*i:self.softmax_quant*(i+1)], name="softmax_l0_output"))
out[m][0] = tf.concat(l0_softmax, 1, name="concat_l0_output")
state[m][0] = l0_o
return ops.internal_states_dict(t_step=t_step+1, out=out, out_initial=out_initial, state=state,
z_prior_mean=z_prior_mean, z_prior_var=z_prior_var, z_prior=z_prior,
z_posterior_mean=z_posterior_mean, z_posterior_var=z_posterior_var, z_posterior=z_posterior)
## Output parts of the model to human readable format
# fig_idx: True = plot all outputs in one figure, False = disabled, <integer> = plot specified output only, None = also save single figure for each output
# fig_xy: True = save lissajous plot of first two dimensions, False = disabled
# NB: motor_output can start at t=0 and have effectively sequence length+1 steps. Other outputs start at t=1
def write_file_csv(self, generated, epoch, modality, idx=None, layer=0, filename_prefix=None, dir=None, fig_idx=True, fig_plot_markers=False, initial=None, fig_plot_dims=None, compute_entropy=False, override_d=None):
# generated_all_layers = np.squeeze(np.asarray(generated), axis=0)
generated_all_layers = generated[0]
# Select which layer to save
if layer is None:
min_layer = 1
max_layer = self.n_layers[modality] # output for all layers (except final)
else: