Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Service abstraction #171

Merged
merged 130 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
130 commits
Select commit Hold shift + click to select a range
ade732f
wip, model abstraction
wgifford Oct 22, 2024
8b2c637
WIP, abstraction implementation
wgifford Oct 23, 2024
1b9cf9e
wip, abstraction
wgifford Oct 28, 2024
4070d02
use new abstraction
wgifford Oct 28, 2024
bde2ca3
wrapper -> handler
wgifford Oct 28, 2024
8907da8
allow older models
wgifford Oct 28, 2024
3b80d45
fix merge
wgifford Oct 28, 2024
75471d5
dead code
wgifford Oct 28, 2024
98c13d9
model -> handler
wgifford Oct 28, 2024
b483ff1
separate hf handler, move some hf utils
wgifford Oct 29, 2024
58f1279
move hf implementation
wgifford Oct 29, 2024
fd061de
Simplify classes
wgifford Oct 29, 2024
2a4806f
adjust type hints
wgifford Oct 29, 2024
aa6ac13
adjust logger
wgifford Oct 29, 2024
b8a7e1e
return only model
wgifford Oct 29, 2024
98b38d2
tsfm_config -> handler_config
wgifford Oct 29, 2024
6c41147
separate model_id and model_path attributes
wgifford Oct 29, 2024
ef16275
logger.exception
wgifford Oct 29, 2024
e47b448
add HF-like config object
wgifford Oct 29, 2024
e3f149c
config object, docstrings
wgifford Oct 30, 2024
4c863d4
remove old code
wgifford Oct 30, 2024
dde5bcf
docstrings, add additional exogenous specifier
wgifford Oct 30, 2024
51cfe3e
docstrings
wgifford Oct 30, 2024
1203b6e
avoid explicit conversion to datetime, support numeric timestamps
wgifford Oct 31, 2024
018ef3e
handle np.datetime64
wgifford Oct 31, 2024
e33c27e
comprehensive timestamp tests
wgifford Oct 31, 2024
c59dca9
improve coverage by testing lib code
ssiegel95 Oct 31, 2024
ca43e71
renamed
ssiegel95 Oct 31, 2024
3e135f1
call lib functions
ssiegel95 Oct 31, 2024
031ef3f
update signatures, train preprocessor during prepare
wgifford Oct 31, 2024
089b79a
:Merge remote-tracking branch 'origin/service_abstraction' into coverage
ssiegel95 Oct 31, 2024
bf590ec
more fixtures
ssiegel95 Oct 31, 2024
861cfa5
remove unused config options
wgifford Nov 1, 2024
15cdb63
ensure max length calculated
wgifford Nov 1, 2024
373f23c
docstrings, explicit definition of tsfm_config args
wgifford Nov 1, 2024
3ef7463
adjust for new defaults
wgifford Nov 1, 2024
9add858
Merge remote-tracking branch 'origin/service_abstraction' into coverage
ssiegel95 Nov 1, 2024
527b8d3
add more data related tests
ssiegel95 Nov 1, 2024
edf56e8
avoid groupby warning
wgifford Nov 1, 2024
c8548ea
catch a data processing exception
ssiegel95 Nov 1, 2024
25006d9
up coverage
ssiegel95 Nov 1, 2024
cf339d1
Merge remote-tracking branch 'origin/service_abstraction' into coverage
ssiegel95 Nov 2, 2024
5d11172
fixes
ssiegel95 Nov 2, 2024
4f51440
Merge pull request #176 from ibm-granite/coverage
ssiegel95 Nov 2, 2024
a280579
poetry lock
ssiegel95 Nov 2, 2024
02ff3e5
Revert "Coverage"
ssiegel95 Nov 2, 2024
bcc1a69
Merge pull request #177 from ibm-granite/revert-176-coverage
ssiegel95 Nov 2, 2024
5104a78
update lock file
ssiegel95 Nov 2, 2024
55d9ed3
dummy BaseParameters
wgifford Nov 5, 2024
4fcf184
refactor to include task specific handler
wgifford Nov 5, 2024
4f49475
pass kwargs
wgifford Nov 6, 2024
01a85ad
use suffix ForecastingHandler
wgifford Nov 7, 2024
a61abbf
wip, model abstraction
wgifford Oct 22, 2024
0c3d2f2
WIP, abstraction implementation
wgifford Oct 23, 2024
2086ba7
wip, abstraction
wgifford Oct 28, 2024
feab33c
use new abstraction
wgifford Oct 28, 2024
9cb5d7f
wrapper -> handler
wgifford Oct 28, 2024
346ca5a
allow older models
wgifford Oct 28, 2024
4fe86dd
fix merge
wgifford Oct 28, 2024
e675bd3
dead code
wgifford Oct 28, 2024
9fb9fa9
model -> handler
wgifford Oct 28, 2024
661bfeb
separate hf handler, move some hf utils
wgifford Oct 29, 2024
a86c43a
move hf implementation
wgifford Oct 29, 2024
ec003da
Simplify classes
wgifford Oct 29, 2024
91b3785
adjust type hints
wgifford Oct 29, 2024
3d07ba5
adjust logger
wgifford Oct 29, 2024
827b5d7
return only model
wgifford Oct 29, 2024
590af16
tsfm_config -> handler_config
wgifford Oct 29, 2024
fe6b02e
separate model_id and model_path attributes
wgifford Oct 29, 2024
d3f943c
logger.exception
wgifford Oct 29, 2024
d9cc573
add HF-like config object
wgifford Oct 29, 2024
fecebd3
config object, docstrings
wgifford Oct 30, 2024
02cb188
remove old code
wgifford Oct 30, 2024
8ff08a4
docstrings, add additional exogenous specifier
wgifford Oct 30, 2024
14a0412
docstrings
wgifford Oct 30, 2024
50d9161
avoid explicit conversion to datetime, support numeric timestamps
wgifford Oct 31, 2024
b3bd24f
handle np.datetime64
wgifford Oct 31, 2024
16d9e9e
improve coverage by testing lib code
ssiegel95 Oct 31, 2024
51e41d8
renamed
ssiegel95 Oct 31, 2024
31ccfc0
call lib functions
ssiegel95 Oct 31, 2024
a2787f5
comprehensive timestamp tests
wgifford Oct 31, 2024
b90d862
update signatures, train preprocessor during prepare
wgifford Oct 31, 2024
b473ced
more fixtures
ssiegel95 Oct 31, 2024
c83c2ae
remove unused config options
wgifford Nov 1, 2024
f110ead
ensure max length calculated
wgifford Nov 1, 2024
1f1d847
docstrings, explicit definition of tsfm_config args
wgifford Nov 1, 2024
7aed5f3
adjust for new defaults
wgifford Nov 1, 2024
93d1a60
add more data related tests
ssiegel95 Nov 1, 2024
d91a9eb
catch a data processing exception
ssiegel95 Nov 1, 2024
4deb964
up coverage
ssiegel95 Nov 1, 2024
74003a3
avoid groupby warning
wgifford Nov 1, 2024
35b264d
fixes
ssiegel95 Nov 2, 2024
cd6c594
poetry lock
ssiegel95 Nov 2, 2024
9ae4c4e
Revert "Coverage"
ssiegel95 Nov 2, 2024
25a78c8
update lock file
ssiegel95 Nov 2, 2024
74503e7
dummy BaseParameters
wgifford Nov 5, 2024
9d45f50
refactor to include task specific handler
wgifford Nov 5, 2024
7ad7a78
pass kwargs
wgifford Nov 6, 2024
0050680
use suffix ForecastingHandler
wgifford Nov 7, 2024
c81491e
fix merge issues
wgifford Nov 11, 2024
932a7c7
update extend logic
wgifford Nov 11, 2024
91f0ce5
update lock
wgifford Nov 11, 2024
fc232f6
fix future data handling, add tests
wgifford Nov 11, 2024
3ff882e
unique->nunique
wgifford Nov 11, 2024
e12b401
bump to latest tsfm_public
wgifford Nov 11, 2024
e785fc6
Merge branch 'main' into service_abstraction
wgifford Nov 12, 2024
1969190
fix merge error
wgifford Nov 12, 2024
317c5d8
bump version
wgifford Nov 12, 2024
868babd
allow saving the config
wgifford Nov 15, 2024
24462b1
check the schema
wgifford Nov 15, 2024
2f80f4d
test case for fine-tuned model
wgifford Nov 15, 2024
605dfe6
cleanup
wgifford Nov 15, 2024
1a80807
add modelspec path
ssiegel95 Nov 18, 2024
5b4f663
fix spelling mistake in log message
ssiegel95 Nov 18, 2024
ead5e86
ignore dataframe_checks.py
ssiegel95 Nov 18, 2024
0628b7a
Merge branch 'service_abstraction' of github.com:ibm-granite/granite-…
ssiegel95 Nov 18, 2024
7e6aea1
Merge remote-tracking branch 'origin/main' into service_abstraction
ssiegel95 Nov 18, 2024
7c0e48f
fix duplicate keywork arg
ssiegel95 Nov 18, 2024
27f7beb
use different model
ssiegel95 Nov 18, 2024
8fb942c
make parameters optional
ssiegel95 Nov 18, 2024
ebbe732
intercept thrown pandas exceptions
ssiegel95 Nov 18, 2024
560d4c3
first implementation of data point counts
wgifford Nov 18, 2024
a64e63e
add tests for data point counts, update calculation
wgifford Nov 18, 2024
1f6b758
update poetry lock
ssiegel95 Nov 18, 2024
10b1bda
no longer needed
ssiegel95 Nov 18, 2024
4ec5fb2
Reapply "Coverage"
ssiegel95 Nov 18, 2024
e94b51e
reduce verbosity of local server during tests
ssiegel95 Nov 18, 2024
781f796
pytest-cov
ssiegel95 Nov 18, 2024
534dd0a
fix some tests, skip one other
ssiegel95 Nov 18, 2024
fef2003
Better tests, fix bug :bug:
wgifford Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 293 additions & 0 deletions services/inference/tsfminference/hf_service_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
"""Service handler for HuggingFace models"""

