Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Service abstraction #171

Merged
merged 130 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
130 commits
Select commit Hold shift + click to select a range
ade732f
wip, model abstraction
wgifford Oct 22, 2024
8b2c637
WIP, abstraction implementation
wgifford Oct 23, 2024
1b9cf9e
wip, abstraction
wgifford Oct 28, 2024
4070d02
use new abstraction
wgifford Oct 28, 2024
bde2ca3
wrapper -> handler
wgifford Oct 28, 2024
8907da8
allow older models
wgifford Oct 28, 2024
3b80d45
fix merge
wgifford Oct 28, 2024
75471d5
dead code
wgifford Oct 28, 2024
98c13d9
model -> handler
wgifford Oct 28, 2024
b483ff1
separate hf handler, move some hf utils
wgifford Oct 29, 2024
58f1279
move hf implementation
wgifford Oct 29, 2024
fd061de
Simplify classes
wgifford Oct 29, 2024
2a4806f
adjust type hints
wgifford Oct 29, 2024
aa6ac13
adjust logger
wgifford Oct 29, 2024
b8a7e1e
return only model
wgifford Oct 29, 2024
98b38d2
tsfm_config -> handler_config
wgifford Oct 29, 2024
6c41147
separate model_id and model_path attributes
wgifford Oct 29, 2024
ef16275
logger.exception
wgifford Oct 29, 2024
e47b448
add HF-like config object
wgifford Oct 29, 2024
e3f149c
config object, docstrings
wgifford Oct 30, 2024
4c863d4
remove old code
wgifford Oct 30, 2024
dde5bcf
docstrings, add additional exogenous specifier
wgifford Oct 30, 2024
51cfe3e
docstrings
wgifford Oct 30, 2024
1203b6e
avoid explicit conversion to datetime, support numeric timestamps
wgifford Oct 31, 2024
018ef3e
handle np.datetime64
wgifford Oct 31, 2024
e33c27e
comprehensive timestamp tests
wgifford Oct 31, 2024
c59dca9
improve coverage by testing lib code
ssiegel95 Oct 31, 2024
ca43e71
renamed
ssiegel95 Oct 31, 2024
3e135f1
call lib functions
ssiegel95 Oct 31, 2024
031ef3f
update signatures, train preprocessor during prepare
wgifford Oct 31, 2024
089b79a
:Merge remote-tracking branch 'origin/service_abstraction' into coverage
ssiegel95 Oct 31, 2024
bf590ec
more fixtures
ssiegel95 Oct 31, 2024
861cfa5
remove unused config options
wgifford Nov 1, 2024
15cdb63
ensure max length calculated
wgifford Nov 1, 2024
373f23c
docstrings, explicit definition of tsfm_config args
wgifford Nov 1, 2024
3ef7463
adjust for new defaults
wgifford Nov 1, 2024
9add858
Merge remote-tracking branch 'origin/service_abstraction' into coverage
ssiegel95 Nov 1, 2024
527b8d3
add more data related tests
ssiegel95 Nov 1, 2024
edf56e8
avoid groupby warning
wgifford Nov 1, 2024
c8548ea
catch a data processing exception
ssiegel95 Nov 1, 2024
25006d9
up coverage
ssiegel95 Nov 1, 2024
cf339d1
Merge remote-tracking branch 'origin/service_abstraction' into coverage
ssiegel95 Nov 2, 2024
5d11172
fixes
ssiegel95 Nov 2, 2024
4f51440
Merge pull request #176 from ibm-granite/coverage
ssiegel95 Nov 2, 2024
a280579
poetry lock
ssiegel95 Nov 2, 2024
02ff3e5
Revert "Coverage"
ssiegel95 Nov 2, 2024
bcc1a69
Merge pull request #177 from ibm-granite/revert-176-coverage
ssiegel95 Nov 2, 2024
5104a78
update lock file
ssiegel95 Nov 2, 2024
55d9ed3
dummy BaseParameters
wgifford Nov 5, 2024
4fcf184
refactor to include task specific handler
wgifford Nov 5, 2024
4f49475
pass kwargs
wgifford Nov 6, 2024
01a85ad
use suffix ForecastingHandler
wgifford Nov 7, 2024
a61abbf
wip, model abstraction
wgifford Oct 22, 2024
0c3d2f2
WIP, abstraction implementation
wgifford Oct 23, 2024
2086ba7
wip, abstraction
wgifford Oct 28, 2024
feab33c
use new abstraction
wgifford Oct 28, 2024
9cb5d7f
wrapper -> handler
wgifford Oct 28, 2024
346ca5a
allow older models
wgifford Oct 28, 2024
4fe86dd
fix merge
wgifford Oct 28, 2024
e675bd3
dead code
wgifford Oct 28, 2024
9fb9fa9
model -> handler
wgifford Oct 28, 2024
661bfeb
separate hf handler, move some hf utils
wgifford Oct 29, 2024
a86c43a
move hf implementation
wgifford Oct 29, 2024
ec003da
Simplify classes
wgifford Oct 29, 2024
91b3785
adjust type hints
wgifford Oct 29, 2024
3d07ba5
adjust logger
wgifford Oct 29, 2024
827b5d7
return only model
wgifford Oct 29, 2024
590af16
tsfm_config -> handler_config
wgifford Oct 29, 2024
fe6b02e
separate model_id and model_path attributes
wgifford Oct 29, 2024
d3f943c
logger.exception
wgifford Oct 29, 2024
d9cc573
add HF-like config object
wgifford Oct 29, 2024
fecebd3
config object, docstrings
wgifford Oct 30, 2024
02cb188
remove old code
wgifford Oct 30, 2024
8ff08a4
docstrings, add additional exogenous specifier
wgifford Oct 30, 2024
14a0412
docstrings
wgifford Oct 30, 2024
50d9161
avoid explicit conversion to datetime, support numeric timestamps
wgifford Oct 31, 2024
b3bd24f
handle np.datetime64
wgifford Oct 31, 2024
16d9e9e
improve coverage by testing lib code
ssiegel95 Oct 31, 2024
51e41d8
renamed
ssiegel95 Oct 31, 2024
31ccfc0
call lib functions
ssiegel95 Oct 31, 2024
a2787f5
comprehensive timestamp tests
wgifford Oct 31, 2024
b90d862
update signatures, train preprocessor during prepare
wgifford Oct 31, 2024
b473ced
more fixtures
ssiegel95 Oct 31, 2024
c83c2ae
remove unused config options
wgifford Nov 1, 2024
f110ead
ensure max length calculated
wgifford Nov 1, 2024
1f1d847
docstrings, explicit definition of tsfm_config args
wgifford Nov 1, 2024
7aed5f3
adjust for new defaults
wgifford Nov 1, 2024
93d1a60
add more data related tests
ssiegel95 Nov 1, 2024
d91a9eb
catch a data processing exception
ssiegel95 Nov 1, 2024
4deb964
up coverage
ssiegel95 Nov 1, 2024
74003a3
avoid groupby warning
wgifford Nov 1, 2024
35b264d
fixes
ssiegel95 Nov 2, 2024
cd6c594
poetry lock
ssiegel95 Nov 2, 2024
9ae4c4e
Revert "Coverage"
ssiegel95 Nov 2, 2024
25a78c8
update lock file
ssiegel95 Nov 2, 2024
74503e7
dummy BaseParameters
wgifford Nov 5, 2024
9d45f50
refactor to include task specific handler
wgifford Nov 5, 2024
7ad7a78
pass kwargs
wgifford Nov 6, 2024
0050680
use suffix ForecastingHandler
wgifford Nov 7, 2024
c81491e
fix merge issues
wgifford Nov 11, 2024
932a7c7
update extend logic
wgifford Nov 11, 2024
91f0ce5
update lock
wgifford Nov 11, 2024
fc232f6
fix future data handling, add tests
wgifford Nov 11, 2024
3ff882e
unique->nunique
wgifford Nov 11, 2024
e12b401
bump to latest tsfm_public
wgifford Nov 11, 2024
e785fc6
Merge branch 'main' into service_abstraction
wgifford Nov 12, 2024
1969190
fix merge error
wgifford Nov 12, 2024
317c5d8
bump version
wgifford Nov 12, 2024
868babd
allow saving the config
wgifford Nov 15, 2024
24462b1
check the schema
wgifford Nov 15, 2024
2f80f4d
test case for fine-tuned model
wgifford Nov 15, 2024
605dfe6
cleanup
wgifford Nov 15, 2024
1a80807
add modelspec path
ssiegel95 Nov 18, 2024
5b4f663
fix spelling mistake in log message
ssiegel95 Nov 18, 2024
ead5e86
ignore dataframe_checks.py
ssiegel95 Nov 18, 2024
0628b7a
Merge branch 'service_abstraction' of github.com:ibm-granite/granite-…
ssiegel95 Nov 18, 2024
7e6aea1
Merge remote-tracking branch 'origin/main' into service_abstraction
ssiegel95 Nov 18, 2024
7c0e48f
fix duplicate keywork arg
ssiegel95 Nov 18, 2024
27f7beb
use different model
ssiegel95 Nov 18, 2024
8fb942c
make parameters optional
ssiegel95 Nov 18, 2024
ebbe732
intercept thrown pandas exceptions
ssiegel95 Nov 18, 2024
560d4c3
first implementation of data point counts
wgifford Nov 18, 2024
a64e63e
add tests for data point counts, update calculation
wgifford Nov 18, 2024
1f6b758
update poetry lock
ssiegel95 Nov 18, 2024
10b1bda
no longer needed
ssiegel95 Nov 18, 2024
4ec5fb2
Reapply "Coverage"
ssiegel95 Nov 18, 2024
e94b51e
reduce verbosity of local server during tests
ssiegel95 Nov 18, 2024
781f796
pytest-cov
ssiegel95 Nov 18, 2024
534dd0a
fix some tests, skip one other
ssiegel95 Nov 18, 2024
fef2003
Better tests, fix bug :bug:
wgifford Nov 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion services/boilerplate/inference_payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ class ForecastingMetadataInput(BaseMetadataInput):
)


