forked from czifan/DeepSurv.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
69 lines (59 loc) · 2.33 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# ------------------------------------------------------------------------------
# --coding='utf-8'--
# Written by czifan ([email protected])
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import h5py
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
class SurvivalDataset(Dataset):
''' The dataset class performs loading data from .h5 file. '''
def __init__(self, h5_file, is_train):
''' Loading data from .h5 file based on (h5_file, is_train).
:param h5_file: (String) the path of .h5 file
:param is_train: (bool) which kind of data to be loaded?
is_train=True: loading train data
is_train=False: loading test data
'''
# loads data
self.X, self.e, self.y = \
self._read_h5_file(h5_file, is_train)
# normalizes data
self._normalize()
print('=> load {} samples'.format(self.X.shape[0]))
def _read_h5_file(self, h5_file, is_train):
''' The function to parsing data from .h5 file.
:return X: (np.array) (n, m)
m is features dimension.
:return e: (np.array) (n, 1)
whether the event occurs? (1: occurs; 0: others)
:return y: (np.array) (n, 1)
the time of event e.
'''
split = 'train' if is_train else 'test'
with h5py.File(h5_file, 'r') as f:
X = f[split]['x'][()]
e = f[split]['e'][()].reshape(-1, 1)
y = f[split]['t'][()].reshape(-1, 1)
return X, e, y
def _normalize(self):
''' Performs normalizing X data. '''
self.X = (self.X-self.X.min(axis=0)) / \
(self.X.max(axis=0)-self.X.min(axis=0))
def __getitem__(self, item):
''' Performs constructing torch.Tensor object'''
# gets data with index of item
X_item = self.X[item] # (m)
e_item = self.e[item] # (1)
y_item = self.y[item] # (1)
# constructs torch.Tensor object
X_tensor = torch.from_numpy(X_item)
e_tensor = torch.from_numpy(e_item)
y_tensor = torch.from_numpy(y_item)
return X_tensor, y_tensor, e_tensor
def __len__(self):
return self.X.shape[0]