Skip to content

Commit

Permalink
Use geojson gcps as lon lat source
Browse files Browse the repository at this point in the history
  • Loading branch information
mraspaud committed Mar 7, 2024
1 parent eb6e25f commit 4ce31f7
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 51 deletions.
91 changes: 50 additions & 41 deletions satpy/readers/sar_c_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@

import functools
import logging
import os
from collections import defaultdict
from datetime import timezone as tz
from functools import cached_property
from threading import Lock

import defusedxml.ElementTree as ET
Expand All @@ -50,6 +51,7 @@

from satpy.dataset.data_dict import DatasetDict
from satpy.dataset.dataid import DataID
from satpy.readers import open_file_or_filename
from satpy.readers.file_handlers import BaseFileHandler
from satpy.readers.yaml_reader import GenericYAMLReader
from satpy.utils import get_legacy_chunk_size
Expand Down Expand Up @@ -101,10 +103,10 @@ def __init__(self, filename, filename_info, filetype_info,
"""Init the xml filehandler."""
super().__init__(filename, filename_info, filetype_info)

self._start_time = filename_info["start_time"]
self._end_time = filename_info["end_time"]
self._start_time = filename_info["start_time"].replace(tzinfo=tz.utc)
self._end_time = filename_info["end_time"].replace(tzinfo=tz.utc)
self._polarization = filename_info["polarization"]
self.root = ET.parse(self.filename)
self.root = ET.parse(open_file_or_filename(self.filename))
self._image_shape = image_shape

def get_metadata(self):
Expand Down Expand Up @@ -507,8 +509,13 @@ def interpolate_xarray(xpoints, ypoints, values, shape,
"""Interpolate, generating a dask array."""
from scipy.interpolate import RectBivariateSpline

vchunks = range(0, shape[0], blocksize)
hchunks = range(0, shape[1], blocksize)
try:
blocksize_row, blocksize_col = blocksize
except ValueError:
blocksize_row = blocksize_col = blocksize

vchunks = range(0, shape[0], blocksize_row)
hchunks = range(0, shape[1], blocksize_col)

token = tokenize(blocksize, xpoints, ypoints, values, shape)
name = "interpolate-" + token
Expand All @@ -520,15 +527,15 @@ def interpolator(xnew, ynew):
return spline(xnew, ynew).T

dskx = {(name, i, j): (interpolate_slice,
slice(vcs, min(vcs + blocksize, shape[0])),
slice(hcs, min(hcs + blocksize, shape[1])),
slice(vcs, min(vcs + blocksize_row, shape[0])),
slice(hcs, min(hcs + blocksize_col, shape[1])),
interpolator)
for i, vcs in enumerate(vchunks)
for j, hcs in enumerate(hchunks)
}

res = da.Array(dskx, name, shape=list(shape),
chunks=(blocksize, blocksize),
chunks=(blocksize_row, blocksize_col),
dtype=values.dtype)
return DataArray(res, dims=("y", "x"))

Expand Down Expand Up @@ -573,9 +580,8 @@ class SAFEGRD(BaseFileHandler):
def __init__(self, filename, filename_info, filetype_info, calibrator, denoiser):
"""Init the grd filehandler."""
super().__init__(filename, filename_info, filetype_info)

self._start_time = filename_info["start_time"]
self._end_time = filename_info["end_time"]
self._start_time = filename_info["start_time"].replace(tzinfo=tz.utc)
self._end_time = filename_info["end_time"].replace(tzinfo=tz.utc)

self._polarization = filename_info["polarization"]

Expand All @@ -585,7 +591,6 @@ def __init__(self, filename, filename_info, filetype_info, calibrator, denoiser)
self.denoiser = denoiser
self.read_lock = Lock()

self.filehandle = rasterio.open(self.filename, "r", sharing=False)
self.get_lonlatalts = functools.lru_cache(maxsize=2)(
self._get_lonlatalts_uncached
)
Expand All @@ -606,18 +611,25 @@ def get_dataset(self, key, info):
data.attrs.update(info)

else:
data = xr.open_dataarray(self.filename, engine="rasterio",
chunks={"band": 1, "y": CHUNK_SIZE, "x": CHUNK_SIZE}).squeeze()
data = data.assign_coords(x=np.arange(len(data.coords["x"])),
y=np.arange(len(data.coords["y"])))
data = self._calibrate_and_denoise(data, key)
data = self._calibrate_and_denoise(self._data, key)
data.attrs.update(info)
data.attrs.update({"platform_name": self._mission_id})

data = self._change_quantity(data, key["quantity"])

return data

@cached_property
def _data(self):
data = xr.open_dataarray(self.filename, engine="rasterio",
chunks="auto"
).squeeze()
self.chunks = data.data.chunksize
data = data.assign_coords(x=np.arange(len(data.coords["x"])),
y=np.arange(len(data.coords["y"])))

return data

@staticmethod
def _change_quantity(data, quantity):
"""Change quantity to dB if needed."""
Expand All @@ -631,11 +643,9 @@ def _change_quantity(data, quantity):

def _calibrate_and_denoise(self, data, key):
"""Calibrate and denoise the data."""
chunks = CHUNK_SIZE

dn = self._get_digital_number(data)
dn = self.denoiser(dn, chunks)
data = self.calibrator(dn, key["calibration"], chunks)
dn = self.denoiser(dn, self.chunks)
data = self.calibrator(dn, key["calibration"], self.chunks)

return data

Expand All @@ -646,13 +656,6 @@ def _get_digital_number(self, data):
dn = data * data
return dn

def _denoise(self, dn, chunks):
"""Denoise the data."""
logger.debug("Reading noise data.")
noise = self.noise.get_noise_correction(chunks=chunks).fillna(0)
dn = dn - noise
return dn

def _get_lonlatalts_uncached(self):
"""Obtain GCPs and construct latitude and longitude arrays.
Expand All @@ -662,16 +665,16 @@ def _get_lonlatalts_uncached(self):
Returns:
coordinates (tuple): A tuple with longitude and latitude arrays
"""
band = self.filehandle
shape = self._data.shape

(xpoints, ypoints), (gcp_lons, gcp_lats, gcp_alts), (gcps, crs) = self.get_gcps()

# FIXME: do interpolation on cartesian coordinates if the area is
# problematic.

longitudes = interpolate_xarray(xpoints, ypoints, gcp_lons, band.shape)
latitudes = interpolate_xarray(xpoints, ypoints, gcp_lats, band.shape)
altitudes = interpolate_xarray(xpoints, ypoints, gcp_alts, band.shape)
longitudes = interpolate_xarray(xpoints, ypoints, gcp_lons, shape, self.chunks)
latitudes = interpolate_xarray(xpoints, ypoints, gcp_lats, shape, self.chunks)
altitudes = interpolate_xarray(xpoints, ypoints, gcp_alts, shape, self.chunks)

longitudes.attrs["gcps"] = gcps
longitudes.attrs["crs"] = crs
Expand All @@ -694,9 +697,12 @@ def get_gcps(self):
gcp_coords (tuple): longitude and latitude 1d arrays
"""
gcps = self.filehandle.gcps
gcps = self._data.coords["spatial_ref"].attrs["gcps"]
crs = self._data.rio.crs

gcp_array = np.array([(p.row, p.col, p.x, p.y, p.z) for p in gcps[0]])
gcp_list = [(feature["properties"]["row"], feature["properties"]["col"], *feature["geometry"]["coordinates"])
for feature in gcps["features"]]
gcp_array = np.array(gcp_list)

ypoints = np.unique(gcp_array[:, 0])
xpoints = np.unique(gcp_array[:, 1])
Expand All @@ -705,7 +711,10 @@ def get_gcps(self):
gcp_lats = gcp_array[:, 3].reshape(ypoints.shape[0], xpoints.shape[0])
gcp_alts = gcp_array[:, 4].reshape(ypoints.shape[0], xpoints.shape[0])

return (xpoints, ypoints), (gcp_lons, gcp_lats, gcp_alts), gcps
rio_gcps = [rasterio.control.GroundControlPoint(*gcp) for gcp in gcp_list]


return (xpoints, ypoints), (gcp_lons, gcp_lats, gcp_alts), (rio_gcps, crs)

@property
def start_time(self):
Expand All @@ -731,12 +740,12 @@ def __init__(self, config, filter_parameters=None):
@property
def start_time(self):
"""Get the start time."""
return self.storage_items.values()[0].filename_info["start_time"]
return self.storage_items.values()[0].filename_info["start_time"].replace(tzinfo=tz.utc)

@property
def end_time(self):
"""Get the end time."""
return self.storage_items.values()[0].filename_info["end_time"]
return self.storage_items.values()[0].filename_info["end_time"].replace(tzinfo=tz.utc)

def load(self, dataset_keys, **kwargs):
"""Load some data."""
Expand All @@ -752,20 +761,20 @@ def load(self, dataset_keys, **kwargs):
if key["name"] not in ["longitude", "latitude"]:
lonlats = self.load([DataID(self._id_keys, name="longitude", polarization=key["polarization"]),
DataID(self._id_keys, name="latitude", polarization=key["polarization"])])
gcps = val.coords["spatial_ref"].attrs["gcps"]
from pyresample.future.geometry import SwathDefinition
val.attrs["area"] = SwathDefinition(lonlats["longitude"], lonlats["latitude"],
attrs=dict(gcps=None))
attrs=dict(gcps=gcps))
datasets[key] = val
continue
return datasets

def create_storage_items(self, files, **kwargs):
"""Create the storage items."""
filenames = [os.fspath(filename) for filename in files]
filenames = files
files_by_type = defaultdict(list)
for file_type, type_info in self.config["file_types"].items():
files_by_type[file_type].extend(self.filename_items_for_filetype(filenames, type_info))

image_shapes = dict()
for annotation_file, annotation_info in files_by_type["safe_annotation"]:
annotation_fh = SAFEXMLAnnotation(annotation_file,
Expand Down
22 changes: 12 additions & 10 deletions satpy/tests/reader_tests/test_sar_c_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
dirname_suffix = "20190201T024655_20190201T024720_025730_02DC2A_AE07"
filename_suffix = "20190201t024655-20190201t024720-025730-02dc2a"

START_TIME = datetime(2019, 2, 1, 2, 46, 55)
END_TIME = datetime(2019, 2, 1, 2, 47, 20)

@pytest.fixture(scope="module")
def granule_directory(tmp_path_factory):
Expand All @@ -62,7 +64,7 @@ def annotation_file(granule_directory):
@pytest.fixture(scope="module")
def annotation_filehandler(annotation_file):
"""Create an annotation filehandler."""
filename_info = dict(start_time=None, end_time=None, polarization="vv")
filename_info = dict(start_time=START_TIME, end_time=END_TIME, polarization="vv")
return SAFEXMLAnnotation(annotation_file, filename_info, None)


Expand All @@ -74,16 +76,16 @@ def calibration_file(granule_directory):
calibration_file = cal_dir / f"calibration-s1a-iw-grd-vv-{filename_suffix}-001.xml"
with open(calibration_file, "wb") as fd:
fd.write(calibration_xml)
return calibration_file
return Path(calibration_file)

@pytest.fixture(scope="module")
def calibration_filehandler(calibration_file, annotation_filehandler):
"""Create a calibration filehandler."""
filename_info = dict(start_time=None, end_time=None, polarization="vv")
filename_info = dict(start_time=START_TIME, end_time=END_TIME, polarization="vv")
return Calibrator(calibration_file,
filename_info,
None,
image_shape=annotation_filehandler.image_shape)
filename_info,
None,
image_shape=annotation_filehandler.image_shape)

@pytest.fixture(scope="module")
def noise_file(granule_directory):
Expand All @@ -99,14 +101,14 @@ def noise_file(granule_directory):
@pytest.fixture(scope="module")
def noise_filehandler(noise_file, annotation_filehandler):
"""Create a noise filehandler."""
filename_info = dict(start_time=None, end_time=None, polarization="vv")
filename_info = dict(start_time=START_TIME, end_time=END_TIME, polarization="vv")
return Denoiser(noise_file, filename_info, None, image_shape=annotation_filehandler.image_shape)


@pytest.fixture(scope="module")
def noise_with_holes_filehandler(annotation_filehandler):
"""Create a noise filehandler from data with holes."""
filename_info = dict(start_time=None, end_time=None, polarization="vv")
filename_info = dict(start_time=START_TIME, end_time=END_TIME, polarization="vv")
noise_filehandler = Denoiser(BytesIO(noise_xml_with_holes),
filename_info, None,
image_shape=annotation_filehandler.image_shape)
Expand Down Expand Up @@ -151,13 +153,13 @@ def measurement_file(granule_directory):
crs="+proj=latlong",
gcps=gcps) as dst:
dst.write(Z, 1)
return filename
return Path(filename)


@pytest.fixture(scope="module")
def measurement_filehandler(measurement_file, noise_filehandler, calibration_filehandler):
"""Create a measurement filehandler."""
filename_info = {"mission_id": "S1A", "dataset_name": "foo", "start_time": 0, "end_time": 0,
filename_info = {"mission_id": "S1A", "dataset_name": "foo", "start_time": START_TIME, "end_time": END_TIME,
"polarization": "vv"}
filetype_info = None
from satpy.readers.sar_c_safe import SAFEGRD
Expand Down

0 comments on commit 4ce31f7

Please sign in to comment.