forked from anosorae/IRRA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
35 lines (30 loc) · 1.14 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from prettytable import PrettyTable
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import torch
import numpy as np
import time
import os.path as op
from datasets import build_dataloader
from processor.processor import do_inference
from utils.checkpoint import Checkpointer
from utils.logger import setup_logger
from model import build_model
from utils.metrics import Evaluator
import argparse
from utils.iotools import load_train_configs
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="IRRA Test")
parser.add_argument("--config_file", default='logs/CUHK-PEDES/iira/configs.yaml')
args = parser.parse_args()
args = load_train_configs(args.config_file)
args.training = False
logger = setup_logger('IRRA', save_dir=args.output_dir, if_train=args.training)
logger.info(args)
device = "cuda"
test_img_loader, test_txt_loader, num_classes = build_dataloader(args)
model = build_model(args, num_classes=num_classes)
checkpointer = Checkpointer(model)
checkpointer.load(f=op.join(args.output_dir, 'best.pth'))
model.to(device)
do_inference(model, test_img_loader, test_txt_loader)