-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
84 lines (70 loc) · 2.56 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gc
import os
import fire
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.loggers.wandb import WandbLogger
from transformer.model import Transformer
from data.ar_dataset import ARDataModule
from utils.seed import seed_everything
seed_everything(42, benchmark=False)
def test(
ds_name,
krn_encoding: str = "bekern",
input_modality: str = "audio", # "audio" or "image" or "both"
use_distorted_images: bool = False, # Only used if input_modality == "image" or "both"
img_height: int = None, # If None, the original image height is used (only used if input_modality == "image" or "both")
checkpoint_path: str = "",
):
gc.collect()
torch.cuda.empty_cache()
# TODO
# Implement multimodal testing
if input_modality == "both":
raise NotImplementedError(
"We can only perform unimodal model testing right now."
)
# Check if checkpoint path is empty or does not exist
if checkpoint_path == "":
raise ValueError("Checkpoint path not provided")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint path {checkpoint_path} does not exist")
# Get source dataset name
_, src_ds_name, model_name = checkpoint_path.split("/")
# Experiment info
print("TEST EXPERIMENT")
print(f"\tSource dataset: {src_ds_name}")
print(f"\tTest dataset: {ds_name}")
print(f"\tKern encoding: {krn_encoding}")
print(f"\tInput modality: {input_modality}")
print(
f"\tUse distorted images: {use_distorted_images} (used if input_modality in ['image', 'both'])"
)
print(f"\tImage height: {img_height} (used if input_modality in ['image', 'both'])")
print(f"\tCheckpoint path: {checkpoint_path}")
# Data module
datamodule = ARDataModule(
ds_name=ds_name,
krn_encoding=krn_encoding,
input_modality=input_modality,
use_distorted_images=use_distorted_images,
img_height=img_height,
)
datamodule.setup(stage="test")
ytest_i2w = datamodule.test_ds.i2w
# Model
model = Transformer.load_from_checkpoint(checkpoint_path, ytest_i2w=ytest_i2w)
# Test
trainer = Trainer(
logger=WandbLogger(
project="OMR-A2S-Poly-Multimodal",
group=model_name.split(".ckpt")[0],
name=f"Train-{src_ds_name}_Test-{ds_name}",
log_model=False,
),
precision="16-mixed", # Mixed precision training
)
model.freeze()
trainer.test(model, datamodule=datamodule)
if __name__ == "__main__":
fire.Fire(test)