diff --git a/routers/predict.py b/routers/predict.py index cdd54d1..7815067 100644 --- a/routers/predict.py +++ b/routers/predict.py @@ -91,7 +91,7 @@ def prediction(dataset, sent_model, text_len): return predictions.item() @router.post("/", status_code=status.HTTP_200_OK, response_model=PredictResponse) -@limiter.limit("2/hour") # Limit to 2 requests per minute +@limiter.limit("2/minute") # Limit to 2 requests per minute async def predict_sentiment(request:Request, data: TextStr): """ Predict the sentiment of the input text using a pre-loaded machine learning model. diff --git a/tests/test_routers/test_predict.py b/tests/test_routers/test_predict.py index 4a586b8..2bad28e 100644 --- a/tests/test_routers/test_predict.py +++ b/tests/test_routers/test_predict.py @@ -25,6 +25,7 @@ import os +import httpx import tensorflow as tf import tensorflow_hub as hub import tensorflow_text as text # pylint: disable=unused-import @@ -37,7 +38,7 @@ # Create a test FastAPI app app = FastAPI() app.include_router(predict.router) - + @pytest.fixture def client(): """ @@ -71,6 +72,35 @@ def client(): with TestClient(app) as c: yield c +import pytest_asyncio +import httpx + +@pytest_asyncio.fixture +async def async_client(): + """ + Asynchronous fixture to set up a test client for the FastAPI application with + a loaded sentiment analysis model. + + This fixture initializes the FastAPI application by loading the sentiment analysis + model from the specified directory and sets it in the app's state. It provides an + async HTTP client for testing FastAPI endpoints. + """ + # Load the model as before + model_path = os.path.join(os.path.dirname(__file__), '../../nlp-models/saved-model/sent-model') + + ml_models = tf.keras.models.load_model( + model_path, + custom_objects={'KerasLayer': hub.KerasLayer} + ) + ml_models.trainable = False + + # Assign the model to the FastAPI app state + app.state.sentiment_analysis_model = ml_models + + # Create an asynchronous client for testing + async with httpx.AsyncClient(app=app, base_url="http://127.0.0.1:8000/") as client: + yield client + def test_model_loaded(client): """ Tests the testclient for loaded model. @@ -145,7 +175,8 @@ def test_predict_sentiment_model_not_loaded(client): assert response.status_code == 500 assert response.json() == {"detail": "Model not loaded properly"} -def test_rate_limiting(client): +@pytest.mark.asyncio +async def test_rate_limiting(async_client): """ Test that the /predict route is limited to 2 requests per minute. The 3rd request should return a 429 Too Many Requests error. @@ -161,14 +192,13 @@ def test_rate_limiting(client): """ test_data = {"texts": "I recently stayed at a hotel that was highly disappointing. The room was dirty, and the staff were unhelpful and rude."} - # Make 2 successful requests within the rate limit for _ in range(2): - response = client.post("/", json=test_data) + response = await async_client.post("/", json=test_data) assert response.status_code == 200, "Expected status code 200 for valid requests" # The 3rd request should hit the rate limit and return status 429 - response = client.post("/", json=test_data) + response = await async_client.post("/", json=test_data) assert response.status_code == 429, "Expected status code 429 after rate limit exceeded" - assert response.json() == {'detail': '2 per 1 hour'}, \ + assert response.json() == {'detail': '2 per 1 minute'}, \ "Expected rate limit exceeded error message" \ No newline at end of file