Skip to content

Commit

Permalink
enhance: enable describe_replica api in milvus client
Browse files Browse the repository at this point in the history
Signed-off-by: Wei Liu <[email protected]>
  • Loading branch information
weiliu1031 committed Jan 3, 2025
1 parent c70d44c commit 21a4e2b
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
milvus_client.release_collection(collection_name)
milvus_client.load_partitions(collection_name, partition_names =["p1", "p2"])

replicas=milvus_client.describe_replica(collection_name)
print("replicas:", replicas)

print(fmt.format("Start search in partiton p1"))
vectors_to_search = rng.random((1, dim))
result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["pk", "a", "b"], partition_names = ["p1"])
Expand Down
31 changes: 31 additions & 0 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Plan,
PrivilegeGroupInfo,
Replica,
ReplicaInfo,
ResourceGroupConfig,
ResourceGroupInfo,
RoleInfo,
Expand Down Expand Up @@ -1747,6 +1748,36 @@ def get_replicas(

return Replica(groups)

@retry_on_rpc_failure()
def describe_replica(
self, collection_name: str, timeout: Optional[float] = None, **kwargs
) -> List[ReplicaInfo]:
collection_id = self.describe_collection(collection_name, timeout, **kwargs)[
"collection_id"
]

req = Prepare.get_replicas(collection_id)
future = self._stub.GetReplicas.future(req, timeout=timeout)
response = future.result()
check_status(response.status)

groups = []
for replica in response.replicas:
shards = [
Shard(s.dm_channel_name, s.node_ids, s.leaderID) for s in replica.shard_replicas
]
groups.append(
ReplicaInfo(
replica.replicaID,
shards,
replica.node_ids,
replica.resource_group_name,
replica.num_outbound_node,
)
)

return groups

@retry_on_rpc_failure()
def do_bulk_insert(
self,
Expand Down
52 changes: 52 additions & 0 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,11 @@ def shard_leader(self) -> int:


class Group:
"""
This class represents replica info in orm format api, which is deprecated in milvus client api.
use `ReplicaInfo` instead.
"""

def __init__(
self,
group_id: int,
Expand Down Expand Up @@ -400,6 +405,10 @@ def num_outbound_node(self):

class Replica:
"""
This class represents replica info list in orm format api,
which is deprecated in milvus client api.
use `List[ReplicaInfo]` instead.
Replica groups:
- Group: <group_id:2>, <group_nodes:(1, 2, 3)>,
<shards:[Shard: <shard_id:10>,
Expand Down Expand Up @@ -428,6 +437,49 @@ def groups(self):
return self._groups


class ReplicaInfo:
def __init__(
self,
replica_id: int,
shards: List[str],
nodes: List[tuple],
resource_group: str,
num_outbound_node: dict,
) -> None:
self._id = replica_id
self._shards = shards
self._nodes = tuple(nodes)
self._resource_group = resource_group
self._num_outbound_node = num_outbound_node

def __repr__(self) -> str:
return (
f"ReplicaInfo: <id:{self.id}>, <nodes:{self.group_nodes}>, "
f"<shards:{self.shards}>, <resource_group: {self.resource_group}>, "
f"<num_outbound_node: {self.num_outbound_node}>"
)

@property
def id(self):
return self._id

@property
def group_nodes(self):
return self._nodes

@property
def shards(self):
return self._shards

@property
def resource_group(self):
return self._resource_group

@property
def num_outbound_node(self):
return self._num_outbound_node


class BulkInsertState:
"""enum states of bulk insert task"""

Expand Down
17 changes: 17 additions & 0 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ExtraList,
LoadState,
OmitZeroDict,
ReplicaInfo,
ResourceGroupConfig,
construct_cost_extra,
)
Expand Down Expand Up @@ -1696,3 +1697,19 @@ def transfer_replica(
return conn.transfer_replica(
source_group, target_group, collection_name, num_replicas, timeout
)

def describe_replica(
self, collection_name: str, timeout: Optional[float] = None, **kwargs
) -> List[ReplicaInfo]:
"""Get the current loaded replica information
Args:
collection_name (``str``): The name of the given collection.
timeout (``float``, optional): An optional duration of time in seconds to allow
for the RPC. When timeout is set to None, client waits until server response
or error occur.
Returns:
List[ReplicaInfo]: All the replica information.
"""
conn = self._get_connection()
return conn.describe_replica(collection_name, timeout=timeout, **kwargs)

0 comments on commit 21a4e2b

Please sign in to comment.