-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add makefile + move model_server to src
- Loading branch information
Showing
3 changed files
with
82 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
local-run: | ||
python3 src/model_server.py & python3 src/pipeline/app.py | ||
bash scripts/kill_model_server.sh | ||
|
||
run-pipeline: | ||
python3 src/pipeline/app.py | ||
|
||
model-server: | ||
python3 src/model_server.py | ||
|
||
kill-model-server: | ||
bash scripts/kill_model_server.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
echo "Checking for running model_server.py files..." | ||
ps aux | grep python | grep whale | grep model_server.py | ||
|
||
PID=$(ps aux | grep python | grep whale | grep model_server.py | awk '{print $2}') | ||
if [ -z "$PID" ]; then | ||
echo "No model_server.py files running." | ||
else | ||
echo "Killing PID: $PID" | ||
kill -9 $PID | ||
sleep 2 | ||
|
||
echo "Checking for running model_server.py files..." | ||
ps aux | grep python | grep whale | grep model_server.py | ||
fi | ||
|
||
echo "Done." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from flask import Flask, request, jsonify | ||
import tensorflow_hub as hub | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
import logging | ||
|
||
|
||
# Load the TensorFlow model | ||
print("Loading model...") | ||
# model = hub.load("https://www.kaggle.com/models/google/humpback-whale/TensorFlow2/humpback-whale/1") | ||
model = hub.load("https://tfhub.dev/google/humpback_whale/1") | ||
score_fn = model.signatures["score"] | ||
print("Model loaded.") | ||
|
||
# Initialize Flask app | ||
app = Flask(__name__) | ||
|
||
# Define the predict endpoint | ||
@app.route('/predict', methods=['POST']) | ||
def predict(): | ||
try: | ||
# Parse the request data | ||
data = request.json | ||
batch = np.array(data['batch'], dtype=np.float32) # Assuming batch is passed as a list | ||
key = data['key'] | ||
print(f"batch.shape = {batch.shape}") | ||
|
||
# Prepare the input for the model | ||
waveform_exp = tf.expand_dims(batch, 0) # Expanding dimensions to fit model input shape | ||
print(f"waveform_exp.shape = {waveform_exp.shape}") | ||
|
||
# Run inference | ||
results = score_fn( | ||
waveform=waveform_exp, # waveform_exp, | ||
context_step_samples=10_000 | ||
)["scores"][0] # NOTE currently only support batch size 1 | ||
print(f"results.shape = {results.shape}") | ||
print("results = ", results) | ||
|
||
# Return the predictions and key as JSON | ||
return jsonify({ | ||
'key': key, | ||
'predictions': results.numpy().tolist() | ||
}) | ||
|
||
except Exception as e: | ||
logging.error(f"An error occurred: {str(e)}") | ||
print(f"An error occurred: {str(e)}") | ||
return jsonify({'error': str(e)}), 500 | ||
|
||
# Main entry point | ||
if __name__ == "__main__": | ||
app.run(host='0.0.0.0', port=5000, debug=True) |