-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
126 lines (92 loc) · 3.83 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from obiwan.new_models import CBM, FuseCBM
from obiwan.datasets.cub import get_cub_dataloaders
from obiwan.utils import recall
import hydra
from omegaconf import DictConfig, OmegaConf
import torch
import torch.nn.functional as F
from torchvision.models.resnet import resnet18, resnet50
from torchmetrics.aggregation import MeanMetric
from torchmetrics.classification import MultilabelAccuracy, Accuracy, MultilabelF1Score
import os
from dotenv import load_dotenv
import uuid
import json
load_dotenv()
import wandb #noqa
try:
from rich.tqdm import tqdm
except ImportError:
from tqdm import tqdm
def evaluate(model: CBM, dataloader, device, num_classes, num_concepts):
model.eval()
model.to(device)
concept_accuracy = MultilabelAccuracy(num_labels=num_concepts)
concept_f1 = MultilabelF1Score(num_labels=num_concepts)
class_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
concept_accuracy.to(device)
class_accuracy.to(device)
concept_f1.to(device)
with torch.no_grad():
for imgs, attrs, labels in tqdm(dataloader):
imgs = imgs.to(device)
attrs = attrs.to(device)
labels = labels.to(device)
concepts, classes = model(imgs)
concept_accuracy.update(concepts, attrs)
class_accuracy.update(classes, labels.long().squeeze())
concept_f1.update(concepts, attrs)
final_concept_accuracy = concept_accuracy.compute()
final_class_accuracy = class_accuracy.compute()
final_concept_f1 = concept_f1.compute()
return final_concept_accuracy, final_class_accuracy, final_concept_f1
def evaluate_recall(model: FuseCBM, dataloader, device, intervene: bool, pre_concept: bool):
model.eval()
model.to(device)
embeddings_list = []
labels_list = []
with torch.no_grad():
for imgs, attrs, labels in tqdm(dataloader):
imgs = imgs.to(device)
attrs = attrs.to(device)
labels = labels.to(device)
# embeddings = model.get_embedding(imgs)
if pre_concept:
embeddings = model.get_pre_concept_embedding(imgs)
else:
if intervene:
embeddings = model.get_fused_embedding(imgs, False)
else:
embeddings = model.get_fused_embedding(imgs, False)
embeddings = F.normalize(embeddings, dim=1)
embeddings_list.append(embeddings)
labels_list.append(labels)
embeddings = torch.cat(embeddings_list, dim=0)
labels = torch.cat(labels_list, dim=0)
recall_list = recall(embeddings, labels, rank=[1,5,10])
return recall_list
def evaluate_recall_with_gallery(model: FuseCBM, dataloader, device, intervene: bool, pre_concept: bool, gallery_features, gallery_labels):
model.eval()
model.to(device)
embeddings_list = []
labels_list = []
with torch.no_grad():
for imgs, attrs, labels in tqdm(dataloader):
imgs = imgs.to(device)
attrs = attrs.to(device)
labels = labels.to(device)
# embeddings = model.get_embedding(imgs)
if pre_concept:
embeddings = model.get_pre_concept_embedding(imgs)
else:
if intervene:
embeddings = model.get_fused_embedding_with_intervention(imgs, attrs, False, False)
else:
embeddings = model.get_embedding(imgs)
embeddings = F.normalize(embeddings, dim=1)
embeddings_list.append(embeddings)
labels_list.append(labels)
embeddings = torch.cat(embeddings_list, dim=0)
labels = torch.cat(labels_list, dim=0)
recall_list = recall(embeddings, labels, rank=[1,5,10], gallery_features=gallery_features, gallery_labels=gallery_labels)
return recall_list