-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
42 lines (34 loc) · 1.18 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from model_stealing.predictors import Alexnet, AlexnetHalf
from dbma.dbma import DBMA
from dbma.WNR import WNR
from dbma.unet import UnetGenerator, get_norm_layer
from model_stealing.predictors import ResNet18
def get_model(name, path):
# return model with normalized layer on top of it.
if name == "alexnet":
model = Alexnet()
elif name == "alexnet_half":
model = AlexnetHalf()
elif name == "resnet18":
model = ResNet18()
else:
raise Exception("unknown models {}".format(name))
model.load_state_dict(torch.load(path, map_location="cpu"))
return model
def create_dbma_model(args=None):
wnr = WNR(args.wvlt, args.mode, args.levels, args.keep_percentage)
norm_layer = get_norm_layer("instance")
regen = UnetGenerator(3, 3, 5, 64, norm_layer, False)
dbma = DBMA(wnr, regen)
return dbma
if __name__ == "__main__":
wnr = WNR()
norm_layer = get_norm_layer("instance")
regen = UnetGenerator(3, 3, 5, 64, norm_layer, False)
model = ResNet18()
dbma = DBMA(wnr, regen, model)
x = torch.rand(size=(32, 3, 32, 32))
y = torch.rand(size=(32, 3, 32, 32))
out = dbma(x, y)
print(out.keys())