-
Notifications
You must be signed in to change notification settings - Fork 736
/
dpt.py
924 lines (726 loc) · 26.9 KB
/
dpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
import math
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)
def unflatten_with_named_tensor(input, dim, sizes):
"""Workaround for unflattening with named tensor."""
# tracer acts up with unflatten. See https://github.com/pytorch/pytorch/issues/49538
new_shape = list(input.shape)[:dim] + list(sizes) + list(input.shape)[dim+1:]
return input.view(*new_shape)
class Slice(nn.Module):
def __init__(self, start_index=1):
super(Slice, self).__init__()
self.start_index = start_index
def forward(self, x):
return x[:, self.start_index :]
class AddReadout(nn.Module):
def __init__(self, start_index=1):
super(AddReadout, self).__init__()
self.start_index = start_index
def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index :] + readout.unsqueeze(1)
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = torch.cat((x[:, self.start_index :], readout), -1)
return self.project(features)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x
def forward_vit(pretrained, x):
b, c, h, w = x.shape
glob = pretrained.model.forward_flex(x)
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
unflattened_dim = 2
unflattened_size = (
int(torch.div(h, pretrained.model.patch_size[1], rounding_mode='floor')),
int(torch.div(w, pretrained.model.patch_size[0], rounding_mode='floor')),
)
unflatten = nn.Sequential(nn.Unflatten(unflattened_dim, unflattened_size))
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten_with_named_tensor(layer_3, unflattened_dim, unflattened_size)
if layer_4.ndim == 3:
layer_4 = unflatten_with_named_tensor(layer_4, unflattened_dim, unflattened_size)
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
return layer_1, layer_2, layer_3, layer_4
def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, : self.start_index],
posemb[0, self.start_index :],
)
gs_old = int(math.sqrt(posemb_grid.shape[0]))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward_flex(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed(
self.pos_embed, torch.div(h, self.patch_size[1], rounding_mode='floor'), torch.div(w, self.patch_size[0], rounding_mode='floor')
)
B = x.shape[0]
if hasattr(self.patch_embed, "backbone"):
x = self.patch_embed.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
if getattr(self, "dist_token", None) is not None:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
activations = {}
def get_activation(name):
def hook(model, input, output):
activations[name] = output
return hook
def get_readout_oper(vit_features, features, use_readout, start_index=1):
if use_readout == "ignore":
readout_oper = [Slice(start_index)] * len(features)
elif use_readout == "add":
readout_oper = [AddReadout(start_index)] * len(features)
elif use_readout == "project":
readout_oper = [
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
assert (
False
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
return readout_oper
def _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[2, 5, 8, 11],
vit_features=768,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
# 32, 48, 136, 384
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
)
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model(
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
start_index=2,
)
def _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=[0, 1, 8, 11],
vit_features=768,
use_vit_only=False,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
if use_vit_only == True:
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
else:
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
get_activation("1")
)
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
get_activation("2")
)
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
if use_vit_only == True:
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
else:
pretrained.act_postprocess1 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess2 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitb_rn50_384(
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
):
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
hooks = [0, 1, 8, 11] if hooks == None else hooks
return _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
if backbone == "vitl16_384":
pretrained = _make_pretrained_vitl16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups, expand=expand
) # ViT-L/16 - 85.0% Top1 (backbone)
elif backbone == "vitb_rn50_384":
pretrained = _make_pretrained_vitb_rn50_384(
use_pretrained,
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
scratch = _make_scratch(
[256, 512, 768, 768], features, groups=groups, expand=expand
) # ViT-H/16 - 85.0% Top1 (backbone)
elif backbone == "vitb16_384":
pretrained = _make_pretrained_vitb16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups, expand=expand
) # ViT-B/16 - 84.6% Top1 (backbone)
elif backbone == "resnext101_wsl":
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
elif backbone == "efficientnet_lite3":
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand==True:
out_shape1 = out_shape
out_shape2 = out_shape*2
out_shape3 = out_shape*4
out_shape4 = out_shape*8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
efficientnet = torch.hub.load(
"rwightman/gen-efficientnet-pytorch",
"tf_efficientnet_lite3",
pretrained=use_pretrained,
exportable=exportable
)
return _make_efficientnet_backbone(efficientnet)
def _make_efficientnet_backbone(effnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
)
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
return pretrained
def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
)
pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4
return pretrained
def _make_pretrained_resnext101_wsl(use_pretrained):
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
return _make_resnet_backbone(resnet)
class Interpolate(nn.Module):
"""Interpolation module.
"""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
)
return x
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)
return output
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups=1
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
if self.bn==True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn==True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn==True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
# return out + x
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand==True:
out_features = features//2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
def _make_fusion_block(features, use_bn):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
)
class DPT(BaseModel):
def __init__(
self,
head,
features=256,
backbone="vitb_rn50_384",
readout="project",
channels_last=False,
use_bn=False,
):
super(DPT, self).__init__()
self.channels_last = channels_last
hooks = {
"vitb_rn50_384": [0, 1, 8, 11],
"vitb16_384": [2, 5, 8, 11],
"vitl16_384": [5, 11, 17, 23],
}
# Instantiate backbone and reassemble blocks
self.pretrained, self.scratch = _make_encoder(
backbone,
features,
True, # Set to true of you want to train from scratch, uses ImageNet weights
groups=1,
expand=False,
exportable=False,
hooks=hooks[backbone],
use_readout=readout,
)
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
self.scratch.output_conv = head
def forward(self, x):
if self.channels_last == True:
x.contiguous(memory_format=torch.channels_last)
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return out
class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs):
features = kwargs["features"] if "features" in kwargs else 256
head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
super().__init__(head, **kwargs)
if path is not None:
self.load(path)
def forward(self, x):
return super().forward(x).squeeze(dim=1)