Skip to content

Commit

Permalink
Simplify TimeSeriesCloudPredictor API (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
shchur authored Jan 18, 2024
1 parent 2a8f757 commit f1b198f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 63 deletions.
28 changes: 8 additions & 20 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,40 +110,28 @@ result = cloud_predictor.predict(test_data)
import pandas as pd
from autogluon.cloud import TimeSeriesCloudPredictor

data = pd.read_csv("https://autogluon.s3.amazonaws.com/datasets/cloud/timeseries_train.csv")
id_column="item_id"
timestamp_column="timestamp"
target="target"
data = pd.read_csv("https://autogluon.s3.amazonaws.com/datasets/timeseries/m4_hourly_tiny/train.csv")

predictor_init_args = {
"target": target
"target": "target",
"prediction_length" : 24,
} # args used when creating TimeSeriesPredictor()
predictor_fit_args = {
"train_data": data,
"time_limit": 120
"time_limit": 120,
} # args passed to TimeSeriesPredictor.fit()
cloud_predictor = TimeSeriesCloudPredictor(cloud_output_path="YOUR_S3_BUCKET_PATH")
cloud_predictor.fit(
predictor_init_args=predictor_init_args,
predictor_fit_args=predictor_fit_args,
id_column=id_column,
timestamp_column=timestamp_column
id_column="item_id",
timestamp_column="timestamp",
)
cloud_predictor.deploy()
result = cloud_predictor.predict_real_time(
test_data=data,
id_column=id_column,
timestamp_column=timestamp_column,
target=target
)
result = cloud_predictor.predict_real_time(data)
cloud_predictor.cleanup_deployment()
# Batch inference
result = cloud_predictor.predict(
test_data=data,
id_column=id_column,
timestamp_column=timestamp_column,
target=target
)
result = cloud_predictor.predict(data)
```
:::

Expand Down
62 changes: 34 additions & 28 deletions src/autogluon/cloud/predictor/timeseries_cloud_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ class TimeSeriesCloudPredictor(CloudPredictor):
predictor_file_name = "TimeSeriesCloudPredictor.pkl"
backend_map = {SAGEMAKER: TIMESERIES_SAGEMAKER}

def __init__(
self,
local_output_path: Optional[str] = None,
cloud_output_path: Optional[str] = None,
backend: str = SAGEMAKER,
verbosity: int = 2,
) -> None:
super().__init__(
local_output_path=local_output_path,
cloud_output_path=cloud_output_path,
backend=backend,
verbosity=verbosity,
)
self.target_column: Optional[str] = None
self.id_column: Optional[str] = None
self.timestamp_column: Optional[str] = None

@property
def predictor_type(self):
"""
Expand All @@ -33,8 +50,8 @@ def fit(
*,
predictor_init_args: Dict[str, Any],
predictor_fit_args: Dict[str, Any],
id_column: str,
timestamp_column: str,
id_column: str = "item_id",
timestamp_column: str = "timestamp",
static_features: Optional[Union[str, pd.DataFrame]] = None,
framework_version: str = "latest",
job_name: Optional[str] = None,
Expand All @@ -56,10 +73,10 @@ def fit(
Init args for the predictor
predictor_fit_args: dict
Fit args for the predictor
id_column: str
Name of the 'item_id' column
timestamp_column: str
Name of the 'timestamp' column
id_column: str, default = "item_id"
Name of the item ID column
timestamp_column: str, default = "timestamp"
Name of the timestamp column
static_features: Optional[pd.DataFrame]
An optional data frame describing the metadata attributes of individual items in the item index.
For more detail, please refer to `TimeSeriesDataFrame` documentation:
Expand Down Expand Up @@ -102,6 +119,11 @@ def fit(
), "Predictor is already fit! To fit additional models, create a new `CloudPredictor`"
if backend_kwargs is None:
backend_kwargs = {}

self.target_column = predictor_init_args.get("target", "target")
self.id_column = id_column
self.timestamp_column = timestamp_column

backend_kwargs = self.backend.parse_backend_fit_kwargs(backend_kwargs)
self.backend.fit(
predictor_init_args=predictor_init_args,
Expand All @@ -124,9 +146,6 @@ def fit(
def predict_real_time(
self,
test_data: Union[str, pd.DataFrame],
id_column: str,
timestamp_column: str,
target: str,
static_features: Optional[Union[str, pd.DataFrame]] = None,
accept: str = "application/x-parquet",
**kwargs,
Expand All @@ -141,16 +160,10 @@ def predict_real_time(
test_data: Union(str, pandas.DataFrame)
The test data to be inferenced.
Can be a pandas.DataFrame or a local path to a csv file.
id_column: str
Name of the 'item_id' column
timestamp_column: str
Name of the 'timestamp' column
static_features: Optional[pd.DataFrame]
An optional data frame describing the metadata attributes of individual items in the item index.
For more detail, please refer to `TimeSeriesDataFrame` documentation:
https://auto.gluon.ai/stable/api/autogluon.predictor.html#timeseriesdataframe
target: str
Name of column that contains the target values to forecast
accept: str, default = application/x-parquet
Type of accept output content.
Valid options are application/x-parquet, text/csv, application/json
Expand All @@ -164,9 +177,9 @@ def predict_real_time(
"""
return self.backend.predict_real_time(
test_data=test_data,
id_column=id_column,
timestamp_column=timestamp_column,
target=target,
id_column=self.id_column,
timestamp_column=self.timestamp_column,
target=self.target_column,
static_features=static_features,
accept=accept,
)
Expand All @@ -177,9 +190,6 @@ def predict_proba_real_time(self, **kwargs) -> pd.DataFrame:
def predict(
self,
test_data: Union[str, pd.DataFrame],
id_column: str,
timestamp_column: str,
target: str,
static_features: Optional[Union[str, pd.DataFrame]] = None,
predictor_path: Optional[str] = None,
framework_version: str = "latest",
Expand All @@ -203,10 +213,6 @@ def predict(
test_data: str
The test data to be inferenced.
Can be a pandas.DataFrame or a local path to a csv file.
id_column: str
Name of the 'item_id' column
timestamp_column: str
Name of the 'timestamp' column
static_features: Optional[Union[str, pd.DataFrame]]
An optional data frame describing the metadata attributes of individual items in the item index.
For more detail, please refer to `TimeSeriesDataFrame` documentation:
Expand Down Expand Up @@ -262,9 +268,9 @@ def predict(
backend_kwargs = self.backend.parse_backend_predict_kwargs(backend_kwargs)
return self.backend.predict(
test_data=test_data,
id_column=id_column,
timestamp_column=timestamp_column,
target=target,
id_column=self.id_column,
timestamp_column=self.timestamp_column,
target=self.target_column,
static_features=static_features,
predictor_path=predictor_path,
framework_version=framework_version,
Expand Down
17 changes: 2 additions & 15 deletions tests/unittests/timeseries/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@
def test_timeseries(test_helper, framework_version):
train_data = "timeseries_train.csv"
static_features = "timeseries_static_features.csv"
id_column = "item_id"
timestamp_column = "timestamp"
target = "target"
timestamp = test_helper.get_utc_timestamp_now()
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
test_helper.prepare_data(train_data, static_features)
time_limit = 60

predictor_init_args = dict(target=target)
predictor_init_args = dict(target="target", prediction_length=3)

predictor_fit_args = dict(
train_data=train_data,
Expand All @@ -35,25 +32,15 @@ def test_timeseries(test_helper, framework_version):
predictor_fit_args,
train_data,
fit_kwargs=dict(
id_column=id_column,
timestamp_column=timestamp_column,
static_features=static_features,
framework_version=framework_version,
custom_image_uri=training_custom_image_uri,
),
deploy_kwargs=dict(framework_version=framework_version, custom_image_uri=inference_custom_image_uri),
predict_kwargs=dict(
id_column=id_column,
timestamp_column=timestamp_column,
target=target,
static_features=static_features,
framework_version=framework_version,
custom_image_uri=inference_custom_image_uri,
),
predict_real_time_kwargs=dict(
id_column=id_column,
timestamp_column=timestamp_column,
target=target,
static_features=static_features,
),
predict_real_time_kwargs=dict(static_features=static_features),
)

0 comments on commit f1b198f

Please sign in to comment.