class BaseParameters(BaseModel):
model_config = ConfigDict(extra="forbid", protected_namespaces=())


class ForecastingParameters(BaseModel):
model_config = ConfigDict(extra="forbid", protected_namespaces=())

Expand Down Expand Up @@ -142,7 +146,9 @@ class ForecastingInferenceInput(BaseInferenceInput):
description="An object of ForecastingMetadataInput that contains the schema" " metadata of the data input.",
)

parameters: ForecastingParameters
parameters: ForecastingParameters = Field(
description="additional parameters affecting behavior of the forecast.", default_factory=dict
)

data: Dict[str, List[Any]] = Field(
description="A payload of data matching the schema provided."
Expand Down Expand Up @@ -277,3 +283,6 @@ class PredictOutput(BaseModel):
description="List of prediction results.",
default=None,
)

input_data_points: int = Field(description="Count of input data points.", default=None)
output_data_points: int = Field(description="Count of output data points.", default=None)
678 changes: 343 additions & 335 deletions services/finetuning/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion services/finetuning/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ __version_tuple__ = (0, 0, 0)
# including 3.9 causes poetry lock to run forever
python = ">=3.10,<3.13"
numpy = { version = "<2" }
tsfm_public = { git = "https://github.com/IBM-granite/granite-tsfm.git", tag = "v0.2.13", markers = "sys_platform != 'win32'" }
tsfm_public = { git = "https://github.com/IBM-granite/granite-tsfm.git", tag = "v0.2.16", markers = "sys_platform != 'win32'" }


