-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix local auto labeling #600
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -534,6 +534,7 @@ | |
post_hook=post_hook, | ||
batch_size=batch_size, | ||
log_to_field=log_to_field, | ||
is_prediction=False, | ||
) | ||
self.datasource.metadata_field(log_to_field).set_annotation().apply() | ||
return res | ||
|
@@ -548,6 +549,7 @@ | |
post_hook: Callable[[Any], Any] = identity_func, | ||
batch_size: int = 1, | ||
log_to_field: Optional[str] = None, | ||
is_prediction: Optional[bool] = True | ||
) -> Dict[str, Any]: | ||
""" | ||
Fetch an MLflow model from a specific repository and use it to predict on the datapoints in this QueryResult. | ||
|
@@ -577,6 +579,8 @@ | |
batch_size: Size of the file batches that are sent to ``model.predict()``. | ||
Default batch size is 1, but it is still being sent as a list for consistency. | ||
log_to_field: If set, writes prediction results to this metadata field in the datasource. | ||
is_prediction: If True, log as a prediction (will need to be manually approved as annotation). | ||
If False, will be automatically approved as an annotation | ||
""" | ||
if not host: | ||
host = self.datasource.source.repoApi.host | ||
|
@@ -602,7 +606,7 @@ | |
if "torch" in loader_module: | ||
model.predict = model.__call__ | ||
|
||
return self.generate_predictions(lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field) | ||
return self.generate_predictions(lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field, is_prediction) | ||
|
||
def get_annotations(self, **kwargs) -> "QueryResult": | ||
""" | ||
|
@@ -860,8 +864,9 @@ | |
return ds | ||
|
||
@staticmethod | ||
def _get_predict_dict(predictions, remote_path, log_to_field): | ||
res = {log_to_field: json.dumps(predictions[remote_path][0]).encode("utf-8")} | ||
def _get_predict_dict(predictions, remote_path, log_to_field, is_prediction=False): | ||
ls_json_key = "annotations" if not is_prediction else "predictions" | ||
res = {log_to_field: json.dumps({ls_json_key: [predictions[remote_path][0]]}).encode("utf-8")} | ||
if len(predictions[remote_path]) == 2: | ||
res[f"{log_to_field}_score"] = predictions[remote_path][1] | ||
|
||
|
@@ -934,6 +939,7 @@ | |
predict_fn: CustomPredictor, | ||
batch_size: int = 1, | ||
log_to_field: Optional[str] = None, | ||
is_prediction: Optional[bool] = False, | ||
) -> Dict[str, Tuple[str, Optional[float]]]: | ||
""" | ||
Sends all the datapoints returned in this QueryResult as prediction targets for | ||
|
@@ -943,6 +949,7 @@ | |
predict_fn: function that handles batched input and returns predictions with an optional prediction score. | ||
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously | ||
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine. | ||
is_prediction: (optional, default: False) whether we're creating predictions or annotations. | ||
If None, just returns predictions. | ||
(in addition to logging to a field, iff that parameter is set) | ||
""" | ||
|
@@ -956,7 +963,7 @@ | |
_Batcher(dset, batch_size) if batch_size != 1 else dset | ||
): # encapsulates dataset with batcher if necessary and iterates over it | ||
for prediction, remote_path in zip( | ||
predict_fn(local_paths), | ||
[predict_fn(local_paths)], | ||
[result.path for result in self[idx * batch_size : (idx + 1) * batch_size]], | ||
Comment on lines
965
to
967
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [blocker] This doesn't look right. CustomPredictor = Callable[
[List[str]],
List[Tuple[Any, Optional[float]]],
] So the function should be accepting a list of datapoints, and return a list of prediction tuples. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does the return from the prediction function you're using right now look like? |
||
): | ||
predictions[remote_path] = prediction | ||
|
@@ -965,7 +972,7 @@ | |
if log_to_field: | ||
with self.datasource.metadata_context() as ctx: | ||
for remote_path in predictions: | ||
ctx.update_metadata(remote_path, self._get_predict_dict(predictions, remote_path, log_to_field)) | ||
ctx.update_metadata(remote_path, self._get_predict_dict(predictions, remote_path, log_to_field, is_prediction)) | ||
return predictions | ||
|
||
def generate_annotations(self, predict_fn: CustomPredictor, batch_size: int = 1, log_to_field: str = "annotation"): | ||
|
@@ -982,6 +989,7 @@ | |
predict_fn, | ||
batch_size=batch_size, | ||
log_to_field=log_to_field, | ||
is_prediction=False | ||
) | ||
self.datasource.metadata_field(log_to_field).set_annotation().apply() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a
is_prediction
kwarg on apredict_with_...
function is IMO really confusing API design.I won't put my foot down and I will push this through for whatever you need right now, but we will probably need to change this in the near future if you want people to actually be using this.
Either way, thanks for adding the docstring