Skip to content

Commit

Permalink
[bug fix] Fix the time series target, timestamp and id column issue f…
Browse files Browse the repository at this point in the history
…or attach_job call (#167)
  • Loading branch information
tonyhoo authored Jan 16, 2025
1 parent cf38369 commit 9a1d727
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 16 deletions.
53 changes: 49 additions & 4 deletions src/autogluon/cloud/job/sagemaker_job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
from abc import abstractmethod
from typing import Optional
from typing import Dict, Optional, Union

import sagemaker

Expand Down Expand Up @@ -61,6 +62,10 @@ def _get_job_status(self):
def _get_output_path(self):
raise NotImplementedError

@abstractmethod
def _get_hyperparameters(self):
raise NotImplementedError

@property
def job_name(self):
return self._job_name
Expand Down Expand Up @@ -103,6 +108,17 @@ def get_output_path(self) -> Optional[str]:
return None
return self._get_output_path()

def get_hyperparameters(self) -> Dict[str, Union[int, str]]:
"""
Get hyperparameters of the job
Returns:
--------
dict:
Hyperparameters of the training job
"""
return self._get_hyperparameters()

def __getstate__(self):
state_dict = self.__dict__.copy()
state_dict["session"] = None
Expand Down Expand Up @@ -140,6 +156,7 @@ def info(self):
status=self.get_job_status(),
framework_version=self.framework_version,
artifact_path=self.get_output_path(),
hyperparameters=self.get_hyperparameters(),
)
return info

Expand All @@ -152,6 +169,15 @@ def _get_output_path(self):
assert self._output_path is not None
return self._output_path + "/" + self._output_filename

def _get_hyperparameters(self):
if self.job_name:
hyperparameters = self.session.describe_training_job(self.job_name)["HyperParameters"]
if "predictor_metadata" in hyperparameters:
hyperparameters["predictor_metadata"] = json.loads(hyperparameters["predictor_metadata"])
return hyperparameters
else:
return None

