Skip to content

Commit

Permalink
support the report value in the dml and dql request
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG committed Apr 23, 2024
1 parent 59bf5e8 commit 0162583
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 16 deletions.
20 changes: 16 additions & 4 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ujson

from pymilvus.exceptions import DataTypeNotMatchException, ExceptionsMessage, MilvusException
from pymilvus.grpc_gen import schema_pb2
from pymilvus.grpc_gen import schema_pb2, common_pb2
from pymilvus.settings import Config

from . import entity_helper
Expand Down Expand Up @@ -195,6 +195,7 @@ def __init__(self, raw: Any):
self._timestamp = 0
self._succ_index = []
self._err_index = []
self._cost = 0

self._pack(raw)

Expand Down Expand Up @@ -234,10 +235,15 @@ def succ_index(self):
def err_index(self):
return self._err_index

@property
def cost(self):
return self._cost

def __str__(self):
return (
f"(insert count: {self._insert_cnt}, delete count: {self._delete_cnt}, upsert count: {self._upsert_cnt}, "
f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count})"
f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count}, "
f"cost: {self._cost})"
)

__repr__ = __str__
Expand All @@ -262,6 +268,8 @@ def _pack(self, raw: Any):
self._timestamp = raw.timestamp
self._succ_index = raw.succ_index
self._err_index = raw.err_index
self._cost = raw.status.extra_info["report_value"] \
if raw.status and raw.status.extra_info else 0


class SequenceIterator:
Expand Down Expand Up @@ -374,10 +382,14 @@ def __str__(self):
class SearchResult(list):
"""nq results: List[Hits]"""

def __init__(self, res: schema_pb2.SearchResultData, round_decimal: Optional[int] = None):
def __init__(self, res: schema_pb2.SearchResultData,
round_decimal: Optional[int] = None,
status: Optional[common_pb2.Status] = None):
self._nq = res.num_queries
all_topks = res.topks

self.cost = status.extra_info["report_value"] if status and status.extra_info else 0

output_fields = res.output_fields
fields_data = res.fields_data

Expand Down Expand Up @@ -497,7 +509,7 @@ def __iter__(self) -> SequenceIterator:

def __str__(self) -> str:
"""Only print at most 10 query results"""
return str(list(map(str, self[:10])))
return f"cost: {self.cost}, part data: {str(list(map(str, self[:10])))}"

__repr__ = __str__

Expand Down
2 changes: 1 addition & 1 deletion pymilvus/client/asynch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def exception(self):
class SearchFuture(Future):
def on_response(self, response: milvus_pb2.SearchResults):
check_status(response.status)
return SearchResult(response.results)
return SearchResult(response.results, status=response.status)


class MutationFuture(Future):
Expand Down
1 change: 0 additions & 1 deletion pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def pack_field_value_to_field_data(
message="invalid input for float32 vector, expect np.ndarray with dtype=float32"
)
f_value = field_value.view(np.float32).tolist()

field_data.vectors.dim = len(f_value)
field_data.vectors.float_vector.data.extend(f_value)

