diff --git a/src/openeo_gfmap/fetching/generic.py b/src/openeo_gfmap/fetching/generic.py index 46edb74..21dc209 100644 --- a/src/openeo_gfmap/fetching/generic.py +++ b/src/openeo_gfmap/fetching/generic.py @@ -1,10 +1,10 @@ """ Generic extraction of features, supporting VITO backend. """ -from typing import Callable, Optional +from functools import partial +from typing import Callable import openeo -from openeo.rest import OpenEoApiError from openeo_gfmap.backend import Backend, BackendContext from openeo_gfmap.fetching import CollectionFetcher, FetchType, _log @@ -28,18 +28,15 @@ "vapour-pressure": "AGERA5-VAPOUR", "wind-speed": "AGERA5-WIND", } -KNOWN_UNTEMPORAL_COLLECTIONS = ["COPERNICUS_30"] -def _get_generic_fetcher( - collection_name: str, fetch_type: FetchType, backend: Backend -) -> Callable: - band_mapping: Optional[dict] = None - +def _get_generic_fetcher(collection_name: str, fetch_type: FetchType) -> Callable: if collection_name == "COPERNICUS_30": - band_mapping = BASE_DEM_MAPPING + BASE_MAPPING = BASE_DEM_MAPPING elif collection_name == "AGERA5": - band_mapping = BASE_WEATHER_MAPPING + BASE_MAPPING = BASE_WEATHER_MAPPING + else: + raise Exception("Please choose a valid collection.") def generic_default_fetcher( connection: openeo.Connection, @@ -48,34 +45,23 @@ def generic_default_fetcher( bands: list, **params, ) -> openeo.DataCube: - if band_mapping is not None: - bands = convert_band_names(bands, band_mapping) + bands = convert_band_names(bands, BASE_MAPPING) - if (collection_name in KNOWN_UNTEMPORAL_COLLECTIONS) and ( - temporal_extent is not None - ): + if (collection_name == "COPERNICUS_30") and (temporal_extent is not None): _log.warning( - "Ignoring the temporal extent provided by the user as the collection %s is known to be untemporal.", - collection_name, + "User set-up non None temporal extent for DEM collection. Ignoring it." ) temporal_extent = None - try: - cube = _load_collection( - connection, - bands, - collection_name, - spatial_extent, - temporal_extent, - fetch_type, - **params, - ) - except OpenEoApiError as e: - if "CollectionNotFound" in str(e): - raise ValueError( - f"Collection {collection_name} not found in the selected backend {backend.value}." - ) from e - raise e + cube = _load_collection( + connection, + bands, + collection_name, + spatial_extent, + temporal_extent, + fetch_type, + **params, + ) # # Apply if the collection is a GeoJSON Feature collection # if isinstance(spatial_extent, GeoJSON): @@ -90,11 +76,12 @@ def _get_generic_processor(collection_name: str, fetch_type: FetchType) -> Calla """Builds the preprocessing function from the collection name as it stored in the target backend. """ - band_mapping: Optional[dict] = None if collection_name == "COPERNICUS_30": - band_mapping = BASE_DEM_MAPPING + BASE_MAPPING = BASE_DEM_MAPPING elif collection_name == "AGERA5": - band_mapping = BASE_WEATHER_MAPPING + BASE_MAPPING = BASE_WEATHER_MAPPING + else: + raise Exception("Please choose a valid collection.") def generic_default_processor(cube: openeo.DataCube, **params): """Default collection preprocessing method for generic datasets. @@ -112,14 +99,51 @@ def generic_default_processor(cube: openeo.DataCube, **params): if collection_name == "COPERNICUS_30": cube = cube.min_time() - if band_mapping is not None: - cube = rename_bands(cube, band_mapping) + cube = rename_bands(cube, BASE_MAPPING) return cube return generic_default_processor +OTHER_BACKEND_MAP = { + "AGERA5": { + Backend.TERRASCOPE: { + "fetch": partial(_get_generic_fetcher, collection_name="AGERA5"), + "preprocessor": partial(_get_generic_processor, collection_name="AGERA5"), + }, + Backend.CDSE: { + "fetch": partial(_get_generic_fetcher, collection_name="AGERA5"), + "preprocessor": partial(_get_generic_processor, collection_name="AGERA5"), + }, + Backend.FED: { + "fetch": partial(_get_generic_fetcher, collection_name="AGERA5"), + "preprocessor": partial(_get_generic_processor, collection_name="AGERA5"), + }, + }, + "COPERNICUS_30": { + Backend.TERRASCOPE: { + "fetch": partial(_get_generic_fetcher, collection_name="COPERNICUS_30"), + "preprocessor": partial( + _get_generic_processor, collection_name="COPERNICUS_30" + ), + }, + Backend.CDSE: { + "fetch": partial(_get_generic_fetcher, collection_name="COPERNICUS_30"), + "preprocessor": partial( + _get_generic_processor, collection_name="COPERNICUS_30" + ), + }, + Backend.FED: { + "fetch": partial(_get_generic_fetcher, collection_name="COPERNICUS_30"), + "preprocessor": partial( + _get_generic_processor, collection_name="COPERNICUS_30" + ), + }, + }, +} + + def build_generic_extractor( backend_context: BackendContext, bands: list, @@ -128,7 +152,13 @@ def build_generic_extractor( **params, ) -> CollectionFetcher: """Creates a generic extractor adapted to the given backend. Currently only tested with VITO backend""" - fetcher = _get_generic_fetcher(collection_name, fetch_type, backend_context.backend) - preprocessor = _get_generic_processor(collection_name, fetch_type) + backend_functions = OTHER_BACKEND_MAP.get(collection_name).get( + backend_context.backend + ) + + fetcher, preprocessor = ( + backend_functions["fetch"](fetch_type=fetch_type), + backend_functions["preprocessor"](fetch_type=fetch_type), + ) return CollectionFetcher(backend_context, bands, fetcher, preprocessor, **params) diff --git a/src/openeo_gfmap/fetching/s1.py b/src/openeo_gfmap/fetching/s1.py index 97d6fc2..6081d40 100644 --- a/src/openeo_gfmap/fetching/s1.py +++ b/src/openeo_gfmap/fetching/s1.py @@ -67,6 +67,8 @@ def s1_grd_fetch_default( """ bands = convert_band_names(bands, BASE_SENTINEL1_GRD_MAPPING) + load_collection_parameters = params.get("load_collection", {}) + cube = _load_collection( connection, bands, @@ -74,7 +76,7 @@ def s1_grd_fetch_default( spatial_extent, temporal_extent, fetch_type, - **params, + **load_collection_parameters, ) if fetch_type is not FetchType.POINT and isinstance(spatial_extent, GeoJSON): diff --git a/src/openeo_gfmap/manager/job_manager.py b/src/openeo_gfmap/manager/job_manager.py index 986bbf4..9a9f7b2 100644 --- a/src/openeo_gfmap/manager/job_manager.py +++ b/src/openeo_gfmap/manager/job_manager.py @@ -1,11 +1,9 @@ import json -import pickle import threading -import time from concurrent.futures import ThreadPoolExecutor +from enum import Enum from functools import partial from pathlib import Path -from threading import Lock from typing import Callable, Optional, Union import pandas as pd @@ -18,62 +16,28 @@ from openeo_gfmap.stac import constants # Lock to use when writing to the STAC collection -_stac_lock = Lock() - - -def retry_on_exception(max_retries: int, delay_s: int = 180): - """Decorator to retry a function if an exception occurs. - Used for post-job actions that can crash due to internal backend issues. Restarting the action - usually helps to solve the issue. - - Parameters - ---------- - max_retries: int - The maximum number of retries to attempt before finally raising the exception. - delay: int (default=180 seconds) - The delay in seconds to wait before retrying the decorated function. - """ - - def decorator(func): - def wrapper(*args, **kwargs): - latest_exception = None - for _ in range(max_retries): - try: - return func(*args, **kwargs) - except Exception as e: - time.sleep( - delay_s - ) # Waits before retrying, while allowing other futures to run. - latest_exception = e - raise latest_exception - - return wrapper - - return decorator +_stac_lock = threading.Lock() def done_callback(future, df, idx): - """Changes the status of the job when the post-job action future is done.""" + """Sets the status of the job to the given status when the future is done.""" current_status = df.loc[idx, "status"] - exception = future.exception() - if exception is None: + if not future.exception(): if current_status == "postprocessing": df.loc[idx, "status"] = "finished" elif current_status == "postprocessing-error": df.loc[idx, "status"] = "error" - elif current_status == "running": - df.loc[idx, "status"] = "running" else: raise ValueError( f"Invalid status {current_status} for job {df.loc[idx, 'id']} for done_callback!" ) - else: - _log.exception( - "Exception occurred in post-job future for job %s:\n%s", - df.loc[idx, "id"], - exception, - ) - df.loc[idx, "status"] = "error" + + +class PostJobStatus(Enum): + """Indicates the workers if the job finished as sucessful or with an error.""" + + FINISHED = "finished" + ERROR = "error" class GFMAPJobManager(MultiBackendJobManager): @@ -89,50 +53,13 @@ def __init__( post_job_action: Optional[Callable] = None, poll_sleep: int = 5, n_threads: int = 1, + post_job_params: dict = {}, resume_postproc: bool = True, # If we need to check for post-job actions that crashed restart_failed: bool = False, # If we need to restart failed jobs - stac_enabled: bool = True, ): - """ - Initializes the GFMAP job manager. - - Parameters - ---------- - output_dir: Path - The base output directory where the results/stac/logs of the jobs will be stored. - output_path_generator: Callable - User defined function that generates the output path for the job results. Expects as - inputs the output directory, the index of the job in the job dataframe - and the row of the job, and returns the final path where to save a job result asset. - collection_id: Optional[str] - The ID of the STAC collection that is being generated. Can be left empty if the STAC - catalogue is not being generated or if it is being resumed from an existing catalogue. - collection_description: Optional[str] - The description of the STAC collection that is being generated. - stac: Optional[Union[str, Path]] - The path to the STAC collection to be saved or resumed. - If None, the default path will be used. - post_job_action: Optional[Callable] - A user defined function that will be called after a job is finished. It will receive - the list of items generated by the job and the row of the job, and should return the - updated list of items. - poll_sleep: int - The time in seconds to wait between polling the backend for job status. - n_threads: int - The number of threads to execute `on_job_done` and `on_job_error` functions. - resume_postproc: bool - If set to true, all `on_job_done` and `on_job_error` functions that failed are resumed. - restart_failed: bool - If set to true, all jobs that failed within the OpenEO backend are restarted. - stac_enabled: bool (default=True) - If the STAC generation is enabled or not. Disabling it will prevent the creation, - update and loading of the STAC collection. - """ self._output_dir = output_dir - self._catalogue_cache = output_dir / "catalogue_cache.bin" self.stac = stac - self.stac_enabled = stac_enabled self.collection_id = collection_id self.collection_description = collection_description @@ -147,73 +74,41 @@ def __init__( self._output_path_gen = output_path_generator self._post_job_action = post_job_action + self._post_job_params = post_job_params # Monkey patching the _normalize_df method to ensure we have no modification on the # geometry column MultiBackendJobManager._normalize_df = self._normalize_df super().__init__(poll_sleep) - if self.stac_enabled: - self._root_collection = self._initialize_stac() + self._root_collection = self._normalize_stac() - def _load_stac(self) -> Optional[pystac.Collection]: - """ - Loads the STAC collection from the cache, the specified `stac` path or the default path. - If no STAC collection is found, returns None. - """ + def _normalize_stac(self): default_collection_path = self._output_dir / "stac/collection.json" - if self._catalogue_cache.exists(): + if self.stac is not None: _log.info( - "Loading the STAC collection from the persisted binary file: %s.", - self._catalogue_cache, + f"Reloading the STAC collection from the provided path: {self.stac}." ) - with open(self._catalogue_cache, "rb") as file: - return pickle.load(file) - elif self.stac is not None: - _log.info( - "Reloading the STAC collection from the provided path: %s.", self.stac - ) - return pystac.read_file(str(self.stac)) + root_collection = pystac.read_file(str(self.stac)) elif default_collection_path.exists(): _log.info( - "Reload the STAC collection from the default path: %s.", - default_collection_path, + f"Reload the STAC collection from the default path: {default_collection_path}." ) self.stac = default_collection_path - return pystac.read_file(str(self.stac)) - - _log.info( - "No STAC collection found as cache, in the default path or in the provided path." - ) - return None - - def _create_stac(self) -> pystac.Collection: - """ - Creates and returns new STAC collection. The created stac collection will use the - `collection_id` and `collection_description` parameters set in the constructor. - """ - if self.collection_id is None: - raise ValueError( - "A collection ID is required to generate a STAC collection." - ) - collection = pystac.Collection( - id=self.collection_id, - description=self.collection_description, - extent=None, - ) - collection.license = constants.LICENSE - collection.add_link(constants.LICENSE_LINK) - collection.stac_extensions = constants.STAC_EXTENSIONS - return collection - - def _initialize_stac(self) -> pystac.Collection: - """ - Loads and returns if possible an existing stac collection, otherwise creates a new one. - """ - root_collection = self._load_stac() - if not root_collection: + root_collection = pystac.read_file(str(self.stac)) + else: _log.info("Starting a fresh STAC collection.") - root_collection = self._create_stac() + assert ( + self.collection_id is not None + ), "A collection ID is required to generate a STAC collection." + root_collection = pystac.Collection( + id=self.collection_id, + description=self.collection_description, + extent=None, + ) + root_collection.license = constants.LICENSE + root_collection.add_link(constants.LICENSE_LINK) + root_collection.stac_extensions = constants.STAC_EXTENSIONS return root_collection @@ -255,30 +150,16 @@ def _resume_postjob_actions(self, df: pd.DataFrame): job = connection.job(row.id) if row.status == "postprocessing": _log.info( - "Resuming postprocessing of job %s, queueing on_job_finished...", - row.id, - ) - future = self._executor.submit(self.on_job_done, job, row, _stac_lock) - future.add_done_callback( - partial( - done_callback, - df=df, - idx=idx, - ) + f"Resuming postprocessing of job {row.id}, queueing on_job_finished..." ) + future = self._executor.submit(self.on_job_done, job, row) + future.add_done_callback(partial(done_callback, df=df, idx=idx)) else: _log.info( - "Resuming postprocessing of job %s, queueing on_job_error...", - row.id, + f"Resuming postprocessing of job {row.id}, queueing on_job_error..." ) future = self._executor.submit(self.on_job_error, job, row) - future.add_done_callback( - partial( - done_callback, - df=df, - idx=idx, - ) - ) + future.add_done_callback(partial(done_callback, df=df, idx=idx)) self._futures.append(future) def _restart_failed_jobs(self, df: pd.DataFrame): @@ -286,9 +167,7 @@ def _restart_failed_jobs(self, df: pd.DataFrame): failed_tasks = df[df.status.isin(["error", "start_failed"])] not_started_tasks = df[df.status == "not_started"] _log.info( - "Resetting %s failed jobs to 'not_started'. %s jobs are already 'not_started'.", - len(failed_tasks), - len(not_started_tasks), + f"Resetting {len(failed_tasks)} failed jobs to 'not_started'. {len(not_started_tasks)} jobs are already 'not_started'." ) for idx, _ in failed_tasks.iterrows(): df.loc[idx, "status"] = "not_started" @@ -324,53 +203,38 @@ def _update_statuses(self, df: pd.DataFrame): job_metadata["status"] == "finished" ): _log.info( - "Job %s finished successfully, queueing on_job_done...", job.job_id + f"Job {job.job_id} finished successfully, queueing on_job_done..." ) job_status = "postprocessing" - future = self._executor.submit(self.on_job_done, job, row, _stac_lock) + future = self._executor.submit(self.on_job_done, job, row) # Future will setup the status to finished when the job is done - future.add_done_callback( - partial( - done_callback, - df=df, - idx=idx, - ) - ) + future.add_done_callback(partial(done_callback, df=df, idx=idx)) self._futures.append(future) - if "costs" in job_metadata: - df.loc[idx, "costs"] = job_metadata["costs"] - df.loc[idx, "memory"] = ( - job_metadata["usage"] - .get("max_executor_memory", {}) - .get("value", None) - ) - - else: - _log.warning( - "Costs not found in job %s metadata. Costs will be set to 'None'.", - job.job_id, - ) + df.loc[idx, "costs"] = job_metadata["costs"] + df.loc[idx, "memory"] = ( + job_metadata["usage"] + .get("max_executor_memory", {}) + .get("value", None) + ) + df.loc[idx, "cpu"] = ( + job_metadata["usage"].get("cpu", {}).get("value", None) + ) + df.loc[idx, "duration"] = ( + job_metadata["usage"].get("duration", {}).get("value", None) + ) # Case in which it failed if (df.loc[idx, "status"] != "error") and ( job_metadata["status"] == "error" ): _log.info( - "Job %s finished with error, queueing on_job_error...", - job.job_id, + f"Job {job.job_id} finished with error, queueing on_job_error..." ) job_status = "postprocessing-error" future = self._executor.submit(self.on_job_error, job, row) # Future will setup the status to error when the job is done - future.add_done_callback( - partial( - done_callback, - df=df, - idx=idx, - ) - ) + future.add_done_callback(partial(done_callback, df=df, idx=idx)) self._futures.append(future) - if "costs" in job_metadata: df.loc[idx, "costs"] = job_metadata["costs"] df.loc[idx, "status"] = job_status @@ -378,7 +242,6 @@ def _update_statuses(self, df: pd.DataFrame): # Clear the futures that are done and raise their potential exceptions if they occurred. self._clear_queued_actions() - @retry_on_exception(max_retries=2, delay_s=180) def on_job_error(self, job: BatchJob, row: pd.Series): """Method called when a job finishes with an error. @@ -389,14 +252,7 @@ def on_job_error(self, job: BatchJob, row: pd.Series): row: pd.Series The row in the dataframe that contains the job relative information. """ - try: - logs = job.logs() - except Exception as e: # pylint: disable=broad-exception-caught - _log.exception( - "Error getting logs in `on_job_error` for job %s:\n%s", job.job_id, e - ) - logs = [] - + logs = job.logs() error_logs = [log for log in logs if log.level.lower() == "error"] job_metadata = job.describe_job() @@ -415,21 +271,15 @@ def on_job_error(self, job: BatchJob, row: pd.Series): f"Couldn't find any error logs. Please check the error manually on job ID: {job.job_id}." ) - @retry_on_exception(max_retries=2, delay_s=30) - def on_job_done( - self, job: BatchJob, row: pd.Series, lock: Lock - ): # pylint: disable=arguments-differ + def on_job_done(self, job: BatchJob, row: pd.Series): """Method called when a job finishes successfully. It will first download the results of the job and then call the `post_job_action` method. """ - job_products = {} for idx, asset in enumerate(job.get_results().get_assets()): try: _log.debug( - "Generating output path for asset %s from job %s...", - asset.name, - job.job_id, + f"Generating output path for asset {asset.name} from job {job.job_id}..." ) output_path = self._output_path_gen(self._output_dir, idx, row) # Make the output path @@ -438,17 +288,11 @@ def on_job_done( # Add to the list of downloaded products job_products[f"{job.job_id}_{asset.name}"] = [output_path] _log.debug( - "Downloaded %s from job %s -> %s", - asset.name, - job.job_id, - output_path, + f"Downloaded {asset.name} from job {job.job_id} -> {output_path}" ) except Exception as e: _log.exception( - "Error downloading asset %s from job %s:\n%s", - asset.name, - job.job_id, - e, + f"Error downloading asset {asset.name} from job {job.job_id}", e ) raise e @@ -469,35 +313,53 @@ def on_job_done( asset.href = str( asset_path ) # Update the asset href to the output location set by the output_path_generator - + # item.id = f"{job.job_id}_{item.id}" # Add the item to the the current job items. job_items.append(item) - _log.info("Parsed item %s from job %s", item.id, job.job_id) + _log.info(f"Parsed item {item.id} from job {job.job_id}") except Exception as e: _log.exception( - "Error failed to add item %s from job %s to STAC collection:\n%s", - item.id, - job.job_id, + f"Error failed to add item {item.id} from job {job.job_id} to STAC collection", e, ) + raise e # _post_job_action returns an updated list of stac items. Post job action can therefore # update the stac items and access their products through the HREF. It is also the # reponsible of adding the appropriate metadata/assets to the items. if self._post_job_action is not None: - _log.debug("Calling post job action for job %s...", job.job_id) - job_items = self._post_job_action(job_items, row) + _log.debug(f"Calling post job action for job {job.job_id}...") + job_items = self._post_job_action(job_items, row, self._post_job_params) - _log.info("Adding %s items to the STAC collection...", len(job_items)) + _log.info(f"Adding {len(job_items)} items to the STAC collection...") - if self.stac_enabled: - with lock: - self._update_stac(job.job_id, job_items) + with _stac_lock: # Take the STAC lock to avoid concurrence issues + # Filters the job items to only keep the ones that are not already in the collection + existing_ids = [item.id for item in self._root_collection.get_all_items()] + job_items = [item for item in job_items if item.id not in existing_ids] - _log.info("Job %s and post job action finished successfully.", job.job_id) + self._root_collection.add_items(job_items) + _log.info(f"Added {len(job_items)} items to the STAC collection.") + + _log.info(f"Writing STAC collection for {job.job_id} to file...") + try: + self._write_stac() + except Exception as e: + _log.exception( + f"Error writing STAC collection for job {job.job_id} to file.", e + ) + raise e + _log.info(f"Wrote STAC collection for {job.job_id} to file.") + + _log.info(f"Job {job.job_id} and post job action finished successfully.") def _normalize_df(self, df: pd.DataFrame) -> pd.DataFrame: - """Ensure we have the required columns and the expected type for the geometry column.""" + """Ensure we have the required columns and the expected type for the geometry column. + + :param df: The dataframe to normalize. + :return: a new dataframe that is normalized. + """ + # check for some required columns. required_with_default = [ ("status", "not_started"), @@ -515,7 +377,7 @@ def _normalize_df(self, df: pd.DataFrame) -> pd.DataFrame: } df = df.assign(**new_columns) - _log.debug("Normalizing dataframe. Columns: %s", df.columns) + _log.debug(f"Normalizing dataframe. Columns: {df.columns}") return df @@ -550,7 +412,7 @@ def run_jobs( The file to track the results of the jobs. """ # Starts the thread pool to work on the on_job_done and on_job_error methods - _log.info("Starting ThreadPoolExecutor with %s workers.", self._n_threads) + _log.info(f"Starting ThreadPoolExecutor with {self._n_threads} workers.") with ThreadPoolExecutor(max_workers=self._n_threads) as executor: _log.info("Creating and running jobs.") self._executor = executor @@ -561,13 +423,6 @@ def run_jobs( self._wait_queued_actions() _log.info("Exiting ThreadPoolExecutor.") self._executor = None - _log.info("All jobs finished running.") - if self.stac_enabled: - _log.info("Saving persisted STAC collection to final .json collection.") - self._write_stac() - _log.info("Saved STAC catalogue to JSON format, all tasks finished!") - else: - _log.info("STAC was disabled, skipping generation of the catalogue.") def _write_stac(self): """Writes the STAC collection to the output directory.""" @@ -584,36 +439,6 @@ def _write_stac(self): self._root_collection.normalize_hrefs(str(root_path)) self._root_collection.save(catalog_type=CatalogType.SELF_CONTAINED) - def _persist_stac(self): - """Persists the STAC collection by saving it into a binary file.""" - _log.debug("Validating the STAC collection before persisting.") - self._root_collection.validate_all() - _log.info("Persisting STAC collection to temp file %s.", self._catalogue_cache) - with open(self._catalogue_cache, "wb") as file: - pickle.dump(self._root_collection, file) - - def _update_stac(self, job_id: str, job_items: list[pystac.Item]): - """Updates the STAC collection by adding the items generated by the job. - Does not add duplicates or override with the same item ID. - """ - try: - _log.info("Thread %s entered the STAC lock.", threading.get_ident()) - # Filters the job items to only keep the ones that are not already in the collection - existing_ids = [item.id for item in self._root_collection.get_all_items()] - job_items = [item for item in job_items if item.id not in existing_ids] - - self._root_collection.add_items(job_items) - _log.info("Added %s items to the STAC collection.", len(job_items)) - - self._persist_stac() - except Exception as e: - _log.exception( - "Error adding items to the STAC collection for job %s:\n%s ", - job_id, - str(e), - ) - raise e - def setup_stac( self, constellation: Optional[str] = None, diff --git a/src/openeo_gfmap/manager/job_splitters.py b/src/openeo_gfmap/manager/job_splitters.py index 19ede9c..7ed9a5f 100644 --- a/src/openeo_gfmap/manager/job_splitters.py +++ b/src/openeo_gfmap/manager/job_splitters.py @@ -2,7 +2,6 @@ form of a GeoDataFrames. """ -from functools import lru_cache from pathlib import Path from typing import List @@ -31,11 +30,6 @@ def load_s2_grid(web_mercator: bool = False) -> gpd.GeoDataFrame: url, timeout=180, # 3mins ) - if response.status_code != 200: - raise ValueError( - "Failed to download the S2 grid from the artifactory. " - f"Status code: {response.status_code}" - ) with open(gdf_path, "wb") as f: f.write(response.content) return gpd.read_parquet(gdf_path) @@ -72,10 +66,6 @@ def split_job_s2grid( if polygons.crs is None: raise ValueError("The GeoDataFrame must contain a CRS") -<<<<<<< HEAD - polygons = polygons.to_crs(epsg=4326) - polygons["geometry"] = polygons.geometry.centroid -======= epsg = 3857 if web_mercator else 4326 original_crs = polygons.crs @@ -83,26 +73,16 @@ def split_job_s2grid( polygons = polygons.to_crs(epsg=epsg) polygons["centroid"] = polygons.geometry.centroid ->>>>>>> 1110e4aa35cfbe72a9dbd9b56e40048ea40ca2d8 # Dataset containing all the S2 tiles, find the nearest S2 tile for each point s2_grid = load_s2_grid(web_mercator) s2_grid["geometry"] = s2_grid.geometry.centroid -<<<<<<< HEAD - # Filter tiles on CDSE availability - s2_grid = s2_grid[s2_grid.cdse_valid] - - polygons = gpd.sjoin_nearest(polygons, s2_grid[["tile", "geometry"]]).drop( - columns=["index_right"] - ) -======= polygons = gpd.sjoin_nearest( polygons.set_geometry("centroid"), s2_grid[["tile", "geometry"]] ).drop(columns=["index_right", "centroid"]) polygons = polygons.set_geometry("geometry").to_crs(original_crs) ->>>>>>> 1110e4aa35cfbe72a9dbd9b56e40048ea40ca2d8 split_datasets = [] for _, sub_gdf in polygons.groupby("tile"): diff --git a/src/openeo_gfmap/utils/catalogue.py b/src/openeo_gfmap/utils/catalogue.py index e31f3af..2553a5d 100644 --- a/src/openeo_gfmap/utils/catalogue.py +++ b/src/openeo_gfmap/utils/catalogue.py @@ -1,12 +1,9 @@ """Functionalities to interract with product catalogues.""" -from typing import Optional - import geojson import requests from pyproj.crs import CRS from rasterio.warp import transform_bounds -from requests import adapters from shapely.geometry import box, shape from shapely.ops import unary_union @@ -18,20 +15,6 @@ TemporalContext, ) -request_sessions: Optional[requests.Session] = None - - -def _request_session() -> requests.Session: - global request_sessions - - if request_sessions is None: - request_sessions = requests.Session() - retries = adapters.Retry( - total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504] - ) - request_sessions.mount("https://", adapters.HTTPAdapter(max_retries=retries)) - return request_sessions - class UncoveredS1Exception(Exception): """Exception raised when there is no product available to fully cover spatially a given @@ -56,14 +39,6 @@ def _query_cdse_catalogue( temporal_extent: TemporalContext, **additional_parameters: dict, ) -> dict: - """ - Queries the CDSE catalogue for a given collection, spatio-temporal context and additional - parameters. - - Params - ------ - - """ minx, miny, maxx, maxy = bounds # The date format should be YYYY-MM-DD @@ -73,14 +48,13 @@ def _query_cdse_catalogue( url = ( f"https://catalogue.dataspace.copernicus.eu/resto/api/collections/" f"{collection}/search.json?box={minx},{miny},{maxx},{maxy}" - f"&sortParam=startDate&maxRecords=1000&dataset=ESA-DATASET" - f"&startDate={start_date}&completionDate={end_date}" + f"&sortParam=startDate&maxRecords=100" + f"&dataset=ESA-DATASET&startDate={start_date}&completionDate={end_date}" ) for key, value in additional_parameters.items(): url += f"&{key}={value}" - session = _request_session() - response = session.get(url, timeout=60) + response = requests.get(url) if response.status_code != 200: raise Exception( @@ -133,20 +107,19 @@ def _check_cdse_catalogue( return len(grd_tiles) > 0 -def s1_area_per_orbitstate_vvvh( +def s1_area_per_orbitstate( backend: BackendContext, spatial_extent: SpatialContext, temporal_extent: TemporalContext, ) -> dict: - """ - Evaluates for both the ascending and descending state orbits the area of interesection for the - available products with a VV&VH polarisation. + """Evaluates for both the ascending and descending state orbits the area of interesection + between the given spatio-temporal context and the products available in the backend's + catalogue. Parameters ---------- backend : BackendContext - The backend to be within, as each backend might use different catalogues. Only the CDSE, - CDSE_STAGING and FED backends are supported. + The backend to be within, as each backend might use different catalogues. spatial_extent : SpatialContext The spatial extent to be checked, it will check within its bounding box. temporal_extent : TemporalContext @@ -186,11 +159,7 @@ def s1_area_per_orbitstate_vvvh( if backend.backend in [Backend.CDSE, Backend.CDSE_STAGING, Backend.FED]: ascending_products = _parse_cdse_products( _query_cdse_catalogue( - "Sentinel1", - bounds, - temporal_extent, - orbitDirection="ASCENDING", - polarisation="VV%26VH", + "Sentinel1", bounds, temporal_extent, orbitDirection="ASCENDING" ) ) descending_products = _parse_cdse_products( @@ -199,7 +168,6 @@ def s1_area_per_orbitstate_vvvh( bounds, temporal_extent, orbitDirection="DESCENDING", - polarisation="VV%26VH", ) ) else: @@ -236,19 +204,18 @@ def s1_area_per_orbitstate_vvvh( } -def select_s1_orbitstate_vvvh( +def select_S1_orbitstate( backend: BackendContext, spatial_extent: SpatialContext, temporal_extent: TemporalContext, ) -> str: - """Selects the orbit state that covers the most area of intersection for the - available products with a VV&VH polarisation. + """Selects the orbit state that covers the most area of the given spatio-temporal context + for the Sentinel-1 collection. Parameters ---------- backend : BackendContext - The backend to be within, as each backend might use different catalogues. Only the CDSE, - CDSE_STAGING and FED backends are supported. + The backend to be within, as each backend might use different catalogues. spatial_extent : SpatialContext The spatial extent to be checked, it will check within its bounding box. temporal_extent : TemporalContext @@ -261,7 +228,7 @@ def select_s1_orbitstate_vvvh( """ # Queries the products in the catalogues - areas = s1_area_per_orbitstate_vvvh(backend, spatial_extent, temporal_extent) + areas = s1_area_per_orbitstate(backend, spatial_extent, temporal_extent) ascending_overlap = areas["ASCENDING"]["full_overlap"] descending_overlap = areas["DESCENDING"]["full_overlap"] diff --git a/tests/test_openeo_gfmap/test_generic_fetchers.py b/tests/test_openeo_gfmap/test_generic_fetchers.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_openeo_gfmap/test_s1_fetchers.py b/tests/test_openeo_gfmap/test_s1_fetchers.py index 01e41e2..d805fb8 100644 --- a/tests/test_openeo_gfmap/test_s1_fetchers.py +++ b/tests/test_openeo_gfmap/test_s1_fetchers.py @@ -50,7 +50,7 @@ def sentinel1_grd( "elevation_model": "COPERNICUS_30", "coefficient": "gamma0-ellipsoid", "load_collection": { - "polarization": lambda polar: polar == "VV&VH", + "polarization": lambda polar: (polar == "VV") or (polar == "VH"), }, } @@ -156,7 +156,7 @@ def sentinel1_grd_point_based( "elevation_model": "COPERNICUS_30", "coefficient": "gamma0-ellipsoid", "load_collection": { - "polarization": lambda polar: polar == "VV&VH", + "polarization": lambda polar: (polar == "VV") or (polar == "VH"), }, } extractor = build_sentinel1_grd_extractor( diff --git a/tests/test_openeo_gfmap/test_utils.py b/tests/test_openeo_gfmap/test_utils.py index 95c241a..e63dc1f 100644 --- a/tests/test_openeo_gfmap/test_utils.py +++ b/tests/test_openeo_gfmap/test_utils.py @@ -7,10 +7,7 @@ from openeo_gfmap import Backend, BackendContext, BoundingBoxExtent, TemporalContext from openeo_gfmap.utils import split_collection_by_epsg, update_nc_attributes -from openeo_gfmap.utils.catalogue import ( - s1_area_per_orbitstate_vvvh, - select_s1_orbitstate_vvvh, -) +from openeo_gfmap.utils.catalogue import s1_area_per_orbitstate, select_S1_orbitstate # Region of Paris, France SPATIAL_CONTEXT = BoundingBoxExtent( @@ -24,7 +21,7 @@ def test_query_cdse_catalogue(): backend_context = BackendContext(Backend.CDSE) - response = s1_area_per_orbitstate_vvvh( + response = s1_area_per_orbitstate( backend=backend_context, spatial_extent=SPATIAL_CONTEXT, temporal_extent=TEMPORAL_CONTEXT, @@ -45,7 +42,7 @@ def test_query_cdse_catalogue(): assert response["DESCENDING"]["full_overlap"] is True # Testing the decision maker, it should return DESCENDING - decision = select_s1_orbitstate_vvvh( + decision = select_S1_orbitstate( backend=backend_context, spatial_extent=SPATIAL_CONTEXT, temporal_extent=TEMPORAL_CONTEXT,