Skip to content

Commit

Permalink
refactor code from tf.v1 to tf.v2 with improve test
Browse files Browse the repository at this point in the history
  • Loading branch information
LINSANITY03 committed Oct 6, 2024
1 parent 46f3406 commit a690089
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 88 deletions.
6 changes: 2 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,21 @@
from contextlib import asynccontextmanager

import tensorflow_hub as hub
import tensorflow._api.v2.compat.v1 as tf
import tensorflow as tf
import tensorflow_text as text # pylint: disable=unused-import

from fastapi import FastAPI

from .routers import predict

tf.disable_v2_behavior() # pylint: disable=no-member

@asynccontextmanager
async def lifespan(apps: FastAPI):
"""
Loading the ML models such that it will be executed before the application
starts taking requests, during the startup.
"""
# Load the ML model
ml_models = tf.keras.models.load_model("nlp-models/saved-model/sent-model",
ml_models = tf.keras.models.load_model("nlp-models/saved-model/sent-model1.h5",
custom_objects={'KerasLayer': hub.KerasLayer},
)

Expand Down
40 changes: 4 additions & 36 deletions nlp-models/notebook/Sentiment_analysis.ipynb
Original file line number Diff line number Diff line change
@@ -1,32 +1,5 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Dj1a7JYBGphY",
"outputId": "7bd599e8-85bf-4547-f628-936417e47bc2"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /usr/local/lib/python3.10/dist-packages/tensorflow/python/compat/v2_compat.py:107: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"non-resource variables are not supported in the long term\n"
]
}
],
"source": [
"import tensorflow._api.v2.compat.v1 as tf\n",
"\n",
"tf.disable_v2_behavior()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -140,7 +113,7 @@
"metadata": {},
"outputs": [],
"source": [
"model.save(\"../saved-model/sent-model\") # saving the model"
"model.save(\"../saved-model/sent-model1.h5\") # saving the model"
]
},
{
Expand Down Expand Up @@ -189,16 +162,11 @@
"dataset = tf.data.Dataset.from_tensor_slices(texts)\n",
"dataset = dataset.batch(1) # Adjust the batch size as needed\n",
"\n",
"\n",
"with tf.compat.v1.Session() as sess:\n",
" sess.run(tf.compat.v1.global_variables_initializer()) # Use tf.compat.v1.global_variables_initializer()\n",
" sess.run(tf.compat.v1.tables_initializer())\n",
"\n",
" # Perform prediction using the dataset\n",
" predictions = model.predict(dataset, steps=len(texts)) # Specify steps based on the number of texts\n",
"# Perform prediction using the dataset\n",
"predictions = model.predict(dataset, steps=len(texts)) # Specify steps based on the number of texts\n",
"\n",
"# Print the predictions\n",
" print(predictions)"
"print(predictions)"
]
}
],
Expand Down
20 changes: 6 additions & 14 deletions nlp-models/src/sentiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@

import tensorflow_hub as hub
import tensorflow_text as text # pylint: disable=unused-import
import tensorflow._api.v2.compat.v1 as tf

tf.disable_v2_behavior() # pylint: disable=no-member
import tensorflow as tf

PRE_PROCESS_URL= "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
ENCODER_URL = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4"
Expand Down Expand Up @@ -76,15 +74,9 @@ def build_model():
dataset = tf.data.Dataset.from_tensor_slices(texts)
dataset = dataset.batch(1) # Adjust the batch size as needed

# Perform prediction using the dataset
# Specify steps based on the number of texts
predictions = model.predict(dataset, steps=len(texts))

with tf.compat.v1.Session() as sess:
# Use tf.compat.v1.global_variables_initializer()
sess.run(tf.compat.v1.global_variables_initializer())
sess.run(tf.compat.v1.tables_initializer())

# Perform prediction using the dataset
# Specify steps based on the number of texts
predictions = model.predict(dataset, steps=len(texts))

# Print the predictions
print(predictions)
# Print the predictions
print(predictions)
18 changes: 6 additions & 12 deletions routers/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@

from typing import List

import tensorflow._api.v2.compat.v1 as tf
import tensorflow as tf

from fastapi import APIRouter, Request, status, HTTPException
from pydantic import BaseModel

tf.disable_v2_behavior() # pylint: disable=no-member

