-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathutils.py
89 lines (74 loc) · 2.9 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# ------------------------------------------------------------------------------
# --coding='utf-8'--
# Written by czifan ([email protected])
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import logging
import numpy as np
import configparser
from lifelines.utils import concordance_index
def read_config(ini_file):
''' Performs read config file and parses it.
:param ini_file: (String) the path of a .ini file.
:return config: (dict) the dictionary of information in ini_file.
'''
def _build_dict(items):
return {item[0]: eval(item[1]) for item in items}
# create configparser object
cf = configparser.ConfigParser()
# read .ini file
cf.read(ini_file)
config = {sec: _build_dict(cf.items(sec)) for sec in cf.sections()}
return config
def c_index(risk_pred, y, e):
''' Performs calculating c-index
:param risk_pred: (np.ndarray or torch.Tensor) model prediction
:param y: (np.ndarray or torch.Tensor) the times of event e
:param e: (np.ndarray or torch.Tensor) flag that records whether the event occurs
:return c_index: the c_index is calculated by (risk_pred, y, e)
'''
if not isinstance(y, np.ndarray):
y = y.detach().cpu().numpy()
if not isinstance(risk_pred, np.ndarray):
risk_pred = risk_pred.detach().cpu().numpy()
if not isinstance(e, np.ndarray):
e = e.detach().cpu().numpy()
return concordance_index(y, risk_pred, e)
def adjust_learning_rate(optimizer, epoch, lr, lr_decay_rate):
''' Adjusts learning rate according to (epoch, lr and lr_decay_rate)
:param optimizer: (torch.optim object)
:param epoch: (int)
:param lr: (float) the initial learning rate
:param lr_decay_rate: (float) learning rate decay rate
:return lr_: (float) updated learning rate
'''
for param_group in optimizer.param_groups:
param_group['lr'] = lr / (1+epoch*lr_decay_rate)
return optimizer.param_groups[0]['lr']
def create_logger(logs_dir):
''' Performs creating logger
:param logs_dir: (String) the path of logs
:return logger: (logging object)
'''
# logs settings
log_file = os.path.join(logs_dir,
time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + '.log')
# initialize logger
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
# initialize handler
handler = logging.FileHandler(log_file)
handler.setLevel(logging.INFO)
handler.setFormatter(
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
# initialize console
console = logging.StreamHandler()
console.setLevel(logging.INFO)
# builds logger
logger.addHandler(handler)
logger.addHandler(console)
return logger