import copy
import importlib
import logging
import pathlib
from pathlib import Path
from typing import Any, Dict, Optional, Union

import pandas as pd
import transformers
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel

from tsfm_public import TimeSeriesForecastingPipeline, TimeSeriesPreprocessor

from .inference_payloads import (
ForecastingMetadataInput,
ForecastingParameters,
)
from .service_handler import ServiceHandler


LOGGER = logging.getLogger(__file__)


class HuggingFaceHandler(ServiceHandler):
def __init__(
self,
model_id: Union[str, Path],
tsfm_config: Dict[str, Any],
):
super().__init__(model_id=model_id, tsfm_config=tsfm_config)

if "model_type" in tsfm_config and "model_config_name" in tsfm_config and "module_path" in tsfm_config:
register_config(
tsfm_config["model_type"],
tsfm_config["model_config_name"],
tsfm_config["module_path"],
)
LOGGER.info(f"registered {tsfm_config['model_type']}")

def load_preprocessor(self, model_path: str) -> TimeSeriesPreprocessor:
# load preprocessor
try:
preprocessor = TimeSeriesPreprocessor.from_pretrained(model_path)
LOGGER.info("Successfully loaded preprocessor")
except OSError:
preprocessor = None
LOGGER.info("No preprocessor found")
except Exception as ex:
raise ex

