diff --git a/lib/inat_inferrer.py b/lib/inat_inferrer.py index dfcb608..5ad56b7 100644 --- a/lib/inat_inferrer.py +++ b/lib/inat_inferrer.py @@ -144,6 +144,9 @@ def setup_elevation_dataframe(self): self.geo_elevation_cells = InatInferrer.add_lat_lng_to_h3_geo_dataframe( self.geo_elevation_cells ) + self.geo_elevation_cell_indices = { + index: idx for idx, index in enumerate(self.geo_elevation_cells.index) + } def setup_elevation_dataframe_from_worldclim(self, resolution): # preventing from processing at too high a resolution @@ -414,6 +417,21 @@ def aggregate_results(self, leaf_scores, debug=False, # InatInferrer.print_aggregated_scores(all_node_scores) return all_node_scores + def h3_04_geo_results_for_taxon_and_cell(self, taxon_id, lat, lng): + try: + taxon = self.taxonomy.df.loc[taxon_id] + except: + return None + # print(f"taxon `{taxon_id}` does not exist in the taxonomy") + # raise e + if pd.isna(taxon["leaf_class_id"]): + return None + h3_cell = h3.geo_to_h3(float(lat), float(lng), 4) + return float(self.geo_elevation_model.eval_one_class_elevation_from_features( + [self.geo_model_features[self.geo_elevation_cell_indices[h3_cell]]], + int(taxon["leaf_class_id"]) + )[0][0]) / taxon["geo_threshold"] + def h3_04_geo_results_for_taxon(self, taxon_id, bounds=[], thresholded=False, raw_results=False): if (self.geo_elevation_cells is None) or (self.geo_elevation_model is None): diff --git a/lib/inat_vision_api.py b/lib/inat_vision_api.py index a590484..3dde5fe 100644 --- a/lib/inat_vision_api.py +++ b/lib/inat_vision_api.py @@ -28,6 +28,8 @@ def __init__(self, config): self.h3_04_taxon_range_comparison_route, methods=["GET"]) self.app.add_url_rule("/h3_04_bounds", "h3_04_bounds", self.h3_04_bounds_route, methods=["GET"]) + self.app.add_url_rule("/geo_scores_for_taxa", "geo_scores_for_taxa", + self.geo_scores_for_taxa_route, methods=["POST"]) self.app.add_url_rule("/build_info", "build_info", self.build_info_route, methods=["GET"]) def setup_inferrer(self, config): @@ -86,6 +88,14 @@ def build_info_route(self): "build_date": os.getenv("BUILD_DATE", "") } + def geo_scores_for_taxa_route(self): + return { + obs["id"]: self.inferrer.h3_04_geo_results_for_taxon_and_cell( + obs["taxon_id"], obs["lat"], obs["lng"] + ) + for obs in request.json["observations"] + } + def index_route(self): form = ImageForm() if "observation_id" in request.args: