Skip to content

Commit

Permalink
Issue #16 blackify src/ and tests/
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Feb 5, 2024
1 parent 28388e4 commit 4e30454
Show file tree
Hide file tree
Showing 19 changed files with 69 additions and 176 deletions.
8 changes: 2 additions & 6 deletions src/openeo_gfmap/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ class BackendContext:
backend: Backend


def _create_connection(
url: str, *, env_var_suffix: str, connect_kwargs: Optional[dict] = None
):
def _create_connection(url: str, *, env_var_suffix: str, connect_kwargs: Optional[dict] = None):
"""
Generic helper to create an openEO connection
with support for multiple client credential configurations from environment variables
Expand Down Expand Up @@ -63,9 +61,7 @@ def _create_connection(

# Use a shorter max poll time by default to alleviate the default impression that the test seem to hang
# because of the OIDC device code poll loop.
max_poll_time = int(
os.environ.get("OPENEO_OIDC_DEVICE_CODE_MAX_POLL_TIME") or 30
)
max_poll_time = int(os.environ.get("OPENEO_OIDC_DEVICE_CODE_MAX_POLL_TIME") or 30)
connection.authenticate_oidc(max_poll_time=max_poll_time)
return connection

Expand Down
4 changes: 1 addition & 3 deletions src/openeo_gfmap/extractions/s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def s2_l2a_fetch_default(
), "CRS not defined within GeoJSON collection."
spatial_extent = dict(spatial_extent)

cube = connection.load_collection(
collection_name, spatial_extent, temporal_extent, bands
)
cube = connection.load_collection(collection_name, spatial_extent, temporal_extent, bands)

# Apply if the collection is a GeoJSON Feature collection
if isinstance(spatial_extent, GeoJSON):
Expand Down
14 changes: 4 additions & 10 deletions src/openeo_gfmap/features/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ class FeatureExtractor(ABC):
point based extraction or dense Cubes for tile/polygon based extraction.
"""

def _common_preparations(
self, inarr: xr.DataArray, parameters: dict
) -> xr.DataArray:
def _common_preparations(self, inarr: xr.DataArray, parameters: dict) -> xr.DataArray:
"""Common preparations to be executed before the feature extractor is
executed. This method should be called by the `_execute` method of the
feature extractor.
Expand All @@ -102,8 +100,7 @@ def output_labels(self) -> list:

def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube:
raise NotImplementedError(
"FeatureExtractor is a base abstract class, please implement the "
"_execute method."
"FeatureExtractor is a base abstract class, please implement the " "_execute method."
)


Expand All @@ -130,8 +127,7 @@ def get_latlons(self, inarr: xr.DataArray) -> xr.DataArray:

if self.epsg is None:
raise Exception(
"EPSG code was not defined, cannot extract lat/lon array "
"as the CRS is unknown."
"EPSG code was not defined, cannot extract lat/lon array " "as the CRS is unknown."
)

# If the coordiantes are not in EPSG:4326, we need to reproject them
Expand Down Expand Up @@ -257,9 +253,7 @@ def apply_feature_extractor(
udf = openeo.UDF(code=udf_code, context=parameters)

cube = cube.apply_neighborhood(process=udf, size=size, overlap=overlap)
return cube.rename_labels(
dimension="bands", target=feature_extractor_class().output_labels()
)
return cube.rename_labels(dimension="bands", target=feature_extractor_class().output_labels())


def apply_feature_extractor_local(
Expand Down
3 changes: 1 addition & 2 deletions src/openeo_gfmap/fetching/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ def load_collection(
pre_mask = params.get("pre_mask", None)
if pre_mask is not None:
assert isinstance(pre_mask, openeo.DataCube), (
f"The 'pre_mask' parameter must be an openeo datacube, "
f"got {pre_mask}."
f"The 'pre_mask' parameter must be an openeo datacube, " f"got {pre_mask}."
)
cube = cube.mask(pre_mask.resample_cube_spatial(cube))

Expand Down
16 changes: 4 additions & 12 deletions src/openeo_gfmap/fetching/s1.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def s1_grd_fetch_default(
return s1_grd_fetch_default


def get_s1_grd_default_processor(
collection_name: str, fetch_type: FetchType
) -> Callable:
def get_s1_grd_default_processor(collection_name: str, fetch_type: FetchType) -> Callable:
"""Builds the preprocessing function from the collection name as it is stored
in the target backend.
"""
Expand All @@ -113,9 +111,7 @@ def s1_grd_default_processor(cube: openeo.DataCube, **params):
)

cube = resample_reproject(
cube,
params.get("target_resolution", 10.0),
params.get("target_crs", None)
cube, params.get("target_resolution", 10.0), params.get("target_crs", None)
)

cube = rename_bands(cube, BASE_SENTINEL1_GRD_MAPPING)
Expand All @@ -128,15 +124,11 @@ def s1_grd_default_processor(cube: openeo.DataCube, **params):
SENTINEL1_GRD_BACKEND_MAP = {
Backend.TERRASCOPE: {
"default": partial(get_s1_grd_default_fetcher, collection_name="SENTINEL1_GRD"),
"preprocessor": partial(
get_s1_grd_default_processor, collection_name="SENTINEL1_GRD"
),
"preprocessor": partial(get_s1_grd_default_processor, collection_name="SENTINEL1_GRD"),
},
Backend.CDSE: {
"default": partial(get_s1_grd_default_fetcher, collection_name="SENTINEL1_GRD"),
"preprocessor": partial(
get_s1_grd_default_processor, collection_name="SENTINEL1_GRD"
),
"preprocessor": partial(get_s1_grd_default_processor, collection_name="SENTINEL1_GRD"),
},
}

Expand Down
16 changes: 4 additions & 12 deletions src/openeo_gfmap/fetching/s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ def s2_l2a_fetch_default(
return s2_l2a_fetch_default


def get_s2_l2a_element84_fetcher(
collection_name: str, fetch_type: FetchType
) -> Callable:
def get_s2_l2a_element84_fetcher(collection_name: str, fetch_type: FetchType) -> Callable:
"""Fetches the collections from the Sentinel-2 Cloud-Optimized GeoTIFFs
bucket provided by Amazon and managed by Element84.
"""
Expand Down Expand Up @@ -157,9 +155,7 @@ def s2_l2a_element84_fetcher(
return s2_l2a_element84_fetcher


def get_s2_l2a_default_processor(
collection_name: str, fetch_type: FetchType
) -> Callable:
def get_s2_l2a_default_processor(collection_name: str, fetch_type: FetchType) -> Callable:
"""Builds the preprocessing function from the collection name as it stored
in the target backend.
"""
Expand Down Expand Up @@ -188,15 +184,11 @@ def s2_l2a_default_processor(cube: openeo.DataCube, **params):
SENTINEL2_L2A_BACKEND_MAP = {
Backend.TERRASCOPE: {
"fetch": partial(get_s2_l2a_default_fetcher, collection_name="SENTINEL2_L2A"),
"preprocessor": partial(
get_s2_l2a_default_processor, collection_name="SENTINEL2_L2A"
),
"preprocessor": partial(get_s2_l2a_default_processor, collection_name="SENTINEL2_L2A"),
},
Backend.CDSE: {
"fetch": partial(get_s2_l2a_default_fetcher, collection_name="SENTINEL2_L2A"),
"preprocessor": partial(
get_s2_l2a_default_processor, collection_name="SENTINEL2_L2A"
),
"preprocessor": partial(get_s2_l2a_default_processor, collection_name="SENTINEL2_L2A"),
},
}

Expand Down
4 changes: 1 addition & 3 deletions src/openeo_gfmap/inference/inference_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ class ModelInference(ABC):
methods and attributes to be used by other model inference classes.
"""

