From 9c056c8ffaf56ee6983119008e7835f316124436 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:54:34 -0400 Subject: [PATCH] simplify path handling --- services/inference/Makefile | 8 ++++---- services/inference/tests/test_inference.py | 8 +------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/services/inference/Makefile b/services/inference/Makefile index 914737ad..3c20dffb 100644 --- a/services/inference/Makefile +++ b/services/inference/Makefile @@ -10,7 +10,8 @@ boilerplate: # starts the tsfminference service (used mainly for test cases) start_service_local: boilerplate - TSFM_MODEL_DIR=./mytest-tsfm python -m gunicorn \ + TSFM_MODEL_DIR=./mytest-tsfm TSFM_ALLOW_LOAD_FROM_HF_HUB=1 \ + python -m gunicorn \ -w 1 \ -k uvicorn.workers.UvicornWorker \ --bind 127.0.0.1:8000 \ @@ -36,13 +37,12 @@ stop_service_image: $(CONTAINER_BUILDER) stop tsfmserver test_local: clone_models boilerplate start_service_local - TSFM_ALLOW_LOAD_FROM_HF_HUB=1 TSFM_MODEL_DIR=./mytest-tsfm \ - pytest tests + pytest tests $(MAKE) stop_service_local $(MAKE) delete_models test_image: clone_models start_service_image - TSFM_MODEL_DIR=./ pytest tests + pytest tests $(MAKE) stop_service_image $(MAKE) delete_models diff --git a/services/inference/tests/test_inference.py b/services/inference/tests/test_inference.py index 010ac0bc..90c8eac7 100644 --- a/services/inference/tests/test_inference.py +++ b/services/inference/tests/test_inference.py @@ -1,8 +1,5 @@ # Copyright contributors to the TSFM project # -import os -import tempfile -from pathlib import Path from typing import Any, Dict import numpy as np @@ -100,10 +97,7 @@ def test_zero_shot_forecast_inference(ts_data): prediction_length = params["prediction_length"] context_length = params["context_length"] model_id = params["model_id"] - model_dir = ( - Path(os.getenv("TSFM_MODEL_DIR")) if os.getenv("TSFM_MODEL_DIR") else Path(tempfile.gettempdir()) / "test-tsfm" - ) - model_id_path: Path = (model_dir / model_id).as_posix() + model_id_path: str = model_id id_columns = params["id_columns"]