Skip to content

Commit

Permalink
add makefile + move model_server to src
Browse files Browse the repository at this point in the history
  • Loading branch information
pmhalvor committed Sep 29, 2024
1 parent 4a8da13 commit eb0bf2e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 0 deletions.
12 changes: 12 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
local-run:
python3 src/model_server.py & python3 src/pipeline/app.py
bash scripts/kill_model_server.sh

run-pipeline:
python3 src/pipeline/app.py

model-server:
python3 src/model_server.py

kill-model-server:
bash scripts/kill_model_server.sh
16 changes: 16 additions & 0 deletions scripts/kill_model_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
echo "Checking for running model_server.py files..."
ps aux | grep python | grep whale | grep model_server.py

PID=$(ps aux | grep python | grep whale | grep model_server.py | awk '{print $2}')
if [ -z "$PID" ]; then
echo "No model_server.py files running."
else
echo "Killing PID: $PID"
kill -9 $PID
sleep 2

echo "Checking for running model_server.py files..."
ps aux | grep python | grep whale | grep model_server.py
fi

echo "Done."
54 changes: 54 additions & 0 deletions src/model_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from flask import Flask, request, jsonify
import tensorflow_hub as hub
import numpy as np
import tensorflow as tf

import logging


# Load the TensorFlow model
print("Loading model...")
# model = hub.load("https://www.kaggle.com/models/google/humpback-whale/TensorFlow2/humpback-whale/1")
model = hub.load("https://tfhub.dev/google/humpback_whale/1")
score_fn = model.signatures["score"]
print("Model loaded.")

# Initialize Flask app
app = Flask(__name__)

# Define the predict endpoint
@app.route('/predict', methods=['POST'])
def predict():
try:
# Parse the request data
data = request.json
batch = np.array(data['batch'], dtype=np.float32) # Assuming batch is passed as a list
key = data['key']
print(f"batch.shape = {batch.shape}")

# Prepare the input for the model
waveform_exp = tf.expand_dims(batch, 0) # Expanding dimensions to fit model input shape
print(f"waveform_exp.shape = {waveform_exp.shape}")

# Run inference
results = score_fn(
waveform=waveform_exp, # waveform_exp,
context_step_samples=10_000
)["scores"][0] # NOTE currently only support batch size 1
print(f"results.shape = {results.shape}")
print("results = ", results)

# Return the predictions and key as JSON
return jsonify({
'key': key,
'predictions': results.numpy().tolist()
})

except Exception as e:
logging.error(f"An error occurred: {str(e)}")
print(f"An error occurred: {str(e)}")
return jsonify({'error': str(e)}), 500

# Main entry point
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000, debug=True)

0 comments on commit eb0bf2e

Please sign in to comment.