# trying to pick up cpu version for tsfmfinetuning
Expand Down
6 changes: 4 additions & 2 deletions services/inference/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ create_prometheus_metrics_dir:
# starts the tsfminference service (used mainly for test cases)
start_service_local: create_prometheus_metrics_dir boilerplate
PROMETHEUS_MULTIPROC_DIR=./prometheus_metrics \
TSFM_PYTHON_LOGGING_LEVEL="ERROR" \
TSFM_MODEL_DIR=./mytest-tsfm \
TSFM_ALLOW_LOAD_FROM_HF_HUB=1 \
python -m gunicorn \
Expand All @@ -23,7 +24,7 @@ start_service_local: create_prometheus_metrics_dir boilerplate
--bind 127.0.0.1:8000 \
tsfminference.main:app && true &
stop_service_local:
pkill -f 'python.*tsfminference.*'
pkill -f 'python.*gunicorn.*tsfminference\.main\:app'

image: boilerplate
$(CONTAINER_BUILDER) build -t tsfminference -f Dockerfile .
Expand All @@ -45,9 +46,10 @@ stop_service_image:
$(CONTAINER_BUILDER) stop tsfmserver

test_local: clone_models boilerplate start_service_local
pytest -s tests ../tests
pytest --cov=tsfminference --cov-report term-missing tests ../tests
$(MAKE) stop_service_local
$(MAKE) delete_models
$(MAKE) stop_service_local

