forked from Gabrysse/MLDL2024_project1
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvalidation.py
149 lines (118 loc) · 5.02 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from typing import List
import utils
from callbacks import Callback
def val(
epoch : int,
model : torch.nn.Module,
val_loader : DataLoader,
num_classes : int,
device : str = 'cpu',
callbacks : list[Callback] = []):
'''Validation function
Args:
epoch (int): Current epoch number
model (nn.Module): Model to evaluate
val_loader (DataLoader): Validation loader
num_classes (int): Number of classes in the dataset
device (str): Device to run the evaluation on
callbacks (list): List of callback functions to run during validation
Returns:
float: Mean IoU for the validation set
'''
# Run validation begin callbacks
for callback in callbacks:
callback.on_validation_begin()
model.eval()
total_hist = np.zeros((num_classes, num_classes)) # Initialize total histogram for all classes
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(val_loader):
inputs = inputs.to(device)
targets = targets.to(device).squeeze(1) # targets should be [batch_size, height, width]
outputs = model(inputs)
# Verify the shape of the outputs
if isinstance(outputs, tuple):
outputs = outputs[0]
predicted = torch.argmax(outputs, dim=1)
# Calculate fast histogram and accumulate
hist = utils.fast_hist(targets.cpu().numpy(), predicted.cpu().numpy(), num_classes)
total_hist += hist
# calculate True Positive on the confusion matrix
TP = np.diag(total_hist)
# calculate the Accuracy
pixel_acc = np.sum(TP) / np.sum(total_hist)
# loss
loss = 1. - pixel_acc
# Run validation end callbacks
for callback in callbacks:
callback.on_validation_batch_end(batch_idx, loss)
# Compute per-class IoU from the accumulated histogram
ious = utils.per_class_iou(total_hist)
mean_iou = np.nanmean(ious) # Mean IoU for reporting
print(f'Validation Mean IoU for Epoch {epoch + 1}: {mean_iou:.4f}')
# Run validation end callbacks
for callback in callbacks:
callback.on_validation_end(mean_iou)
return mean_iou
def val_GTA5(
epoch : int,
model : torch.nn.Module,
val_loader : DataLoader,
num_classes : int,
class_names : list[str],
callbacks : list[Callback] = [],
device : str = 'cpu'):
'''Validation function for GTA5 dataset, using the provided class names for reporting IoU per class.
Args:
epoch (int): Current epoch number
model (nn.Module): Model to evaluate
val_loader (DataLoader): Validation loader
num_classes (int): Number of classes in the dataset
class_names (list): List of class names for reporting
callbacks (list): List of callback functions to run during validation
device (str): Device to run the evaluation on, default is 'cpu'
Returns:
float: Mean IoU for the validation set
'''
model.eval()
total_miou = 0
# Run validation begin callbacks
for callback in callbacks:
callback.on_validation_begin()
# Initialize confusion matrix
confusion_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(val_loader):
inputs = inputs.to(device)
targets = targets.to(device).squeeze(1) # Ensure targets are [batch_size, height, width]
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = outputs[0] # Select the main output if model returns a tuple
predicted = torch.argmax(outputs, dim=1).cpu().numpy()
targets = targets.cpu().numpy()
# Update confusion matrix using fast_hist
confusion_matrix += utils.fast_hist(targets, predicted, num_classes)
# calculate True Positive on the confusion matrix
TP = np.diag(confusion_matrix)
# calculate the Accuracy
pixel_acc = np.sum(TP) / np.sum(confusion_matrix)
# loss
loss = 1. - pixel_acc
# Run validation end callbacks
for callback in callbacks:
callback.on_validation_batch_end(batch_idx, loss)
# Calculate per class IoU from the confusion matrix
IoUs = utils.per_class_iou(confusion_matrix)
total_miou = np.nanmean(IoUs) # Calculate mean IoU, ignoring NaNs
print(f'Validation mIoU for Epoch {epoch + 1}: {total_miou:.4f}')
class_result_df = pd.DataFrame({'Class': class_names, 'IoU': [f'{iou:.4f}' for iou in IoUs]})
print(class_result_df)
# Run validation end callbacks
for callback in callbacks:
callback.on_validation_end({
'validation_mIoU': total_miou
},data=class_result_df)
return total_miou, class_result_df