return preprocessor

def load_hf_config(self, model_path: str, **extra_config_kwargs: Dict[str, Any]) -> PretrainedConfig:
# load config, separate from load model, since we may need to inspect config first
conf = load_config(model_path, **extra_config_kwargs)

return conf

def load_hf_model(self, model_path: str, config: PretrainedConfig) -> PreTrainedModel:
model = load_model(
model_path,
config=config,
module_path=self.tsfm_config.get("module_path", None),
)

LOGGER.info("Successfully loaded model")
return model

def _get_config_kwargs(
self,
parameters: Optional[ForecastingParameters] = None,
preprocessor: Optional[TimeSeriesPreprocessor] = None,
) -> Dict[str, Any]:
return {}

def _prepare(
self,
schema: Optional[ForecastingMetadataInput] = None,
parameters: Optional[ForecastingParameters] = None,
) -> "HuggingFaceHandler":
# to do: use config parameters below
# issue: may need to know data length to set parameters upon model load (multst)

preprocessor_params = copy.deepcopy(schema.model_dump())
preprocessor_params["prediction_length"] = parameters.prediction_length

LOGGER.info(f"Preprocessor params: {preprocessor_params}")

preprocessor = self.load_preprocessor(self.model_id)

if preprocessor is None:
preprocessor = TimeSeriesPreprocessor(
**preprocessor_params,
scaling=False,
encode_categorical=False,
)
# we don't set context length or prediction length above because it is not needed for inference

model_config_kwargs = self._get_config_kwargs(
parameters=parameters,
preprocessor=preprocessor,
)
LOGGER.info(f"model_config_kwargs: {model_config_kwargs}")
model_config = self.load_hf_config(self.model_id, **model_config_kwargs)

model = self.load_hf_model(self.model_id, config=model_config)

self.config = model_config
self.model = model
self.preprocessor = preprocessor

return self

def _run(
self,
data: pd.DataFrame,
future_data: Optional[pd.DataFrame] = None,
schema: Optional[ForecastingMetadataInput] = None,
parameters: Optional[ForecastingParameters] = None,
) -> pd.DataFrame:
# tbd, can this be moved to the HFWrapper?
# error checking once data available
if self.preprocessor.freq is None:
# train to estimate freq if not available
self.preprocessor.train(data)
LOGGER.info(f"Data frequency determined: {self.preprocessor.freq}")

# warn if future data is not provided, but is needed by the model
if self.preprocessor.exogenous_channel_indices and future_data is None:
ValueError(
"Future data should be provided for exogenous columns where the future is known (`control_columns` and `observable_columns`)"
)

forecast_pipeline = TimeSeriesForecastingPipeline(
model=self.model,
explode_forecasts=True,
feature_extractor=self.preprocessor,
add_known_ground_truth=False,
freq=self.preprocessor.freq,
)
forecasts = forecast_pipeline(data, future_time_series=future_data, inverse_scale_outputs=True)

