-
Notifications
You must be signed in to change notification settings - Fork 27
/
hubconf.py
172 lines (136 loc) · 4.4 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
from sslearning.models.accNet import Resnet
import torch
import copy
dependencies = ["torch"]
def load_weights(
weight_path, model, my_device="cpu", name_start_idx=2, is_dist=False
):
# only need to change weights name when the
# model is trained in a distributed manner
pretrained_dict = torch.load(weight_path, map_location=my_device)
pretrained_dict_v2 = copy.deepcopy(
pretrained_dict
) # v2 has the right para names
if is_dist:
for key in pretrained_dict:
para_names = key.split(".")
new_key = ".".join(para_names[name_start_idx:])
pretrained_dict_v2[new_key] = pretrained_dict_v2.pop(key)
model_dict = model.state_dict()
# 1. filter out unnecessary keys such as the final linear layers
# we don't want linear layer weights either
pretrained_dict = {
k: v
for k, v in pretrained_dict_v2.items()
if k in model_dict and k.split(".")[0] != "classifier"
}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
print("%d Weights loaded" % len(pretrained_dict))
def harnet5(pretrained=False, my_device="cpu", class_num=2, **kwargs):
"""
harnet5 model
pretrained (bool): kwargs, load pretrained weights into the model
Input:
X is of size: N x 3 x 150. N is the number of examples.
3 is the xyz channel. 150 consists of
a 5-second recording with 30hz.
Output:
my_device (str)
class_num (int): the number of classes to predict
Example:
repo = 'OxWearables/ssl-wearables'
model = torch.hub.load(repo, 'harnet5',
pretrained=True)
x = np.random.rand(1, 3, 150)
x = torch.FloatTensor(x)
model(x)
"""
# Call the model, load pretrained weights
model = Resnet(
output_size=class_num,
is_eva=True,
epoch_len=5,
resnet_version=1,
)
if pretrained:
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(
dirname, "", "model_check_point", "mtl_5_best.mdl"
)
load_weights(
checkpoint, model, my_device, is_dist=True, name_start_idx=1
)
return model
def harnet10(pretrained=False, my_device="cpu", class_num=2, **kwargs):
"""
harnet10 model
pretrained (bool): kwargs, load pretrained weights into the model
Input:
X is of size: N x 3 x 300. N is the number of examples.
3 is the xyz channel. 300 consists of
a 10-second recording with 30hz.
Output:
my_device (str)
class_num (int): the number of classes to predict
Example:
repo = 'OxWearables/ssl-wearables'
model = torch.hub.load(repo, 'harnet10',
pretrained=True)
x = np.random.rand(1, 3, 300)
x = torch.FloatTensor(x)
model(x)
"""
# Call the model, load pretrained weights
model = Resnet(
output_size=class_num,
is_eva=True,
resnet_version=1,
)
if pretrained:
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(
dirname, "", "model_check_point", "mtl_best.mdl"
)
load_weights(
checkpoint, model, my_device, is_dist=True, name_start_idx=1
)
return model
def harnet30(pretrained=False, my_device="cpu", class_num=2, **kwargs):
"""
harnet10 model
pretrained (bool): kwargs, load pretrained weights into the model
Input:
X is of size: N x 3 x 900. N is the number of examples.
3 is the xyz channel. 900 consists of
a 30-second recording with 30hz.
Output:
my_device (str)
class_num (int): the number of classes to predict
Example:
repo = 'OxWearables/ssl-wearables'
model = torch.hub.load(repo, 'harnet30',
pretrained=True)
x = np.random.rand(1, 3, 900)
x = torch.FloatTensor(x)
model(x)
"""
# Call the model, load pretrained weights
model = Resnet(
output_size=class_num,
is_eva=True,
epoch_len=30,
resnet_version=1,
)
if pretrained:
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(
dirname, "", "model_check_point", "mtl_30_best.mdl"
)
load_weights(
checkpoint, model, my_device, is_dist=True, name_start_idx=1
)
return model