From b4c6a23028f3127a54c76d7c6d3c03b17b27a09a Mon Sep 17 00:00:00 2001 From: perhalvorsen <31341520+pmhalvor@users.noreply.github.com> Date: Sat, 5 Oct 2024 10:08:33 +0200 Subject: [PATCH] Apply inline suggestions from code review --- src/config/common.yaml | 2 +- src/config/local.yaml | 2 +- src/model_server.py | 3 +-- src/stages/classify.py | 2 +- src/stages/postprocess.py | 6 ++---- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/config/common.yaml b/src/config/common.yaml index 50810a3..4abb209 100644 --- a/src/config/common.yaml +++ b/src/config/common.yaml @@ -5,7 +5,7 @@ pipeline: show_plots: false is_local: false - # gcp + # gcp - bigquery project: "bioacoustics-2024" dataset_id: "whale_speech" diff --git a/src/config/local.yaml b/src/config/local.yaml index 94dd9e7..fe7ca78 100644 --- a/src/config/local.yaml +++ b/src/config/local.yaml @@ -3,7 +3,7 @@ pipeline: verbose: true debug: true show_plots: false - # is_local: true + is_local: true search: export_template: "data/encounters/{filename}-{timeframe}.csv" diff --git a/src/model_server.py b/src/model_server.py index 9e2997d..d4b613d 100644 --- a/src/model_server.py +++ b/src/model_server.py @@ -10,8 +10,7 @@ # Load the TensorFlow model logging.info("Loading model...") -model = hub.load("https://www.kaggle.com/models/google/humpback-whale/TensorFlow2/humpback-whale/1") -# model = hub.load("https://tfhub.dev/google/humpback_whale/1") +model = hub.load(config.classify.model_uri) score_fn = model.signatures["score"] logging.info("Model loaded.") diff --git a/src/stages/classify.py b/src/stages/classify.py index e340349..429a39d 100644 --- a/src/stages/classify.py +++ b/src/stages/classify.py @@ -32,7 +32,7 @@ def __init__(self, config: SimpleNamespace): self.batch_duration = config.classify.batch_duration self.model_sample_rate = config.classify.model_sample_rate - self.inference_url = config.classify.inference_url + self.inference_url = config.classify.inference_url # plotting parameters self.hydrophone_sensitivity = config.classify.hydrophone_sensitivity diff --git a/src/stages/postprocess.py b/src/stages/postprocess.py index 76c0eb8..5d863f7 100644 --- a/src/stages/postprocess.py +++ b/src/stages/postprocess.py @@ -4,14 +4,12 @@ import pandas as pd import os -# from google.cloud import bigquery from apache_beam.io.gcp.internal.clients import bigquery from typing import Dict, Any, Tuple from types import SimpleNamespace -# class PostprocessLabels(beam.PTransform): class PostprocessLabels(beam.DoFn): def __init__(self, config: SimpleNamespace): self.config = config @@ -30,7 +28,7 @@ def process(self, element, search_output): # convert element to dataframe classifications_df = self._build_classification_df(element) - # convert search_output to dataframe + # clean up search_output dataframe search_output_df = self._build_search_output_df(search_output) # join dataframes @@ -73,7 +71,6 @@ def _build_search_output_df(self, search_output: Dict[str, Any]) -> pd.DataFrame search_output = search_output.rename(columns={"id": "encounter_id"}) search_output["encounter_id"] = search_output["encounter_id"].astype(str) search_output = search_output[[ - # TODO refactor to confing "encounter_id", "latitude", "longitude", @@ -180,6 +177,7 @@ def _write_local(self, element): element_df = pd.DataFrame(element, columns=self.columns) final_df = pd.concat([stored_df, element_df], ignore_index=True) final_df = final_df.drop_duplicates() + logging.debug(f"Appending df to {self.output_path} \n{final_df}") # store as json (hack: to remove \/\/ escapes) final_df_json = final_df.to_json(orient="records").replace("\\/", "/")