router = APIRouter(
prefix="",
tags=["models"],
Expand Down Expand Up @@ -73,17 +71,13 @@ def prediction(dataset, sent_model, text_len):
predictions.item(): Prediction score
"""
with tf.compat.v1.Session() as sess:
# Use tf.compat.v1.global_variables_initializer()
sess.run(tf.compat.v1.global_variables_initializer())
sess.run(tf.compat.v1.tables_initializer())

# Perform prediction using the dataset
# Specify steps based on the number of texts
predictions = sent_model.predict(dataset, steps=text_len)
# Perform prediction using the dataset
# Specify steps based on the number of texts
predictions = sent_model.predict(dataset, steps=text_len)

# Return the prediction score
return predictions.item()
# Return the prediction score
return predictions.item()

@router.post("/", status_code=status.HTTP_200_OK)
async def predict_sentiment(request:Request, data: TextStr):
Expand Down
60 changes: 38 additions & 22 deletions tests/test_routers/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,14 @@

import os

import tensorflow._api.v2.compat.v1 as tf
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text # pylint: disable=unused-import
import pytest

from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.routers import predict


tf.disable_v2_behavior() # pylint: disable=no-member
from ...routers import predict

# Create a test FastAPI app
app = FastAPI()
Expand All @@ -62,47 +59,66 @@ def client():
model_path = os.path.join(os.path.dirname(__file__), '../../nlp-models/saved-model/sent-model')

# Set up the mock model in the app state
app.state.sentiment_analysis_model = tf.keras.models.load_model(
ml_models = tf.keras.models.load_model(
model_path,
custom_objects={'KerasLayer': hub.KerasLayer}
)

ml_models.trainable = False
app.state.sentiment_analysis_model = ml_models

# Use TestClient for testing
with TestClient(app) as c:
yield c

def test_predict_sentiment_positive(test_client, caplog):
def test_model_loaded(client):
"""
Tests the testclient for loaded model.
This test simulates a successful prediction model in app lifecycle.
Parameters:
client (TestClient): The test client instance used to make requests
to the FastAPI application.
Assertions:
- Asserts that the app state has sentiment_analysis_model attribute.
- Asserts that the sentiment_analysis_model attribute is not empty.
"""
# Check if the model is loaded in app.state
assert hasattr(client.app.state, 'sentiment_analysis_model'), "Model is not set in app.state"
assert client.app.state.sentiment_analysis_model is not None, "Model is None in app.state"

def test_predict_sentiment_positive(client):
"""
Tests the sentiment prediction endpoint for positive sentiment input.
This test simulates a POST request to the sentiment prediction endpoint with a
positive input text ("I love this product!") and verifies that the response
positive input text and verifies that the response
is successful (HTTP 200) and includes a sentiment score.
Parameters:
test_client (TestClient): The test client instance used to make requests
to the FastAPI application.
caplog (pytest.LogCaptureFixture): A fixture for capturing log messages during the test.
client (TestClient): The test client instance used to make requests
to the FastAPI application.
Assertions:
- Asserts that the response status code is 200.
- Asserts that the response JSON contains a "score" key.
"""
# Prepare test data
test_data = {"texts": "I love this product!"}

# Make a POST request to the predict endpoint
response = test_client.post("/", json=test_data)
# Prepare test data
test_data = {
"texts": "I recently stayed at a hotel that was highly disappointing. The room was dirty, and the staff were unhelpful and rude. Despite requesting multiple times, the issues were never addressed. The amenities were outdated, and the overall experience was far below what I expected. I would not recommend this place to anyone and will avoid it in the future."
}

if response.status_code != 200:
caplog.set_level("ERROR")
print(f"Error Response JSON: {response.json()}") # Capture in logs
# Make a POST request to the predict endpoint
response = client.post("/", json=test_data)

# Check the response
assert response.status_code == 200, f"Response JSON: {response.json()}"
assert "score" in response.json()

def test_predict_sentiment_model_not_loaded(test_client):
def test_predict_sentiment_model_not_loaded(client):
"""
Tests the sentiment prediction endpoint when the sentiment analysis model is not loaded.
Expand All @@ -111,7 +127,7 @@ def test_predict_sentiment_model_not_loaded(test_client):
It verifies that the endpoint returns an HTTP 500 error with the appropriate error message.
Parameters:
test_client (TestClient): The test client instance used to make requests
client (TestClient): The test client instance used to make requests
to the FastAPI application.
Assertions:
Expand All @@ -122,8 +138,8 @@ def test_predict_sentiment_model_not_loaded(test_client):
# Clear the model from the app state
client.app.state.sentiment_analysis_model = None

test_data = {"texts": "This is bad!"}
response = test_client.post("/", json=test_data)
test_data = {"texts": "I recently stayed at a hotel that was highly disappointing. The room was dirty, and the staff were unhelpful and rude."}
response = client.post("/", json=test_data)

# Check the error response
assert response.status_code == 500
Expand Down

0 comments on commit a690089

Please sign in to comment.