Skip to content

Commit

Permalink
Merge branch 'devel' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
cdbethune committed Sep 25, 2024
2 parents 1efd51f + 1479d75 commit 79618c2
Show file tree
Hide file tree
Showing 75 changed files with 1,835 additions and 1,918 deletions.
Binary file modified .DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ pipelines/metadata_extraction/working/
pipelines/metadata_extraction/deploy/**/
docker-compose.yml
**/results_viz/*.tif
results_*/
41 changes: 33 additions & 8 deletions cdr/request_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Optional

from pika.adapters.blocking_connection import BlockingChannel as Channel
from pika import BlockingConnection, ConnectionParameters
from pika import BlockingConnection, ConnectionParameters, PlainCredentials
from pika.exceptions import AMQPChannelError, AMQPConnectionError
from tasks.common.queue import Request

Expand All @@ -23,10 +23,22 @@ class LaraRequestPublisher:
HEARTBEAT_INTERVAL = 900
BLOCKED_CONNECTION_TIMEOUT = 600

def __init__(self, request_queues: List[str], host="localhost") -> None:
def __init__(
self,
request_queues: List[str],
host="localhost",
port=5672,
vhost="/",
uid="",
pwd="",
) -> None:
self._request_connection: Optional[BlockingConnection] = None
self._request_channel: Optional[Channel] = None
self._host = host
self._port = port
self._vhost = vhost
self._uid = uid
self._pwd = pwd
self._request_queues = request_queues

