Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix local auto labeling #600

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions dagshub/data_engine/model/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@
post_hook=post_hook,
batch_size=batch_size,
log_to_field=log_to_field,
is_prediction=False,
)
self.datasource.metadata_field(log_to_field).set_annotation().apply()
return res
Expand All @@ -548,6 +549,7 @@
post_hook: Callable[[Any], Any] = identity_func,
batch_size: int = 1,
log_to_field: Optional[str] = None,
is_prediction: Optional[bool] = True
) -> Dict[str, Any]:
Comment on lines 551 to 553
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a is_prediction kwarg on a predict_with_... function is IMO really confusing API design.
I won't put my foot down and I will push this through for whatever you need right now, but we will probably need to change this in the near future if you want people to actually be using this.

Either way, thanks for adding the docstring

"""
Fetch an MLflow model from a specific repository and use it to predict on the datapoints in this QueryResult.
Expand Down Expand Up @@ -577,6 +579,8 @@
batch_size: Size of the file batches that are sent to ``model.predict()``.
Default batch size is 1, but it is still being sent as a list for consistency.
log_to_field: If set, writes prediction results to this metadata field in the datasource.
is_prediction: If True, log as a prediction (will need to be manually approved as annotation).
If False, will be automatically approved as an annotation
"""
if not host:
host = self.datasource.source.repoApi.host
Expand All @@ -602,7 +606,7 @@
if "torch" in loader_module:
model.predict = model.__call__

return self.generate_predictions(lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field)
return self.generate_predictions(lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field, is_prediction)

Check failure on line 609 in dagshub/data_engine/model/query_result.py

View workflow job for this annotation

GitHub Actions / Flake8

dagshub/data_engine/model/query_result.py#L609

Line too long (130 > 120 characters) (E501)

def get_annotations(self, **kwargs) -> "QueryResult":
"""
Expand Down Expand Up @@ -860,8 +864,9 @@
return ds

@staticmethod
def _get_predict_dict(predictions, remote_path, log_to_field):
res = {log_to_field: json.dumps(predictions[remote_path][0]).encode("utf-8")}
def _get_predict_dict(predictions, remote_path, log_to_field, is_prediction=False):
ls_json_key = "annotations" if not is_prediction else "predictions"
res = {log_to_field: json.dumps({ls_json_key: [predictions[remote_path][0]]}).encode("utf-8")}
if len(predictions[remote_path]) == 2:
res[f"{log_to_field}_score"] = predictions[remote_path][1]

Expand Down Expand Up @@ -934,6 +939,7 @@
predict_fn: CustomPredictor,
batch_size: int = 1,
log_to_field: Optional[str] = None,
is_prediction: Optional[bool] = False,
) -> Dict[str, Tuple[str, Optional[float]]]:
"""
Sends all the datapoints returned in this QueryResult as prediction targets for
Expand All @@ -943,6 +949,7 @@
predict_fn: function that handles batched input and returns predictions with an optional prediction score.
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
is_prediction: (optional, default: False) whether we're creating predictions or annotations.
If None, just returns predictions.
(in addition to logging to a field, iff that parameter is set)
"""
Expand All @@ -956,7 +963,7 @@
_Batcher(dset, batch_size) if batch_size != 1 else dset
): # encapsulates dataset with batcher if necessary and iterates over it
for prediction, remote_path in zip(
predict_fn(local_paths),
[predict_fn(local_paths)],
[result.path for result in self[idx * batch_size : (idx + 1) * batch_size]],
Comment on lines 965 to 967
Copy link
Member

@kbolashev kbolashev Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blocker] This doesn't look right.
The signature of CustomPredictor (predict_fn) is:

CustomPredictor = Callable[
    [List[str]],
    List[Tuple[Any, Optional[float]]],
]

So the function should be accepting a list of datapoints, and return a list of prediction tuples.
Why are you wrapping the result in another list? This zip will end up processing all of the returned predictions for one single path.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the return from the prediction function you're using right now look like?

):
predictions[remote_path] = prediction
Expand All @@ -965,7 +972,7 @@
if log_to_field:
with self.datasource.metadata_context() as ctx:
for remote_path in predictions:
ctx.update_metadata(remote_path, self._get_predict_dict(predictions, remote_path, log_to_field))
ctx.update_metadata(remote_path, self._get_predict_dict(predictions, remote_path, log_to_field, is_prediction))

Check failure on line 975 in dagshub/data_engine/model/query_result.py

View workflow job for this annotation

GitHub Actions / Flake8

dagshub/data_engine/model/query_result.py#L975

Line too long (131 > 120 characters) (E501)
return predictions

def generate_annotations(self, predict_fn: CustomPredictor, batch_size: int = 1, log_to_field: str = "annotation"):
Expand All @@ -982,6 +989,7 @@
predict_fn,
batch_size=batch_size,
log_to_field=log_to_field,
is_prediction=False
)
self.datasource.metadata_field(log_to_field).set_annotation().apply()

Expand Down