-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support kwargs in predict() and predict_proba() #113
base: master
Are you sure you want to change the base?
Conversation
cols.insert(0, cols.pop(id_index)) | ||
if target is not None: | ||
target_index = cols.index(target) | ||
cols.append(cols.pop(target_index)) | ||
data = data[cols] | ||
|
||
if static_features is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making id_column
an optional argument might break the pd.merge
in line 45.
I think that keeping the id_column
, timestamp_column
and target
as part of the TimeSeriesSagemakerBackend
API is fine since this class is not user-facing. In the public API of the TimeSeriesCloudPredictor
these arguments are optional.
id_column: str = "item_id", | ||
timestamp_column: str = "timestamp", | ||
id_column: Optional[str] = None, | ||
timestamp_column: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is our motivation for changing the defaults here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to make the .predict()
and .predict_real_time()
API align with what we have in the Chronos tutorial. I see that we might not have id_column
and timestamp_column
in the train_data, but please correct me if I misunderstood the example.
@@ -175,13 +178,18 @@ def predict_real_time( | |||
Pandas.DataFrame | |||
Predict results in DataFrame | |||
""" | |||
self.id_column = id_column or self.id_column | |||
self.timestamp_column = timestamp_column or self.timestamp_column | |||
self.target_column = target or self.target_column |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What will happen if both self.target_column is None
and target
is None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the TabularPredictor
API does not require target
at the moment, that's why I made it optional. See this tutorial, e.g.
predictor = TimeSeriesPredictor(prediction_length=14).fit(train_data)
It has been handled https://github.com/autogluon/autogluon/blob/bda6174f4a1fb8398aef4f375d9eacfd29bb46d9/timeseries/src/autogluon/timeseries/predictor.py#L179
@@ -805,6 +805,7 @@ def predict_proba( | |||
instance_count: int = 1, | |||
custom_image_uri: Optional[str] = None, | |||
wait: bool = True, | |||
inference_kwargs: Optional[Dict[str, Any]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add docstring
@@ -556,6 +556,7 @@ def predict( | |||
custom_image_uri: Optional[str] = None, | |||
wait: bool = True, | |||
backend_kwargs: Optional[Dict] = None, | |||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add example code should be added to tutorials that showcase specifying kwargs. Otherwise it will be hard for users to realize how to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I will add some tutorials with this PR.
@@ -648,6 +650,7 @@ def predict_proba( | |||
custom_image_uri: Optional[str] = None, | |||
wait: bool = True, | |||
backend_kwargs: Optional[Dict] = None, | |||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add docstring
@@ -199,6 +210,7 @@ def predict( | |||
custom_image_uri: Optional[str] = None, | |||
wait: bool = True, | |||
backend_kwargs: Optional[Dict] = None, | |||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring
Issue #, if available:
Description of changes:
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.