def run(
self,
role,
Expand Down Expand Up @@ -219,6 +245,7 @@ def info(self):
info = dict(
name=self.job_name,
status=self.get_job_status(),
hyperparameters=self.get_hyperparameters(),
result_path=self._get_output_path(),
)
return info
Expand Down Expand Up @@ -280,14 +307,22 @@ def run(
logger.log(20, "Inference model created successfully")
logger.log(20, "Creating transformer...")
transformer = model.transformer(
instance_count=instance_count, instance_type=instance_type, output_path=output_path, **transformer_kwargs
instance_count=instance_count,
instance_type=instance_type,
output_path=output_path,
**transformer_kwargs,
)
logger.log(20, "Transformer created successfully")

try:
logger.log(20, "Transforming")
transformer.transform(
test_input, job_name=job_name, split_type=split_type, content_type=content_type, wait=wait, **kwargs
test_input,
job_name=job_name,
split_type=split_type,
content_type=content_type,
wait=wait,
**kwargs,
)
self._job_name = job_name

Expand All @@ -309,4 +344,14 @@ def run(
transformer.delete_model()
logger.log(20, f"Predict results have been saved to {self.get_output_path()}")
else:
logger.log(20, "Predict asynchronously. You can use `info()` or `get_job_status()` to check the status.")
logger.log(
20,
"Predict asynchronously. You can use `info()` or `get_job_status()` to check the status.",
)

def _get_hyperparameters(self):
"""
Get hyperparameters of the batch transformation job
Currently batch transformation jobs don't have hyperparameters
"""
return {}
43 changes: 41 additions & 2 deletions src/autogluon/cloud/predictor/timeseries_cloud_predictor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import logging
from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -42,8 +43,7 @@ def predictor_type(self):
def _get_local_predictor_cls(self):
from autogluon.timeseries import TimeSeriesPredictor

predictor_cls = TimeSeriesPredictor
return predictor_cls
return TimeSeriesPredictor

def fit(
self,
Expand Down Expand Up @@ -124,6 +124,18 @@ def fit(
self.id_column = id_column
self.timestamp_column = timestamp_column

# Create predictor metadata dict
predictor_metadata = {
"id_column": self.id_column,
"timestamp_column": self.timestamp_column,
"target_column": self.target_column,
}

# Add to backend kwargs
backend_kwargs.setdefault("autogluon_sagemaker_estimator_kwargs", {}).setdefault("hyperparameters", {})[
"predictor_metadata"
] = json.dumps(predictor_metadata)

backend_kwargs = self.backend.parse_backend_fit_kwargs(backend_kwargs)
self.backend.fit(
predictor_init_args=predictor_init_args,
Expand Down Expand Up @@ -175,6 +187,10 @@ def predict_real_time(
Pandas.DataFrame
Predict results in DataFrame
"""
if self.id_column is None or self.timestamp_column is None or self.target_column is None:
raise ValueError(
"Please set id_column, timestamp_column and target_column before calling predict_real_time"
)
return self.backend.predict_real_time(
test_data=test_data,
id_column=self.id_column,
Expand Down Expand Up @@ -266,6 +282,8 @@ def predict(
if backend_kwargs is None:
backend_kwargs = {}
backend_kwargs = self.backend.parse_backend_predict_kwargs(backend_kwargs)
if self.id_column is None or self.timestamp_column is None or self.target_column is None:
raise ValueError("Please set id_column, timestamp_column and target_column before calling predict")
return self.backend.predict(
test_data=test_data,
id_column=self.id_column,
Expand All @@ -287,3 +305,24 @@ def predict_proba(
**kwargs,
) -> Optional[pd.DataFrame]:
raise ValueError(f"{self.__class__.__name__} does not support predict_proba operation.")

def attach_job(self, job_name: str) -> TimeSeriesCloudPredictor:
"""Attach to existing training job"""
super().attach_job(job_name)

# Get full job description including hyperparameters
job_desc = self.backend.get_fit_job_info()
hyperparameters = job_desc.get("hyperparameters", {})

# Extract and set predictor metadata
if hyperparameters and "predictor_metadata" in hyperparameters:
metadata = hyperparameters["predictor_metadata"]
self.id_column = metadata.get("id_column")
self.timestamp_column = metadata.get("timestamp_column")
self.target_column = metadata.get("target_column")
else:
logger.warning(
"No predictor metadata found in training job. Please set id_column, timestamp_column and target_column manually."
)

return self
33 changes: 25 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def test_functionality(
predict_real_time_kwargs = dict()
cloud_predictor.deploy(**deploy_kwargs)
CloudTestHelper.test_endpoint(
cloud_predictor, test_data, inference_kwargs=inference_kwargs, **predict_real_time_kwargs
cloud_predictor,
test_data,
inference_kwargs=inference_kwargs,
**predict_real_time_kwargs,
)
detached_endpoint = cloud_predictor.detach_endpoint()
cloud_predictor.attach_endpoint(detached_endpoint)
Expand All @@ -182,21 +185,35 @@ def test_functionality(

if predict_kwargs is None:
predict_kwargs = dict()
pred, pred_proba = cloud_predictor.predict_proba(test_data, **predict_kwargs)
assert isinstance(pred, pd.Series) and isinstance(pred_proba, pd.DataFrame)
if isinstance(cloud_predictor, TimeSeriesCloudPredictor):
pred = cloud_predictor.predict(test_data, **predict_kwargs)
assert isinstance(pred, pd.DataFrame)
else:
pred, pred_proba = cloud_predictor.predict_proba(test_data, **predict_kwargs)
assert isinstance(pred, pd.Series) and isinstance(pred_proba, pd.DataFrame)
info = cloud_predictor.info()
assert info["recent_batch_inference_job"]["status"] == "Completed"

# Test deploy with already trained predictor
trained_predictor_path = cloud_predictor.get_fit_job_output_path()
cloud_predictor_no_train.deploy(predictor_path=trained_predictor_path, **deploy_kwargs)
if isinstance(cloud_predictor_no_train, TimeSeriesCloudPredictor):
cloud_predictor_no_train.attach_job(job_name)
cloud_predictor_no_train.deploy(**deploy_kwargs)
else:
cloud_predictor_no_train.deploy(predictor_path=trained_predictor_path, **deploy_kwargs)
CloudTestHelper.test_endpoint(cloud_predictor_no_train, test_data, **predict_real_time_kwargs)
cloud_predictor_no_train.cleanup_deployment()

pred, pred_proba = cloud_predictor_no_train.predict_proba(
test_data, predictor_path=trained_predictor_path, **predict_kwargs
)
assert isinstance(pred, pd.Series) and isinstance(pred_proba, pd.DataFrame)
if isinstance(cloud_predictor_no_train, TimeSeriesCloudPredictor):
cloud_predictor_no_train.attach_job(job_name)
pred = cloud_predictor_no_train.predict(test_data, predictor_path=trained_predictor_path, **predict_kwargs)
assert isinstance(pred, pd.DataFrame)
else:
pred, pred_proba = cloud_predictor_no_train.predict_proba(
test_data, predictor_path=trained_predictor_path, **predict_kwargs
)
assert isinstance(pred, pd.Series) and isinstance(pred_proba, pd.DataFrame)

info = cloud_predictor_no_train.info()
assert info["recent_batch_inference_job"]["status"] == "Completed"

Expand Down
11 changes: 9 additions & 2 deletions tests/unittests/timeseries/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,29 @@ def test_timeseries(test_helper, framework_version):
time_limit = 60

predictor_init_args = dict(target="target", prediction_length=3)

predictor_fit_args = dict(
train_data=train_data,
presets="medium_quality",
time_limit=time_limit,
)

cloud_predictor = TimeSeriesCloudPredictor(
cloud_output_path=f"s3://autogluon-cloud-ci/test-timeseries/{framework_version}/{timestamp}",
local_output_path="test_timeseries_cloud_predictor",
)
cloud_predictor_no_train = TimeSeriesCloudPredictor(
cloud_output_path=f"s3://autogluon-cloud-ci/test-timeseries-no-train/{framework_version}/{timestamp}",
local_output_path="test_timeseries_cloud_predictor_no_train",
)

training_custom_image_uri = test_helper.get_custom_image_uri(framework_version, type="training", gpu=False)
inference_custom_image_uri = test_helper.get_custom_image_uri(framework_version, type="inference", gpu=False)
test_helper.test_basic_functionality(

test_helper.test_functionality(
cloud_predictor,
predictor_init_args,
predictor_fit_args,
cloud_predictor_no_train,
train_data,
fit_kwargs=dict(
static_features=static_features,
Expand Down

0 comments on commit 9a1d727

Please sign in to comment.