-
Notifications
You must be signed in to change notification settings - Fork 13
/
utils.py
161 lines (132 loc) · 5.72 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
import numpy as np
import math
import datetime
class CoordEncoder:
def __init__(self, input_enc, raster=None):
self.input_enc = input_enc
self.raster = raster
def encode(self, locs, normalize=True):
# assumes lon, lat in range [-180, 180] and [-90, 90]
if normalize:
locs = normalize_coords(locs)
if self.input_enc == 'sin_cos': # sinusoidal encoding
loc_feats = encode_loc(locs)
elif self.input_enc == 'env': # bioclim variables
loc_feats = bilinear_interpolate(locs, self.raster)
elif self.input_enc == 'sin_cos_env': # sinusoidal encoding & bioclim variables
loc_feats = encode_loc(locs)
context_feats = bilinear_interpolate(locs, self.raster)
loc_feats = torch.cat((loc_feats, context_feats), 1)
else:
raise NotImplementedError('Unknown input encoding.')
return loc_feats
def normalize_coords(locs):
# locs is in lon {-180, 180}, lat {90, -90}
# output is in the range [-1, 1]
locs[:,0] /= 180.0
locs[:,1] /= 90.0
return locs
def encode_loc(loc_ip, concat_dim=1):
# assumes inputs location are in range -1 to 1
# location is lon, lat
feats = torch.cat((torch.sin(math.pi*loc_ip), torch.cos(math.pi*loc_ip)), concat_dim)
return feats
def bilinear_interpolate(loc_ip, data, remove_nans_raster=True):
# loc is N x 2 vector, where each row is [lon,lat] entry
# each entry spans range [-1,1]
# data is H x W x C, height x width x channel data matrix
# op will be N x C matrix of interpolated features
assert data is not None
# map to [0,1], then scale to data size
loc = (loc_ip.clone() + 1) / 2.0
loc[:,1] = 1 - loc[:,1] # this is because latitude goes from +90 on top to bottom while
# longitude goes from -90 to 90 left to right
assert not torch.any(torch.isnan(loc))
if remove_nans_raster:
data[torch.isnan(data)] = 0.0 # replace with mean value (0 is mean post-normalization)
# cast locations into pixel space
loc[:, 0] *= (data.shape[1]-1)
loc[:, 1] *= (data.shape[0]-1)
loc_int = torch.floor(loc).long() # integer pixel coordinates
xx = loc_int[:, 0]
yy = loc_int[:, 1]
xx_plus = xx + 1
xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1
yy_plus = yy + 1
yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1
loc_delta = loc - torch.floor(loc) # delta values
dx = loc_delta[:, 0].unsqueeze(1)
dy = loc_delta[:, 1].unsqueeze(1)
interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \
data[yy_plus, xx, :]*(1-dx)*dy + data[yy_plus, xx_plus, :]*dx*dy
return interp_val
def rand_samples(batch_size, device, rand_type='uniform'):
# randomly sample background locations
if rand_type == 'spherical':
rand_loc = torch.rand(batch_size, 2).to(device)
theta1 = 2.0*math.pi*rand_loc[:, 0]
theta2 = torch.acos(2.0*rand_loc[:, 1] - 1.0)
lat = 1.0 - 2.0*theta2/math.pi
lon = (theta1/math.pi) - 1.0
rand_loc = torch.cat((lon.unsqueeze(1), lat.unsqueeze(1)), 1)
elif rand_type == 'uniform':
rand_loc = torch.rand(batch_size, 2).to(device)*2.0 - 1.0
return rand_loc
def get_time_stamp():
cur_time = str(datetime.datetime.now())
date, time = cur_time.split(' ')
h, m, s = time.split(':')
s = s.split('.')[0]
time_stamp = '{}-{}-{}-{}'.format(date, h, m, s)
return time_stamp
def coord_grid(grid_size, split_ids=None, split_of_interest=None):
# generate a grid of locations spaced evenly in coordinate space
feats = np.zeros((grid_size[0], grid_size[1], 2), dtype=np.float32)
mg = np.meshgrid(np.linspace(-180, 180, feats.shape[1]), np.linspace(90, -90, feats.shape[0]))
feats[:, :, 0] = mg[0]
feats[:, :, 1] = mg[1]
if split_ids is None or split_of_interest is None:
# return feats for all locations
# this will be an N x 2 array
return feats.reshape(feats.shape[0]*feats.shape[1], 2)
else:
# only select a subset of locations
ind_y, ind_x = np.where(split_ids==split_of_interest)
# these will be N_subset x 2 in size
return feats[ind_y, ind_x, :]
def create_spatial_split(raster, mask, train_amt=1.0, cell_size=25):
# generates a checkerboard style train test split
# 0 is invalid, 1 is train, and 2 is test
# c_size is units of pixels
split_ids = np.ones((raster.shape[0], raster.shape[1]))
start = cell_size
for ii in np.arange(0, split_ids.shape[0], cell_size):
if start == 0:
start = cell_size
else:
start = 0
for jj in np.arange(start, split_ids.shape[1], cell_size*2):
split_ids[ii:ii+cell_size, jj:jj+cell_size] = 2
split_ids = split_ids*mask
if train_amt < 1.0:
# take a subset of the data
tr_y, tr_x = np.where(split_ids==1)
inds = np.random.choice(len(tr_y), int(len(tr_y)*(1.0-train_amt)), replace=False)
split_ids[tr_y[inds], tr_x[inds]] = 0
return split_ids
def average_precision_score_faster(y_true, y_scores):
# drop in replacement for sklearn's average_precision_score
# comparable up to floating point differences
num_positives = y_true.sum()
inds = np.argsort(y_scores)[::-1]
y_true_s = y_true[inds]
false_pos_c = np.cumsum(1.0 - y_true_s)
true_pos_c = np.cumsum(y_true_s)
recall = true_pos_c / num_positives
false_neg = np.maximum(true_pos_c + false_pos_c, np.finfo(np.float32).eps)
precision = true_pos_c / false_neg
recall_e = np.hstack((0, recall, 1))
recall_e = (recall_e[1:] - recall_e[:-1])[:-1]
map_score = (recall_e*precision).sum()
return map_score