test_image: clone_models start_service_image
pytest -s tests ../tests
Expand Down
356 changes: 0 additions & 356 deletions services/inference/openapi.json

This file was deleted.

650 changes: 380 additions & 270 deletions services/inference/poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion services/inference/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ __version_tuple__ = (0, 0, 0)
# including 3.9 causes poetry lock to run forever
python = ">=3.10,<3.13"
numpy = { version = "<2" }
tsfm_public = { git = "https://github.com/IBM-granite/granite-tsfm.git", tag = "v0.2.13", markers = "sys_platform != 'win32'" }
tsfm_public = { git = "https://github.com/IBM-granite/granite-tsfm.git", tag = "v0.2.16", markers = "sys_platform != 'win32'" }

# trying to pick up cpu version for tsfminference
# to make image smaller
Expand Down Expand Up @@ -60,6 +60,7 @@ optional = true
[tool.poetry.group.dev.dependencies]
pytest = "*"
locust = "*"
pytest-coverage = "*"

[build-system]
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
Expand Down
2 changes: 1 addition & 1 deletion services/inference/tests/locust/payload.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"model_id": "ibm/test-ttm-v1",
"model_id": "mytest-tsfm/ttm-1024-96-r2",
"parameters": {
"prediction_length": 1
},
Expand Down
204 changes: 204 additions & 0 deletions services/inference/tests/test_inference_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright contributors to the TSFM project
#

import copy
from datetime import timedelta

import numpy as np
import pandas as pd
import pytest
import yaml
from fastapi import HTTPException
from tsfminference import TSFM_CONFIG_FILE
from tsfminference.inference import InferenceRuntime
from tsfminference.inference_payloads import (
ForecastingInferenceInput,
ForecastingMetadataInput,
ForecastingParameters,
PredictOutput,
)


SERIES_LENGTH = 512
FORECAST_LENGTH = 96
MODEL_ID = "mytest-tsfm/ttm-r1"


@pytest.fixture(scope="module")
def ts_data_base() -> pd.DataFrame:
# Generate a date range
length = SERIES_LENGTH
date_range = pd.date_range(start="2023-10-01", periods=length, freq="H")

# Create a DataFrame
df = pd.DataFrame(
{
"date": date_range,
"ID": "1",
"VAL": np.random.rand(length),
}
)

return df


if TSFM_CONFIG_FILE:
with open(TSFM_CONFIG_FILE, "r") as file:
config = yaml.safe_load(file)
else:
config = {}


@pytest.fixture(scope="module")
def forecasting_input_base() -> ForecastingInferenceInput:
# df: pd.DataFrame = ts_data_base
schema: ForecastingMetadataInput = ForecastingMetadataInput(
timestamp_column="date", id_columns=["ID"], target_columns=["VAL"]
)
parameters: ForecastingParameters = ForecastingParameters()
input: ForecastingInferenceInput = ForecastingInferenceInput(
model_id=MODEL_ID,
schema=schema,
parameters=parameters,
data={
"date": [
"2024-10-18T01:00:21+00:00",
],
"ID1": [
"I1",
],
"VAL": [
10.0,
],
}, # this should get replaced in each test case anyway,
)
return input


def _basic_result_checks(results: PredictOutput, df: pd.DataFrame):
# expected length
assert len(results) == FORECAST_LENGTH
# expected start time
assert results["date"].iloc[0] - df["date"].iloc[-1] == timedelta(hours=1)
# expected end time
assert results["date"].iloc[-1] - df["date"].iloc[-1] == timedelta(hours=FORECAST_LENGTH)


def test_forecast_with_good_data(ts_data_base: pd.DataFrame, forecasting_input_base: ForecastingInferenceInput):
input = forecasting_input_base
df = copy.deepcopy(ts_data_base)
input.data = df.to_dict(orient="list")
runtime: InferenceRuntime = InferenceRuntime(config=config)
po: PredictOutput = runtime.forecast(input=input)
results = pd.DataFrame.from_dict(po.results[0])
_basic_result_checks(results, df)


def test_forecast_with_schema_missing_target_columns(
ts_data_base: pd.DataFrame, forecasting_input_base: ForecastingInferenceInput
):
input = forecasting_input_base
input.schema.target_columns = []
df = copy.deepcopy(ts_data_base)
input.data = df.to_dict(orient="list")
runtime: InferenceRuntime = InferenceRuntime(config=config)
po: PredictOutput = runtime.forecast(input=input)
results = pd.DataFrame.from_dict(po.results[0])
_basic_result_checks(results, df)


