Skip to content

Commit

Permalink
Apply inline suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
pmhalvor authored Oct 5, 2024
1 parent 50109f2 commit b4c6a23
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/config/common.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pipeline:
show_plots: false
is_local: false

# gcp
# gcp - bigquery
project: "bioacoustics-2024"
dataset_id: "whale_speech"

Expand Down
2 changes: 1 addition & 1 deletion src/config/local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pipeline:
verbose: true
debug: true
show_plots: false
# is_local: true
is_local: true

search:
export_template: "data/encounters/{filename}-{timeframe}.csv"
Expand Down
3 changes: 1 addition & 2 deletions src/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

# Load the TensorFlow model
logging.info("Loading model...")
model = hub.load("https://www.kaggle.com/models/google/humpback-whale/TensorFlow2/humpback-whale/1")
# model = hub.load("https://tfhub.dev/google/humpback_whale/1")
model = hub.load(config.classify.model_uri)
score_fn = model.signatures["score"]
logging.info("Model loaded.")

Expand Down
2 changes: 1 addition & 1 deletion src/stages/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, config: SimpleNamespace):

self.batch_duration = config.classify.batch_duration
self.model_sample_rate = config.classify.model_sample_rate
self.inference_url = config.classify.inference_url
self.inference_url = config.classify.inference_url

# plotting parameters
self.hydrophone_sensitivity = config.classify.hydrophone_sensitivity
Expand Down
6 changes: 2 additions & 4 deletions src/stages/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
import pandas as pd
import os

# from google.cloud import bigquery
from apache_beam.io.gcp.internal.clients import bigquery

from typing import Dict, Any, Tuple
from types import SimpleNamespace


# class PostprocessLabels(beam.PTransform):
class PostprocessLabels(beam.DoFn):
def __init__(self, config: SimpleNamespace):
self.config = config
Expand All @@ -30,7 +28,7 @@ def process(self, element, search_output):
# convert element to dataframe
classifications_df = self._build_classification_df(element)

# convert search_output to dataframe
# clean up search_output dataframe
search_output_df = self._build_search_output_df(search_output)

# join dataframes
Expand Down Expand Up @@ -73,7 +71,6 @@ def _build_search_output_df(self, search_output: Dict[str, Any]) -> pd.DataFrame
search_output = search_output.rename(columns={"id": "encounter_id"})
search_output["encounter_id"] = search_output["encounter_id"].astype(str)
search_output = search_output[[
# TODO refactor to confing
"encounter_id",
"latitude",
"longitude",
Expand Down Expand Up @@ -180,6 +177,7 @@ def _write_local(self, element):
element_df = pd.DataFrame(element, columns=self.columns)
final_df = pd.concat([stored_df, element_df], ignore_index=True)
final_df = final_df.drop_duplicates()
logging.debug(f"Appending df to {self.output_path} \n{final_df}")

# store as json (hack: to remove \/\/ escapes)
final_df_json = final_df.to_json(orient="records").replace("\\/", "/")
Expand Down

0 comments on commit b4c6a23

Please sign in to comment.