return forecasts

def _train(
self,
) -> "HuggingFaceHandler": ...


def register_config(model_type: str, model_config_name: str, module_path: str) -> None:
"""Register a configuration for a particular model architecture

Args:
model_type (Optional[str], optional): The type of the model, from the model implementation. Defaults to None.
model_config_name (Optional[str], optional): The name of configuration class for the model. Defaults to None.
module_path (Optional[str], optional): Python module path that can be used to load the
config/model. Defaults to None.

Raises:
RuntimeError: Raised when the module cannot be imported from the provided module path.
"""
# example
# model_type: "tinytimemixer"
# model_config_name: "TinyTimeMixerConfig"
# module_path: "tsfm" # place where config should be importable

# AutoConfig.register("tinytimemixer", TinyTimeMixerConfig)
try:
mod = importlib.import_module(module_path)
conf_class = getattr(mod, model_config_name, None)
except ModuleNotFoundError as exc: # modulenot found, key error ?
raise RuntimeError(f"Could not load {model_config_name} from {module_path}") from exc

if conf_class is not None:
AutoConfig.register(model_type, conf_class)
else:
# issue warning?
pass


def load_config(
model_path: Union[str, pathlib.Path],
model_type: Optional[str] = None,
model_config_name: Optional[str] = None,
module_path: Optional[str] = None,
**extra_config_kwargs: Dict[str, Any],
) -> PretrainedConfig:
"""Load configuration

Attempts to load the configuration, if it is not loadable, then we register it with the AutoConfig mechanism.

Args:
model_path (pathlib.Path): The path from which to load the config.
model_type (Optional[str], optional): The type of the model, from the model implementation. Defaults to None.
model_config_name (Optional[str], optional): The name of configuration class for the model. Defaults to None.
module_path (Optional[str], optional): Python module path that can be used to load the
config/model. Defaults to None.

Returns:
PretrainedConfig: The configuration object corresponding to the pretrained model.
"""
# load config first try autoconfig, if not then we register and load

try:
conf = AutoConfig.from_pretrained(model_path, **extra_config_kwargs)
except (KeyError, ValueError) as exc: # determine error raised by autoconfig
if model_type is None or model_config_name is None or module_path is None:
raise ValueError("model_type, model_config_name, and module_path should be specified.") from exc

register_config(model_type, model_config_name, module_path)
conf = AutoConfig.from_pretrained(model_path, **extra_config_kwargs)

return conf


def _get_model_class(config: PretrainedConfig, module_path: Optional[str] = None) -> type:
"""Helper to find model class based on config object

First the module_path will be checked if it can be loaded in the current environment. If not
then the transformers library will be used.

Args:
config (PretrainedConfig): HF configuration for the model.
module_path (Optional[str], optional): Python module path that can be used to load the
config/model. Defaults to None.

Raises:
AttributeError: Raised if the module at module_path cannot be loaded.
AttributeError: If the architecture provided by the config cannot be loaded from
the module.

Returns:
type: The class for the model.
"""
if module_path is not None:
try:
mod = importlib.import_module(module_path)
except ModuleNotFoundError as exc:
raise AttributeError("Could not load module '{module_path}'.") from exc
else:
mod = transformers

# get architecture from model config
architectures = getattr(config, "architectures", [])
for arch in architectures:
try:
model_class = getattr(mod, arch)
return model_class
except AttributeError as exc:
# catch specific error import error or attribute error
raise AttributeError("Could not load model class for architecture '{arch}'.") from exc


def load_model(
model_path: Union[str, pathlib.Path],
config: Optional[PretrainedConfig] = None,
module_path: Optional[str] = None,
) -> PreTrainedModel:
"""Load a pretrained model.
If module_path is provided, load the model using the provided module path.

Args:
model_path (Union[str, pathlib.Path]): Path to a location where the model can be loaded.
config (Optional[PretrainedConfig], optional): HF Configuration object. Defaults to None.
module_path (Optional[str], optional): Python module path that can be used to load the
config/model. Defaults to None.

Raises:
ValueError: Raised if loading from a module_path and a configuration object is not provided.

Returns:
PreTrainedModel: The loaded pretrained model.
"""

if module_path is not None and config is None:
return None, ValueError("Config must be provided when loading from a custom module_path")

try:
if config is not None:
model_class = _get_model_class(config, module_path=module_path)
LOGGER.info(f"Found model class: {model_class.__name__}")
model = model_class.from_pretrained(model_path, config=config)
return model, None

model = AutoModel.from_pretrained(model_path)
return model, None
except Exception as e:
return None, e

LOGGER.info(f"Found model class: {model.__class__.__name__}")
return model, None
Loading