def _common_preparations(
self, inarr: xr.DataArray, parameters: dict
) -> xr.DataArray:
def _common_preparations(self, inarr: xr.DataArray, parameters: dict) -> xr.DataArray:
"""Common preparations for all inference models. This method will be
executed at the very beginning of the process.
"""
Expand Down
8 changes: 2 additions & 6 deletions src/openeo_gfmap/preprocessing/cloudmasking.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def mask_scl_dilation(cube: openeo.DataCube, **params: dict) -> openeo.DataCube:
)

nonoptical_cube = cube.filter_bands(
bands=list(
filter(lambda band: not band.startswith("S2"), cube.metadata.band_names)
)
bands=list(filter(lambda band: not band.startswith("S2"), cube.metadata.band_names))
)

optical_cube = optical_cube.process(
Expand Down Expand Up @@ -224,9 +222,7 @@ def bap_masking(cube: openeo.DataCube, period: Union[str, list], **params: dict)
)

nonoptical_cube = cube.filter_bands(
bands=list(
filter(lambda band: not band.startswith("S2"), cube.metadata.band_names)
)
bands=list(filter(lambda band: not band.startswith("S2"), cube.metadata.band_names))
)

rank_mask = get_bap_mask(optical_cube, period, **params)
Expand Down
1 change: 1 addition & 0 deletions src/openeo_gfmap/preprocessing/compositing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def median_compositing(cube: openeo.DataCube, period: Union[str, list]) -> opene
elif isinstance(period, list):
return cube.aggregate_temporal(intervals=period, reducer="median", dimension="t")


def mean_compositing(cube: openeo.DataCube, period: str) -> openeo.DataCube:
"""Perfrom mean compositing on the given datacube."""
if isinstance(period, str):
Expand Down
4 changes: 3 additions & 1 deletion src/openeo_gfmap/preprocessing/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import openeo