def start_lara_request_queue(self):
Expand Down Expand Up @@ -77,13 +89,26 @@ def _create_channel(self) -> Channel:
The created channel.
"""
logger.info(f"creating channel on host {self._host}")
connection = BlockingConnection(
ConnectionParameters(
self._host,
heartbeat=self.HEARTBEAT_INTERVAL,
blocked_connection_timeout=self.BLOCKED_CONNECTION_TIMEOUT,
if self._uid != "":
credentials = PlainCredentials(self._uid, self._pwd)
connection = BlockingConnection(
ConnectionParameters(
self._host,
self._port,
self._vhost,
credentials,
heartbeat=self.HEARTBEAT_INTERVAL,
blocked_connection_timeout=self.BLOCKED_CONNECTION_TIMEOUT,
)
)
else:
connection = BlockingConnection(
ConnectionParameters(
self._host,
heartbeat=self.HEARTBEAT_INTERVAL,
blocked_connection_timeout=self.BLOCKED_CONNECTION_TIMEOUT,
)
)
)
channel = connection.channel()
for queue in self._request_queues:
channel.queue_declare(
Expand Down
154 changes: 52 additions & 102 deletions cdr/result_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import pika.spec as spec
from pydantic import BaseModel
from regex import P
from cdr.json_log import JSONLog
from cdr.request_publisher import LaraRequestPublisher
from schema.cdr_schemas.feature_results import FeatureResults
from schema.cdr_schemas.georeference import GeoreferenceResults, GroundControlPoint
Expand All @@ -34,7 +33,11 @@
RequestResult,
)
from schema.mappers.cdr import GeoreferenceMapper, get_mapper
from tasks.geo_referencing.entities import GeoreferenceResult as LARAGeoreferenceResult
from tasks.geo_referencing.entities import (
GeoreferenceResult as LARAGeoreferenceResult,
GroundControlPoint as LARAGroundControlPoint,
)
from tasks.geo_referencing.util import cps_to_transform, project_image
from tasks.metadata_extraction.entities import MetadataExtraction as LARAMetadata
from tasks.point_extraction.entities import PointLabels as LARAPoints
from tasks.segmentation.entities import MapSegmentation as LARASegmentation
Expand Down Expand Up @@ -91,10 +94,10 @@ class LaraResultSubscriber:

# map of pipeline name to system version
PIPELINE_SYSTEM_VERSIONS = {
SEGMENTATION_PIPELINE: "0.0.4",
METADATA_PIPELINE: "0.0.4",
POINTS_PIPELINE: "0.0.4",
GEOREFERENCE_PIPELINE: "0.0.5",
SEGMENTATION_PIPELINE: "0.0.5",
METADATA_PIPELINE: "0.0.5",
POINTS_PIPELINE: "0.0.5",
GEOREFERENCE_PIPELINE: "0.0.6",
}

def __init__(
Expand All @@ -105,7 +108,6 @@ def __init__(
cdr_token: str,
output: str,
workdir: str,
json_log: JSONLog,
host="localhost",
pipeline_sequence: List[str] = DEFAULT_PIPELINE_SEQUENCE,
) -> None:
Expand All @@ -117,7 +119,6 @@ def __init__(
self._cdr_token = cdr_token
self._workdir = workdir
self._output = output
self._json_log = json_log
self._host = host
self._pipeline_sequence = (
pipeline_sequence
Expand Down Expand Up @@ -231,11 +232,6 @@ def _process_lara_result(
case _:
logger.info("unsupported output type received from queue")

self._json_log.log(
"result",
{"type": result.output_type, "cog_id": result.request.image_id},
)

# in the serial case we call the next pipeline in the sequence
if self._request_publisher:
# find the next pipeline
Expand Down Expand Up @@ -357,16 +353,29 @@ def _push_georeferencing(self, result: RequestResult):
gcps = cdr_result.gcps
output_file_name = projection.file_name
output_file_name_full = os.path.join(self._workdir, output_file_name)

assert gcps is not None
lara_gcps = [
LARAGroundControlPoint(
id=f"gcp.{i}",
pixel_x=gcp.px_geom.columns_from_left,
pixel_y=gcp.px_geom.rows_from_top,
latitude=gcp.map_geom.latitude if gcp.map_geom.latitude else 0,
longitude=gcp.map_geom.longitude if gcp.map_geom.longitude else 0,
confidence=gcp.confidence if gcp.confidence else 0,
)
for i, gcp in enumerate(gcps)
]

logger.info(
f"projecting image {result.image_path} to {output_file_name_full} using crs {GeoreferenceMapper.DEFAULT_OUTPUT_CRS}"
)
self._project_georeference(
result.image_path,
output_file_name_full,
projection.crs,
GeoreferenceMapper.DEFAULT_OUTPUT_CRS,
gcps,
lara_gcps,
)

files_.append(
Expand Down Expand Up @@ -415,6 +424,35 @@ def _push_georeferencing(self, result: RequestResult):
except:
logger.info("error when attempting to submit georeferencing results")

def _project_georeference(
self,
source_image_path: str,
target_image_path: str,
source_crs: str,
target_crs: str,
gcps: List[LARAGroundControlPoint],
):
"""
Projects an image to a new coordinate reference system using ground control points.
Args:
source_image_path (str): The path to the source image.
target_image_path (str): The path to the target image.
target_crs (str): The target coordinate reference system.
gcps (List[GroundControlPoint]): The ground control points.
"""
# open the image
img = Image.open(source_image_path)

# create the transform and use it to project the image
geo_transform = cps_to_transform(gcps, source_crs, target_crs)
image_bytes = project_image(img, geo_transform, target_crs)

# write the projected image to disk, creating the directory if it doesn't exist
os.makedirs(os.path.dirname(target_image_path), exist_ok=True)
with open(target_image_path, "wb") as f:
f.write(image_bytes.getvalue())

def _push_features(self, result: RequestResult, model: FeatureResults):
"""
Pushes the features result to the CDR
Expand Down Expand Up @@ -517,91 +555,3 @@ def _push_metadata(self, result: RequestResult):
)

self._push_features(result, final_result)

def _project_georeference(
self,
source_image_path: str,
target_image_path: str,
target_crs: str,
gcps: List[GroundControlPoint],
):
# open the image
img = Image.open(source_image_path)
_, height = img.size

# create the transform
geo_transform = self._cps_to_transform(gcps, height=height, to_crs=target_crs)

# use the transform to project the image
self._project_image(
source_image_path, target_image_path, geo_transform, target_crs
)

