Skip to content

Commit

Permalink
updating naming for model uri and inference url
Browse files Browse the repository at this point in the history
  • Loading branch information
pmhalvor committed Oct 5, 2024
1 parent a2eeb95 commit 50109f2
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/config/common.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ pipeline:
plot_scores: true
plot_path_template: "data/plots/results/{year}/{month:02}/{plot_name}.png"
classification_path: "data/classifications.tsv"
url: https://tfhub.dev/google/humpback_whale/1
model_url: "http://127.0.0.1:5000/predict"
model_uri: https://tfhub.dev/google/humpback_whale/1
inference_url: "http://127.0.0.1:5000/predict"
med_filter_size: 3

postprocess:
Expand Down
10 changes: 5 additions & 5 deletions src/stages/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, config: SimpleNamespace):

self.batch_duration = config.classify.batch_duration
self.model_sample_rate = config.classify.model_sample_rate
self.model_url = config.classify.model_url
self.inference_url = config.classify.inference_url

# plotting parameters
self.hydrophone_sensitivity = config.classify.hydrophone_sensitivity
Expand Down Expand Up @@ -271,7 +271,7 @@ def _postprocess(self, pcoll, grouped_outputs):
class InferenceClient(beam.DoFn):
def __init__(self, config: SimpleNamespace):

self.model_url = config.classify.model_url
self.inference_url = config.classify.inference_url
self.retries = config.classify.inference_retries

def process(self, element):
Expand All @@ -292,7 +292,7 @@ def process(self, element):
wait = 0
while wait < 5:
try:
response = requests.post(self.model_url, json=data)
response = requests.post(self.inference_url, json=data)
response.raise_for_status()
break
except requests.exceptions.ConnectionError as e:
Expand All @@ -301,7 +301,7 @@ def process(self, element):
wait += 1
time.sleep(wait*wait)

response = requests.post(self.model_url, json=data)
response = requests.post(self.inference_url, json=data)
response.raise_for_status()

predictions = response.json().get("predictions", [])
Expand Down Expand Up @@ -424,7 +424,7 @@ def sample_run():
batch_duration=30, # seconds
hydrophone_sensitivity=-168.8,
model_sample_rate=10_000,
model_url="http://127.0.0.1:5000/predict",
inference_url="http://127.0.0.1:5000/predict",
plot_scores=True,
plot_path_template="data/plots/results/{year}/{month:02}/{plot_name}.png",
med_filter_size=3,
Expand Down
26 changes: 12 additions & 14 deletions src/stages/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,27 @@ def process(self, element, search_output):

def _build_classification_df(self, element: Tuple) -> pd.DataFrame:
# convert element to dataframe
classifications_df = pd.DataFrame([element], columns=["audio", "start", "end", "encounter_ids", "classifications"])
df = pd.DataFrame([element], columns=["audio", "start", "end", "encounter_ids", "classifications"])
df = df[df["classifications"].apply(lambda x: len(x) > 0)] # rm empty rows

# explode encounter_ids
classifications_df = classifications_df.explode("encounter_ids").rename(columns={"encounter_ids": "encounter_id"})
classifications_df["encounter_id"] = classifications_df["encounter_id"].astype(str)
df = df.explode("encounter_ids").rename(columns={"encounter_ids": "encounter_id"})
df["encounter_id"] = df["encounter_id"].astype(str)

# TODO replace classifications check w/ pooled_score check
classifications_df = classifications_df[classifications_df["classifications"].apply(lambda x: len(x) > 0)]
# pool classifications in postprocessing
classifications_df["pooled_score"] = classifications_df["classifications"].apply(self._pool_classifications)
df["pooled_score"] = df["classifications"].apply(self._pool_classifications)

# convert start and end to isoformat
classifications_df["start"] = classifications_df["start"].apply(lambda x: x.isoformat())
classifications_df["end"] = classifications_df["end"].apply(lambda x: x.isoformat())
df["start"] = df["start"].apply(lambda x: x.isoformat())
df["end"] = df["end"].apply(lambda x: x.isoformat())

# drop audio and classification columns
classifications_df = classifications_df.drop(columns=["audio"])
classifications_df = classifications_df.drop(columns=["classifications"])
df = df.drop(columns=["audio"])
df = df.drop(columns=["classifications"])


logging.info(f"Classifications: \n{classifications_df.head()}")
logging.info(f"Classifications shape: {classifications_df.shape}")
return classifications_df.reset_index(drop=True)
logging.info(f"Classifications: \n{df.head()}")
logging.info(f"Classifications shape: {df.shape}")
return df.reset_index(drop=True)

def _build_search_output_df(self, search_output: Dict[str, Any]) -> pd.DataFrame:
# convert search_output to dataframe
Expand Down
2 changes: 1 addition & 1 deletion tests/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def example_config():
batch_duration=30, # seconds
hydrophone_sensitivity=-168.8,
model_sample_rate=10_000,
model_url="http://127.0.0.1:5000/predict",
inference_url="http://127.0.0.1:5000/predict",
plot_scores=True,
plot_path_template="data/plots/results/{year}/{month:02}/{plot_name}.png",
med_filter_size=3,
Expand Down

0 comments on commit 50109f2

Please sign in to comment.