-
Notifications
You must be signed in to change notification settings - Fork 5
/
generate.py
28 lines (21 loc) · 1005 Bytes
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
import torch
from numpy import linalg as LA
from mmdetection.splits import get_unseen_class_ids ,get_seen_class_ids
def load_all_att(opt):
attribute = np.load(opt.class_embedding)
labels = np.arange(len(attribute))
attribute/=LA.norm(attribute, ord=2)
return torch.from_numpy(attribute), torch.from_numpy(labels)
def load_seen_att(opt):
attribute, labels = load_all_att(opt)
classes_ids = np.concatenate(([0], get_seen_class_ids(opt.dataset, split=opt.classes_split)))
return attribute[classes_ids], labels[classes_ids]
def load_unseen_att_with_bg(opt):
attribute, labels = load_all_att(opt)
classes_ids = np.concatenate(([0], get_unseen_class_ids(opt.dataset, split=opt.classes_split)))
return attribute[classes_ids], labels[classes_ids]
def load_unseen_att(opt):
attribute, labels = load_all_att(opt)
classes_ids = get_unseen_class_ids(opt.dataset, split=opt.classes_split)
return attribute[classes_ids], labels[classes_ids]