def _project_image(
self,
source_image_path: str,
target_image_path: str,
geo_transform: Affine,
crs: str,
):
with rio.open(source_image_path) as raw:
bounds = riot.array_bounds(raw.height, raw.width, geo_transform)
pro_transform, pro_width, pro_height = calculate_default_transform(
crs, crs, raw.width, raw.height, *tuple(bounds)
)
pro_kwargs = raw.profile.copy()
pro_kwargs.update(
{
"driver": "COG",
"crs": {"init": crs},
"transform": pro_transform,
"width": pro_width,
"height": pro_height,
}
)
_raw_data = raw.read()
with rio.open(target_image_path, "w", **pro_kwargs) as pro:
for i in range(raw.count):
_ = reproject(
source=_raw_data[i],
destination=rio.band(pro, i + 1),
src_transform=geo_transform,
src_crs=crs,
dst_transform=pro_transform,
dst_crs=crs,
resampling=Resampling.bilinear,
num_threads=8,
warp_mem_limit=256,
)

def _cps_to_transform(
self, gcps: List[GroundControlPoint], height: int, to_crs: str
) -> Affine:
cps = [
{
"row": float(gcp.px_geom.rows_from_top),
"col": float(gcp.px_geom.columns_from_left),
"x": float(gcp.map_geom.longitude), # type: ignore
"y": float(gcp.map_geom.latitude), # type: ignore
"crs": gcp.crs,
}
for gcp in gcps
]
cps_p = []
for cp in cps:
if cp["crs"] != to_crs:
proj = Transformer.from_crs(cp["crs"], to_crs, always_xy=True)
x_p, y_p = proj.transform(xx=cp["x"], yy=cp["y"])
cps_p.append(
riot.GroundControlPoint(row=cp["row"], col=cp["col"], x=x_p, y=y_p)
)
else:
cps_p.append(
riot.GroundControlPoint(
row=cp["row"], col=cp["col"], x=cp["x"], y=cp["y"]
)
)
print("cps_p:")
pprint.pprint(cps_p)

return riot.from_gcps(cps_p)
13 changes: 8 additions & 5 deletions cdr/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from flask import Flask, request, Response

from cdr.json_log import JSONLog
from cdr.request_publisher import LaraRequestPublisher
from cdr.result_subscriber import LaraResultSubscriber
from tasks.common.io import download_file
Expand Down Expand Up @@ -56,7 +55,6 @@ class Settings:
callback_url: str
registration_id: Dict[str, str] = {}
rabbitmq_host: str
json_log: JSONLog
serial: bool
sequence: List[str] = []

Expand All @@ -82,7 +80,6 @@ def prefetch_image(working_dir: Path, image_id: str, image_url: str) -> None:
def process_cdr_event():
logger.info("event callback started")
evt = request.get_json(force=True)
settings.json_log.log("event", evt)
logger.info(f"event data received {evt['event']}")
lara_reqs: Dict[str, Request] = {}

Expand Down Expand Up @@ -317,6 +314,10 @@ def main():
parser.add_argument("--imagedir", type=str, required=True)
parser.add_argument("--cog_id", type=str, required=False)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--rabbit_port", type=int, default=5672)
parser.add_argument("--rabbit_vhost", type=str, default="/")
parser.add_argument("--rabbit_uid", type=str, default="")
parser.add_argument("--rabbit_pwd", type=str, default="")
parser.add_argument("--cdr_event_log", type=str, default=CDR_EVENT_LOG)
parser.add_argument("--input", type=str, default=None)
parser.add_argument("--output", type=str, default=None)
Expand All @@ -335,7 +336,6 @@ def main():
settings.callback_secret = CDR_CALLBACK_SECRET
settings.serial = True
settings.sequence = p.sequence
settings.json_log = JSONLog(os.path.join(p.workdir, p.cdr_event_log))

# check parameter consistency: either the mode is process and a cog id is supplied or the mode is host without a cog id
if p.mode == "process":
Expand All @@ -358,6 +358,10 @@ def main():
METADATA_REQUEST_QUEUE,
],
host=p.host,
port=p.rabbit_port,
vhost=p.rabbit_vhost,
uid=p.rabbit_uid,
pwd=p.rabbit_pwd,
)
request_publisher.start_lara_request_queue()

Expand All @@ -370,7 +374,6 @@ def main():
settings.cdr_api_token,
settings.output,
settings.workdir,
settings.json_log,
host=p.host,
pipeline_sequence=settings.sequence,
)
Expand Down
Loading

0 comments on commit 79618c2

Please sign in to comment.