-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
72 lines (60 loc) · 2.11 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import cv2
import ipdb
import os
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch import nn
import torch
import lightning as L
import timm
from timm.models.registry import model_entrypoint
import pickle
import torch.nn.functional as F
from dataloader import build_test_data_loader,build_all_data
from model import ConvNext_S
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from eval import score
def main():
# init model
model = ConvNext_S.load_from_checkpoint(
"./ubc/baseline1_fix-lr_ziwbreee_softmax.ckpt", train_dataloader=None
)
model.eval()
model = model.cuda()
# load weight
# test data
test_dataset = build_test_data_loader()
preds = []
with torch.no_grad():
bar = tqdm(enumerate(test_dataset), total=len(test_dataset))
for step, data in bar:
data = data["data"].unsqueeze(0).cuda()
outputs = model(data)
outputs = F.softmax(outputs)
preds.append(outputs.detach().cpu().numpy())
preds = np.vstack(preds)
print(preds.shape)
df_crop=test_dataset.df
for i in range(preds.shape[-1]):
df_crop[f"cat{i}"] = preds[:, i]
dict_label = {}
for image_id, gdf in df_crop.groupby("image_id"):
dict_label[image_id] = np.argmax( gdf[ [f"cat{i}" for i in range(preds.shape[-1])] ].values.max(axis=0) )
#dict_label[image_id] = np.argmax( gdf[ [f"cat{i}" for i in range(preds.shape[-1])] ].values.mean(axis=0) )
preds = np.array( [ dict_label[image_id] for image_id in df_crop["image_id"].unique() ] )
pred_labels=[test_dataset.class_name[i] for i in preds]
gt_df=build_all_data()
test_df=build_all_data()
test_df["label"] = pred_labels
# visualization
scores=score(solution=gt_df[['image_id','label']],submission=test_df[['image_id','label']],row_id_column_name='image_id')
print(scores)
if __name__ == "__main__":
main()