def test_forecast_with_integer_timestamps(
ts_data_base: pd.DataFrame, forecasting_input_base: ForecastingInferenceInput
):
input: ForecastingInferenceInput = copy.deepcopy(forecasting_input_base)
df = copy.deepcopy(ts_data_base)

timestamp_column = input.schema.timestamp_column
df[timestamp_column] = df[timestamp_column].astype(int)
df[timestamp_column] = range(1, SERIES_LENGTH + 1)
input.data = df.to_dict(orient="list")
runtime: InferenceRuntime = InferenceRuntime(config=config)
po: PredictOutput = runtime.forecast(input=input)
results = pd.DataFrame.from_dict(po.results[0])
assert results[timestamp_column].iloc[0] == SERIES_LENGTH + 1
assert results[timestamp_column].iloc[-1] - df[timestamp_column].iloc[-1] == FORECAST_LENGTH
assert results.dtypes[timestamp_column] == df.dtypes[timestamp_column]


def test_forecast_with_bogus_timestamps(ts_data_base: pd.DataFrame, forecasting_input_base: ForecastingInferenceInput):
input: ForecastingInferenceInput = copy.deepcopy(forecasting_input_base)
df = copy.deepcopy(ts_data_base)

timestamp_column = input.schema.timestamp_column
df[timestamp_column] = df[timestamp_column].astype(str)
df[timestamp_column] = [str(x) for x in range(1, SERIES_LENGTH + 1)]
input.data = df.to_dict(orient="list")
runtime: InferenceRuntime = InferenceRuntime(config=config)
with pytest.raises(ValueError) as _:
runtime.forecast(input=input)


def test_forecast_with_bogus_values(ts_data_base: pd.DataFrame, forecasting_input_base: ForecastingInferenceInput):
input: ForecastingInferenceInput = copy.deepcopy(forecasting_input_base)
df = copy.deepcopy(ts_data_base)
df["VAL"] = df["VAL"].astype(str)
df["VAL"] = [str(x) for x in range(1, SERIES_LENGTH + 1)]
input.data = df.to_dict(orient="list")
runtime: InferenceRuntime = InferenceRuntime(config=config)
with pytest.raises(HTTPException) as _:
runtime.forecast(input=input)


def test_forecast_with_bogus_model_id(ts_data_base: pd.DataFrame, forecasting_input_base: ForecastingInferenceInput):
input: ForecastingInferenceInput = copy.deepcopy(forecasting_input_base)
df = copy.deepcopy(ts_data_base)
input.data = df.to_dict(orient="list")
input.model_id = "hoo-hah"

runtime: InferenceRuntime = InferenceRuntime(config=config)
with pytest.raises(HTTPException) as _:
runtime.forecast(input=input)


def test_forecast_with_insufficient_context_length(
ts_data_base: pd.DataFrame, forecasting_input_base: ForecastingInferenceInput
):
input: ForecastingInferenceInput = copy.deepcopy(forecasting_input_base)
df = copy.deepcopy(ts_data_base)
df = df.iloc[0:-100]

input.data = df.to_dict(orient="list")

runtime: InferenceRuntime = InferenceRuntime(config=config)
with pytest.raises(HTTPException) as _:
runtime.forecast(input=input)


@pytest.mark.skip
def test_forecast_with_nan_data(ts_data_base: pd.DataFrame, forecasting_input_base: ForecastingInferenceInput):
input: ForecastingInferenceInput = copy.deepcopy(forecasting_input_base)
df = copy.deepcopy(ts_data_base)
df.iloc[0, df.columns.get_loc("VAL")] = np.nan

input.data = df.to_dict(orient="list")

runtime: InferenceRuntime = InferenceRuntime(config=config)
# with pytest.raises(HTTPException) as _:
runtime.forecast(input=input)


# @pytest.mark.skip
def test_forecast_with_missing_row(ts_data_base: pd.DataFrame, forecasting_input_base: ForecastingInferenceInput):
input: ForecastingInferenceInput = copy.deepcopy(forecasting_input_base)
df = copy.deepcopy(ts_data_base)
df = df.drop(index=10)

# append a row to give it the correct length
# don't forget to update the timestamp accordingly in the
# appended row

input.data = df.to_dict(orient="list")

runtime: InferenceRuntime = InferenceRuntime(config=config)
with pytest.raises(HTTPException) as _:
runtime.forecast(input=input)
Loading