-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathData_Retriever_Inference.py
59 lines (39 loc) · 1.15 KB
/
Data_Retriever_Inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import pdb
import os
import cv2
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, Dataset
from albumentations import (Normalize, Compose, Resize)
from albumentations.pytorch import ToTensor
import torch.utils.data as data
# In[2]:
class TestDataset(Dataset):
'''Dataset for test prediction'''
def __init__(self, root, df, mean, std):
self.root = root
self.fnames = df['ImageId'].unique().tolist()
self.num_samples = len(self.fnames)
self.transform = Compose(
[
Normalize(mean=mean, std=std, p=1),
Resize(224,224),
ToTensor()
]
)
def __getitem__(self, idx):
fname = self.fnames[idx]
path = os.path.join(self.root, fname)
image = cv2.imread(path)
images = self.transform(image=image)["image"]
return fname, images
def __len__(self):
return self.num_samples
# In[ ]:
# In[ ]: