diff --git a/datacube-wps-config.yaml b/datacube-wps-config.yaml index 1231a70..64cfea1 100644 --- a/datacube-wps-config.yaml +++ b/datacube-wps-config.yaml @@ -30,7 +30,6 @@ processes: fuse_func: datacube_wps.processes.wofls_fuser style: - csv: None table: columns: Wet: @@ -107,7 +106,6 @@ processes: resolution: [30, -30] output_crs: EPSG:3577 style: - csv: None table: columns: Bare Soil: @@ -146,7 +144,6 @@ processes: resolution: [30, -30] style: - csv: None table: columns: Woodland: @@ -206,5 +203,33 @@ processes: - product: ga_ls_wo_3 measurements: [water] style: - csv: None - table: None + + + - process: datacube_wps.processes.ls_s2_fc_drill.LS_S2_FC_Drill + + about: + identifier: LS S2 FC Drill + version: '0.1' + title: Landsat/Sentinel-2 Fractional Cover Drill + abstract: Performs Landsat/Sentinel-2 Fractional Cover Drill + store_supported: False + status_supported: True + geometry_type: polygon + + input: + reproject: + output_crs: EPSG:3577 + resolution: [-30, 30] + resampling: nearest + input: + product: ls_s2_fc_c3 + measurements: [tc] + + style: + csv: True + table: + columns: + Total Cover %: + units: "#" + chartLineColor: "#3B7F00" + active: True diff --git a/datacube_wps/processes/__init__.py b/datacube_wps/processes/__init__.py index 36534b2..e682901 100644 --- a/datacube_wps/processes/__init__.py +++ b/datacube_wps/processes/__init__.py @@ -6,7 +6,6 @@ from collections import Counter import altair -# import altair_saver import boto3 import botocore import datacube @@ -18,7 +17,7 @@ import rasterio.features import xarray from botocore.client import Config -from dask.distributed import Client, worker_client +from dask.distributed import Client from datacube.utils.geometry import CRS, Geometry from datacube.utils.rio import configure_s3_access from datacube.virtual.impl import Product, Juxtapose @@ -77,17 +76,9 @@ def log_wrapper(*args, **kwargs): @log_call def _uploadToS3(filename, data, mimetype): - # AWS_S3_CREDS = { - # "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"), - # "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") - # } - # s3 = session.client("s3", **AWS_S3_CREDS) session = boto3.Session(profile_name="default") bucket = config.get_config_value("s3", "bucket") s3 = session.client("s3") - - # bucket = s3.Bucket('test-wps') - s3.upload_fileobj( data, bucket, @@ -95,8 +86,6 @@ def _uploadToS3(filename, data, mimetype): ExtraArgs={"ACL": "public-read", "ContentType": mimetype}, ) - print('Made it to before the presigned url generation') - bucket = config.get_config_value("s3", "bucket") # Create unsigned s3 client for determining public s3 url s3 = session.client("s3", config=Config(signature_version=botocore.UNSIGNED)) return s3.generate_presigned_url( @@ -108,14 +97,14 @@ def _uploadToS3(filename, data, mimetype): def upload_chart_html_to_S3(chart: altair.Chart, process_id: str): html_io = io.StringIO() - chart.save(html_io, format="html", engine="vl-convert") + chart.save(html_io, format="html")#, engine="vl-convert") html_bytes = io.BytesIO(html_io.getvalue().encode()) return _uploadToS3(process_id + "/chart.html", html_bytes, "text/html") def upload_chart_svg_to_S3(chart: altair.Chart, process_id: str): img_io = io.StringIO() - chart.save(img_io, format="svg", engine="vl-convert") + chart.save(img_io, format="svg")#, engine="vl-convert") img_bytes = io.BytesIO(img_io.getvalue().encode()) return _uploadToS3(process_id + "/chart.svg", img_bytes, "image/svg+xml") @@ -193,7 +182,6 @@ def _guard_rail(input, box): byte_count *= x byte_count *= sum(np.dtype(m.dtype).itemsize for m in measurement_dicts.values()) - print("byte count for query: ", byte_count) if byte_count > MAX_BYTES_IN_GB * GB: raise ProcessError( ("requested area requires {}GB data to load - " "maximum is {}GB").format( @@ -203,7 +191,6 @@ def _guard_rail(input, box): grouped = box.box - print("grouped shape", grouped.shape) assert len(grouped.shape) == 1 if grouped.shape[0] == 0: @@ -364,13 +351,7 @@ def _render_outputs( def _populate_response(response, outputs): - print('before response is populated') - print(response.outputs) - print('------------') for ident, output_value in outputs.items(): - print('TESTING') - print(ident) - print(output_value) if ident in response.outputs: if "data" in output_value: response.outputs[ident].data = output_value["data"] @@ -407,16 +388,6 @@ def __init__(self, about, input, style): self.style = style self.json_version = "v8" - # self.dask_client = dask_client = Client( - # n_workers=num_dask_workers(), processes=True, threads_per_worker=1 - # ) - - self.dask_enabled = True - - if self.dask_enabled: - # get the Dask Client associated with the current Gunicorn worker - self.dask_client = worker_client() - def input_formats(self): return [ ComplexInput( @@ -443,29 +414,35 @@ def request_handler(self, request, response): parameters = _get_parameters(request) result = self.query_handler(time, feature, parameters=parameters) - - if self.style['csv']: + + if 'csv' in self.style: outputs = self.render_outputs(result["data"], None) - - elif self.style['table']: + elif 'table' in self.style: outputs = self.render_outputs(result["data"], result["chart"]) - + raise ProcessError('No output style configured for process!') + _populate_response(response, outputs) return response @log_call - def query_handler(self, time, feature, parameters=None): + def query_handler(self, time, feature, dask_client=None, parameters=None): if parameters is None: parameters = {} - configure_s3_access( - # aws_unsigned=True, - region_name=os.getenv("AWS_DEFAULT_REGION", "auto"), - client=self.dask_client, - ) + if dask_client is None: + dask_client = Client( + n_workers=1, processes=False, threads_per_worker=num_workers() + ) + + with dask_client: + configure_s3_access( + aws_unsigned=True, + region_name=os.getenv("AWS_DEFAULT_REGION", "auto"), + client=dask_client, + ) - with datacube.Datacube() as dc: - data = self.input_data(dc, time, feature) + with datacube.Datacube() as dc: + data = self.input_data(dc, time, feature) df = self.process_data(data, {"time": time, "feature": feature, **parameters}) chart = self.render_chart(df) @@ -493,11 +470,8 @@ def input_data(self, dc, time, feature): lonlat = feature.coords[0] measurements = self.input.output_measurements(bag.product_definitions) - if self.dask_enabled: - data = self.input.fetch(box, dask_chunks={"time": 1}) - data = data.compute() - else: - data = self.input.fetch(box) + data = self.input.fetch(box, dask_chunks={"time": 1}) + data = data.compute() coords = { "longitude": np.array([lonlat[0]]), @@ -567,15 +541,6 @@ def __init__(self, about, input, style): self.mask_all_touched = False self.json_version = "v8" - # self.dask_client = dask_client = Client( - # n_workers=num_dask_workers(), processes=True, threads_per_worker=1 - # ) - self.dask_enabled = True - - if self.dask_enabled: - # get the Dask Client associated with the current Gunicorn worker - self.dask_client = worker_client() - def input_formats(self): return [ ComplexInput( @@ -604,90 +569,88 @@ def request_handler(self, request, response): result = self.query_handler(time, feature, parameters=parameters) - if self.style['csv']: + if 'csv' in self.style: outputs = self.render_outputs(result["data"], None) - - elif self.style['table']: + elif 'table' in self.style: outputs = self.render_outputs(result["data"], result["chart"]) - + else: + raise ProcessError('No output style configured for process!') + _populate_response(response, outputs) return response @log_call - def query_handler(self, time, feature, parameters=None): + def query_handler(self, time, feature, dask_client=None, parameters=None): if parameters is None: parameters = {} - configure_s3_access( - # aws_unsigned=True, - region_name=os.getenv("AWS_DEFAULT_REGION", "auto"), - client=self.dask_client, - ) + if dask_client is None: + dask_client = Client( + n_workers=num_workers(), processes=True, threads_per_worker=1 + ) + + with dask_client: + configure_s3_access( + aws_unsigned=True, + region_name=os.getenv("AWS_DEFAULT_REGION", "auto"), + client=dask_client, + ) - with datacube.Datacube() as dc: - data = self.input_data(dc, time, feature) + with datacube.Datacube() as dc: + data = self.input_data(dc, time, feature) df = self.process_data(data, {"time": time, "feature": feature, **parameters}) - + # If csv specified, return timeseries in csv form - if self.style['csv']: + if 'csv' in self.style: return {"data": df} - # If table style specified in config, return chart (static timeseries) - elif self.style['table'] is not None: + elif 'table' in self.style: chart = self.render_chart(df) return {"data": df, "chart": chart} - - + else: + return {} def input_data(self, dc, time, feature): if time is None: bag = self.input.query(dc, geopolygon=feature) else: bag = self.input.query(dc, time=time, geopolygon=feature) - + output_crs = self.input.get('output_crs') resolution = self.input.get('resolution') align = self.input.get('align') if not (output_crs and resolution): - print('parameters for Geobox not found in inputs') if type(self.input) in (Product,): - print('Checking grid_spec in product') - if bag.product_definitions[self.input._product].grid_spec: - print('grid_spec exists - do nothing') - else: + if not bag.product_definitions[self.input._product].grid_spec: output_crs = mostcommon_crs(list(bag.bag)) - elif type(self.input) in (Juxtapose,): - print('Checking grid_spec of each product') - print(list(bag.product_definitions.values())) - grid_specs = [product_definition.grid_spec for product_definition in list(bag.product_definitions.values()) if getattr(product_definition, 'grid_spec', None)] - if len(set(grid_specs)) == 1: - print('grid_spec exists for all products and are all the same - do nothing') - - elif len(set(grid_specs)) > 1: + if len(set(grid_specs)) > 1: raise ValueError('Multiple grid_spec detected across all products - override target output_crs, resolution in config') - else: if not resolution: raise ValueError('add target resolution to config') - elif not output_crs: output_crs = mostcommon_crs(bag.contained_datasets()) box = self.input.group(bag, output_crs=output_crs, resolution=resolution, align=align) if self.about.get("guard_rail", True): + # HACK: Get around issue where VirtualDatasetBox has a geobox but thinks it doesn't because load_natively flag is True. + # Need load_natively to be False to be able to call box.shape() inside guard_rail check function. + # Don't have time to understand how VirtualDatasets work and why this is happening in any more detail - just need the drill to work :) + run_hack = box.load_natively and box.geobox is not None + if run_hack: + load_natively = box.load_natively + box.load_natively = False _guard_rail(self.input, box) + if run_hack: + box.load_natively = load_natively # TODO customize the number of processes - if self.dask_enabled: - data = self.input.fetch(box, dask_chunks={"time": 1}) - else: - data = self.input.fetch(box) - + data = self.input.fetch(box, dask_chunks={"time": 1}) mask = geometry_mask( feature, data.geobox, all_touched=self.mask_all_touched, invert=True ) diff --git a/datacube_wps/processes/ls_s2_fc_drill.py b/datacube_wps/processes/ls_s2_fc_drill.py new file mode 100644 index 0000000..fd59b27 --- /dev/null +++ b/datacube_wps/processes/ls_s2_fc_drill.py @@ -0,0 +1,69 @@ +import altair +from math import ceil +import numpy as np +import xarray as xr +from pywps import ComplexOutput, LiteralOutput + +from . import FORMATS, PolygonDrill, chart_dimensions, log_call + + +class LS_S2_FC_Drill(PolygonDrill): + SHORT_NAMES = ['TC'] + LONG_NAMES = ['Total Cover'] + + def output_formats(self): + return [ + #LiteralOutput('image', 'Fractional Cover Polygon Drill Preview'), + #LiteralOutput('url', 'Fractional Cover Polygon Drill Graph'), + ComplexOutput('timeseries', 'Fractional Cover Polygon Drill Timeseries', supported_formats=[FORMATS['output_json']]) + ] + + @log_call + def process_data(self, data, parameters): # returns pandas.DataFrame + + NO_DATA = 255 + + mask_da = data['tc'] != NO_DATA + + masked_da = data['tc'].where(mask_da) + + mean_da = masked_da.mean(dim=['x','y'], skipna=True).compute() + + df = mean_da.to_dataframe() + df = df.drop('spatial_ref', axis=1) + df.reset_index(inplace=True) + + return df + + def render_chart(self, df): + + MONTHS_IN_YEAR = 12 + QUARTERS_IN_YEAR = 4 + + width, height = chart_dimensions(self.style) + + chart = altair.Chart(df, + width=width, + height=height, + title='Mean Percentage of Total Cover') + + chart = chart.mark_line() + + n_time_ticks = ceil(df.shape[0] / MONTHS_IN_YEAR) * QUARTERS_IN_YEAR + + try: + line_colour = self.style['table']['columns']['Total Cover %']['chartLineColor'] + except KeyError: + line_colour = '#3B7F00' + + chart = chart.encode( + x=altair.X('time:T', axis=altair.Axis(title='Time', format='%b %Y', tickCount=n_time_ticks)), + y=altair.Y('tc:Q', axis=altair.Axis(title='Mean TC%')), + color=altair.ColorValue(line_colour) + ) + + return chart + + def render_outputs(self, df, chart): + + return super().render_outputs(df, chart, is_enabled=True, name='tc', header=self.LONG_NAMES) diff --git a/datacube_wps/processes/lsfcdrill.py b/datacube_wps/processes/lsfcdrill.py index 3edcb3e..373d7e6 100644 --- a/datacube_wps/processes/lsfcdrill.py +++ b/datacube_wps/processes/lsfcdrill.py @@ -69,12 +69,7 @@ def process_data(self, data, parameters): 'Unobservable': (not_pixels / total_valid)['bs'] * 100 }) - if self.dask_client: - print('dask compute') - dask_time = default_timer() - new_ds = new_ds.compute() - print('dask took', default_timer() - dask_time, 'seconds') - print(new_ds) + new_ds = new_ds.compute() df = new_ds.to_dataframe() df = df.drop('spatial_ref', axis=1) diff --git a/datacube_wps/processes/mangrovedrill.py b/datacube_wps/processes/mangrovedrill.py index cf4f94b..93d1eeb 100644 --- a/datacube_wps/processes/mangrovedrill.py +++ b/datacube_wps/processes/mangrovedrill.py @@ -7,8 +7,7 @@ class MangroveDrill(PolygonDrill): @log_call def process_data(self, data, parameters): - if self.dask_client: - data = data.compute() + data = data.compute() # TODO raise ProcessError('query returned no data') when appropriate woodland = data.where(data == 1).count(['x', 'y']) diff --git a/gunicorn.conf.py b/gunicorn.conf.py index 72f38eb..68b534e 100644 --- a/gunicorn.conf.py +++ b/gunicorn.conf.py @@ -1,31 +1,17 @@ -import os -import gevent.monkey -gevent.monkey.patch_all() +#import gevent.monkey +#gevent.monkey.patch_all() from prometheus_flask_exporter.multiprocess import \ GunicornInternalPrometheusMetrics from datacube_wps.startup_utils import get_pod_vcpus -from dask.distributed import LocalCluster, Client # Settings, https://docs.gunicorn.org/en/stable/settings.html # Check what the server sees: gunicorn --print-config datacube_wps:app timeout = 600 # 10 mins -worker_class = 'gevent' +#worker_class = 'gevent' workers = get_pod_vcpus() * 2 + 1 reload = True -def _num_dask_workers(): - """Number of dask workers""" - return int(os.getenv("DATACUBE_WPS_NUM_WORKERS", "4")) - -def _create_dask_cluster(): - cluster = LocalCluster(n_workers=_num_dask_workers(), scheduler_port=0, threads_per_worker=1) - return cluster - -def post_fork(server, worker): - worker.dask_cluster = _create_dask_cluster() - worker.dask_client = Client(worker.dask_cluster) - def child_exit(server, worker): GunicornInternalPrometheusMetrics.mark_process_dead_on_child_exit(worker.pid)