Expand Down
8 changes: 5 additions & 3 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
CompactionPlans,
CompactionState,
DataType,
ExtraList,
GrantInfo,
Group,
IndexState,
Expand All @@ -56,6 +57,7 @@
State,
Status,
UserInfo,
get_cost_extra,
)
from .utils import (
check_invalid_binary_vector,
Expand Down Expand Up @@ -731,7 +733,7 @@ def _execute_search(
response = self._stub.Search(request, timeout=timeout)
check_status(response.status)
round_decimal = kwargs.get("round_decimal", -1)
return SearchResult(response.results, round_decimal)
return SearchResult(response.results, round_decimal, status=response.status)

except Exception as e:
if kwargs.get("_async", False):
Expand All @@ -750,7 +752,7 @@ def _execute_hybrid_search(
response = self._stub.HybridSearch(request, timeout=timeout)
check_status(response.status)
round_decimal = kwargs.get("round_decimal", -1)
return SearchResult(response.results, round_decimal)
return SearchResult(response.results, round_decimal, status=response.status)

except Exception as e:
if kwargs.get("_async", False):
Expand Down Expand Up @@ -1503,7 +1505,7 @@ def query(
response.fields_data, index, dynamic_fields
)
results.append(entity_row_data)
return results
return ExtraList(results, extra=get_cost_extra(response.status))

@retry_on_rpc_failure()
def load_balance(
Expand Down
33 changes: 31 additions & 2 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from enum import IntEnum
from typing import Any, ClassVar, Dict, List, TypeVar, Union
from typing import Any, ClassVar, Dict, List, Optional, TypeVar, Union

from pymilvus.exceptions import (
AutoIDException,
Expand Down Expand Up @@ -763,7 +763,6 @@ def groups(self):

class ResourceGroupInfo:
def __init__(self, resource_group: Any) -> None:

self._name = resource_group.name
self._capacity = resource_group.capacity
self._num_available_node = resource_group.num_available_node
Expand Down Expand Up @@ -888,3 +887,33 @@ def hostname(self) -> str:
Attributes:
resource_group (str): The name of the resource group that can be transferred to or from.
"""


class ExtraList(list):
"""
A list that can hold extra information.
Attributes:
extra (dict): The extra information of the list.
Example:
ExtraList([1, 2, 3], extra={"total": 3})
"""

def __init__(self, *args, extra: Optional[Dict] = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.extra = extra or {}

def __str__(self) -> str:
"""Only print at most 10 query results"""
return f"part data: {str(list(map(str, self[:10])))}, extra_info: {self.extra}"

__repr__ = __str__


def get_cost_from_status(status: Optional[common_pb2.Status] = None):
return status.extra_info["report_value"] if status and status.extra_info else 0

def get_cost_extra(status: Optional[common_pb2.Status] = None):
return {"cost": get_cost_from_status(status)}

def construct_cost_extra(cost: int):
return {"cost": cost}
23 changes: 18 additions & 5 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from uuid import uuid4

from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL
from pymilvus.client.types import ExceptionsMessage, LoadState
from pymilvus.client.types import (
ExceptionsMessage,
LoadState,
ExtraList,
get_cost_extra,
construct_cost_extra,
)
from pymilvus.exceptions import (
DataTypeNotMatchException,
MilvusException,
Expand Down Expand Up @@ -216,7 +222,11 @@ def insert(
)
except Exception as ex:
raise ex from ex
return {"insert_count": res.insert_count, "ids": res.primary_keys}
return {
"insert_count": res.insert_count,
"ids": res.primary_keys,
"cost": res.cost,
}

def upsert(
self,
Expand Down Expand Up @@ -263,7 +273,10 @@ def upsert(
except Exception as ex:
raise ex from ex

return {"upsert_count": res.upsert_count}
return {
"upsert_count": res.upsert_count,
"cost": res.cost,
}

def search(
self,
Expand Down Expand Up @@ -325,7 +338,7 @@ def search(
query_result.append(hit.to_dict())
ret.append(query_result)

return ret
return ExtraList(ret, extra=construct_cost_extra(res.cost))

def query(
self,
Expand Down Expand Up @@ -543,7 +556,7 @@ def delete(
if ret_pks:
return ret_pks

return {"delete_count": res.delete_count}
return {"delete_count": res.delete_count, "cost": res.cost}

def get_collection_stats(self, collection_name: str, timeout: Optional[float] = None) -> Dict:
conn = self._get_connection()
Expand Down
4 changes: 4 additions & 0 deletions pymilvus/orm/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def succ_index(self):
def err_index(self):
return self._mr.err_index if self._mr else []

@property
def cost(self):
return self._mr.cost if self._mr else 0

def __str__(self) -> str:
"""
Return the information of mutation result
Expand Down

0 comments on commit 0162583

Please sign in to comment.