-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paths2_export.py
277 lines (221 loc) · 8.18 KB
/
s2_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import os
import glob
import re
import logging
from collections import OrderedDict
import gdal
import scipy.ndimage
logger = logging.getLogger('s2_export')
_res_bands = OrderedDict([
('10m', ['B2', 'B3', 'B4', 'B8']),
('20m', ['B5', 'B6', 'B7', 'B8A', 'B11', 'B12']),
('60m', ['B1', 'B9', 'B10'])])
def get_bands_res():
bands_res = OrderedDict()
for reskey in _res_bands:
for bandname in _res_bands[reskey]:
bands_res[bandname] = reskey
return bands_res
_bands_res = get_bands_res()
def res_to_float(reskey):
return float(reskey[:2])
def open_res_datasets(subdatasets):
res_dss = OrderedDict((reskey, gdal.Open(subdatasets[k][0])) for k, reskey in enumerate(_res_bands))
return res_dss
def get_gdal_bands(dss):
bands_gdal = OrderedDict()
for reskey in dss:
for b, bandname in enumerate(_res_bands[reskey]):
bands_gdal[bandname] = dss[reskey].GetRasterBand(b+1)
return bands_gdal
def create_outfile_from_templates(outfile, nbands, template_ds, template_band,
driver_name='GTiff',
gdal_dtype=gdal.GDT_Int16, create_options=['COMPRESS=LZW', 'BIGTIFF=IF_SAFER']):
geotransform = template_ds.GetGeoTransform()
projection = template_ds.GetProjection()
nx = template_band.XSize
ny = template_band.YSize
gdal_dtype = template_band.DataType
drv = gdal.GetDriverByName(driver_name)
out = drv.Create(outfile, ny, nx, nbands, gdal_dtype, create_options)
if out is None:
raise IOError('Unable to create new dataset in {}.'.format(outfile))
out.SetGeoTransform(geotransform)
out.SetProjection(projection)
tgt_nodata = template_band.GetNoDataValue()
if tgt_nodata is not None:
for b in range(nbands):
out.GetRasterBand(b+1).SetNoDataValue(tgt_nodata)
return out
def write_bands(outds, bands_gdal, bands, tgt_res, tgt_nodata=None):
"""Read band data, zoom and write to output dataset
Parameters
----------
outds : gdal Dataset
dataset to write to
bands_gdal : dict
dictionary mapping band names to open gdal bands
tgt_res : str
target resolution
e.g. 10m
tgt_nodata : float
target nodata
"""
tgt_res_float = res_to_float(tgt_res)
for b, bandname in enumerate(bands):
logger.info('Adding data from band {}'.format(bandname))
data = bands_gdal[bandname].ReadAsArray()
zoom = res_to_float(_bands_res[bandname]) / tgt_res_float
if zoom != 1:
logger.info('Scaling data to {} resolution ...'.format(tgt_res))
logger.debug('Zoom factor is {}'.format(zoom))
data = scipy.ndimage.interpolation.zoom(data, zoom, order=0)
outds.GetRasterBand(b+1).WriteArray(data)
logger.info('Writing to disk ...')
outds = None
logger.info('Done.')
def export_from_subdatasets(subdatasets, bands, tgt_res, outfile):
"""Export selected bands from subdatasets
Parameters
----------
subdatasets : list of (str, str)
returned from ds.GetSubDatasets()
bands : list of str
names of bands to extract
tgt_res : str
target resolution
e.g. 10m
outfile : str
path to output file
"""
res_dss = open_res_datasets(subdatasets)
bands_gdal = get_gdal_bands(res_dss)
nbands = len(bands)
template_ds = res_dss[tgt_res]
template_bandname = _res_bands[tgt_res][0]
template_band = bands_gdal[template_bandname]
outds = create_outfile_from_templates(outfile, nbands, template_ds, template_band)
write_bands(outds, bands_gdal, bands, tgt_res)
def find_granule_name(s):
try:
return re.search('(?<=T)\d{2}[A-Z]{3}', s).group(0)
except AttributeError:
return None
def get_granule_xml(filelist):
granule_xml = OrderedDict()
for fname in filelist:
if not fname.endswith('.xml'):
continue
granule = find_granule_name(fname)
if granule is None:
continue
granule_xml[granule] = fname
return granule_xml
def get_multi_granule_xml(ds):
granule_xml = OrderedDict()
subfiles = [e[0] for e in ds.GetSubDatasets() if 'PREVIEW' not in e[0]]
for sf in subfiles:
subsubfiles = gdal.Open(sf).GetFileList()
subgx = get_granule_xml(subsubfiles)
granule_xml.update(subgx)
return granule_xml
def get_multi_granule_subdatasets(ds):
granule_xml = get_multi_granule_xml(ds)
granule_subdatasets = OrderedDict()
for granule in granule_xml:
fname = granule_xml[granule]
granule_subdatasets[granule] = gdal.Open(fname).GetSubDatasets()
return granule_subdatasets
def get_single_granule_subdatasets(ds):
subdatasets = ds.GetSubDatasets()
filelist = gdal.Open(subdatasets[0][0]).GetFileList()
for fname in filelist:
if not fname.endswith('.xml'):
continue
granule = find_granule_name(fname)
if granule is None:
raise ValueError('Unable to determine granule name.')
return {granule: subdatasets}
def get_subdatasets(ds):
subdatasets = ds.GetSubDatasets()
if len(subdatasets) <= 5:
logger.info('Got single-tile S2 product')
return get_single_granule_subdatasets(ds)
else:
logger.info('Got multi-tile S2 product')
return get_multi_granule_subdatasets(ds)
def ensure_xml(infile):
if not infile.endswith('.xml'):
pattern = os.path.join(infile, '*MTD*.xml')
try:
return glob.glob(pattern)[0]
except IndexError:
raise ValueError('Unable to find MTD XML file with pattern \'{}\'.'.format(pattern))
else:
return infile
def get_metadata_from_xml(xmlfile):
"""Very simple XML reader"""
date = None
platform = None
with open(xmlfile, 'r') as f:
for line in f:
if date and platform:
break
try:
if 'PRODUCT_START_TIME' in line:
date = re.search('\d{4}-\d{2}-\d{2}', line).group(0).replace('-', '')
elif 'SPACECRAFT_NAME' in line:
platform = 'S' + re.search('(?<=Sentinel-)2[AB]', line).group(0)
except AttributeError:
pass
if date is None or platform is None:
raise ValueError('Unable to get all metdatada from XML file \'{}\'.'.format(xmlfile))
return dict(date=date, platform=platform)
def generate_outfilename(xmlfile, granule):
meta = get_metadata_from_xml(xmlfile)
return '{platform}_{date}_T{granule}'.format(granule=granule, **meta)
def export(infile, outdir, bands, tgt_res, granules=None):
"""Export selected S2 bands and granules to GeoTIFF
Parameters
----------
infile : str
path to input SAFE or MTD xml file
outdir : str
path to output dir
bands : list of str
names of bands to extract
e.g. B1, B10 (no leading zero!)
tgt_res : str
target resolution e.g. 10m
data not in this resolution will be interpolated
granules : list of str
extract only these granules
Returns
-------
outfiles : list of str
list of output files generated
"""
infile = ensure_xml(infile)
ds = gdal.Open(infile)
if ds is None:
raise IOError('Failed to read input file \'{}\'. Please provide a valid S2 MTD XML file.'.format(infile))
logger.info('Retrieving granule subdatasets ...')
granule_subdatasets = get_subdatasets(ds)
if granules:
logger.info('Selecting granules ...')
for granule in granule_subdatasets:
if granule not in granules:
granule_subdatasets.pop(granule)
if not granule_subdatasets:
logger.info('No granules left to export. Exiting.')
return
logger.info('Granules to export: {}'.format(list(granule_subdatasets)))
outfiles = []
for granule in granule_subdatasets:
logger.info('Exporting granule {} ...'.format(granule))
subdatasets = granule_subdatasets[granule]
outfname = generate_outfilename(infile, granule) + '.tif'
outfile = os.path.join(outdir, outfname)
export_from_subdatasets(subdatasets, bands=bands, tgt_res=tgt_res, outfile=outfile)
outfiles.append(outfile)
return outfiles