def linear_interpolation(cube: openeo.DataCube,) -> openeo.DataCube:
def linear_interpolation(
cube: openeo.DataCube,
) -> openeo.DataCube:
"""Perform linear interpolation on the given datacube."""
return cube.apply_dimension(dimension="t", process="array_interpolate_linear")
5 changes: 1 addition & 4 deletions src/openeo_gfmap/preprocessing/udf_rank.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import numpy as np
import xarray as xr
from openeo.udf import XarrayDataCube
Expand Down Expand Up @@ -31,9 +30,7 @@ def select_maximum(score: xr.DataArray):
# Convert YYYY-mm-dd to datetime64 objects
time_bins = [np.datetime64(interval[0]) for interval in intervals]

rank_mask = bap_score.groupby_bins('t', bins=time_bins).map(
select_maximum
)
rank_mask = bap_score.groupby_bins("t", bins=time_bins).map(select_maximum)
else:
raise ValueError("Period is not defined in the UDF. Cannot run it.")

Expand Down
18 changes: 7 additions & 11 deletions src/openeo_gfmap/preprocessing/udf_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,17 @@ def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube:
cube_array: xr.DataArray = cube.get_array()
cube_array = cube_array.transpose("t", "bands", "y", "x")

clouds = np.logical_or(
np.logical_and(cube_array < 11, cube_array >= 8), cube_array == 3
).isel(bands=0)
clouds = np.logical_or(np.logical_and(cube_array < 11, cube_array >= 8), cube_array == 3).isel(
bands=0
)

weights = [1, 0.8, 0.5]

# Calculate the Day Of Year score
times = cube_array.t.dt.day.values # returns day of the month for each date
sigma = 5
mu = 15
score_doy = (
1
/ (sigma * math.sqrt(2 * math.pi))
* np.exp(-0.5 * ((times - mu) / sigma) ** 2)
)
score_doy = 1 / (sigma * math.sqrt(2 * math.pi)) * np.exp(-0.5 * ((times - mu) / sigma) ** 2)
score_doy = np.broadcast_to(
score_doy[:, np.newaxis, np.newaxis],
[cube_array.sizes["t"], cube_array.sizes["y"], cube_array.sizes["x"]],
Expand Down Expand Up @@ -77,9 +73,9 @@ def erode(image, selem):
)

# Final score is weighted average
score = (
weights[0] * score_clouds + weights[1] * score_doy + weights[2] * score_cov
) / sum(weights)
score = (weights[0] * score_clouds + weights[1] * score_doy + weights[2] * score_cov) / sum(
weights
)
score = np.where(cube_array.values[:, 0, :, :] == 0, 0, score)

score_da = xr.DataArray(
Expand Down
10 changes: 3 additions & 7 deletions src/openeo_gfmap/utils/catalogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ def s1_area_per_orbitstate(
)
)
else:
raise NotImplementedError(
f"This feature is not supported for backend: {backend.backend}."
)
raise NotImplementedError(f"This feature is not supported for backend: {backend.backend}.")

# Builds the shape of the spatial extent and computes the area
spatial_extent = spatial_extent.to_geometry()
Expand All @@ -179,15 +177,13 @@ def s1_area_per_orbitstate(
"ASCENDING": {
"full_overlap": ascending_covers,
"area": sum(
product.intersection(spatial_extent).area
for product in ascending_products
product.intersection(spatial_extent).area for product in ascending_products
),
},
"DESCENDING": {
"full_overlap": descending_covers,
"area": sum(
product.intersection(spatial_extent).area
for product in descending_products
product.intersection(spatial_extent).area for product in descending_products
),
},
}
Expand Down
12 changes: 3 additions & 9 deletions src/openeo_gfmap/utils/tile_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,15 @@ def normalize_array(inarr: xr.DataArray, percentile: float = 0.99) -> xr.DataArr
def select_optical_bands(inarr: xr.DataArray) -> xr.DataArray:
"""Filters and keep only the optical bands for a given array."""
return inarr.sel(
bands=[
band for band in inarr.coords["bands"].to_numpy() if band.startswith("S2-B")
]
bands=[band for band in inarr.coords["bands"].to_numpy() if band.startswith("S2-B")]
)


def select_sar_bands(inarr: xr.DataArray) -> xr.DataArray:
"""Filters and keep only the SAR bands for a given array."""
return inarr.sel(
bands=[
band
for band in inarr.coords["bands"].to_numpy()
if band in ["VV", "VH", "HH", "HV"]
band for band in inarr.coords["bands"].to_numpy() if band in ["VV", "VH", "HH", "HV"]
]
)

Expand All @@ -47,9 +43,7 @@ def array_bounds(inarr: xr.DataArray) -> tuple:
)


def arrays_cosine_similarity(
first_array: xr.DataArray, second_array: xr.DataArray
) -> float:
def arrays_cosine_similarity(first_array: xr.DataArray, second_array: xr.DataArray) -> float:
"""Returns a similarity score based on normalized cosine distance. The
input arrays must have similar ranges to obtain a valid score.
1.0 represents the best score (same tiles), while 0.0 is the worst score.
Expand Down
Loading

0 comments on commit 4e30454

Please sign in to comment.