Skip to content

Commit

Permalink
fix: Use wrong placeholder type for bf16 and float16 (#2011)
Browse files Browse the repository at this point in the history
See also: #2004

---------

Signed-off-by: yangxuan <[email protected]>
  • Loading branch information
XuanYang-cn authored Apr 2, 2024
1 parent b9a10c9 commit 722d2b5
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 129 deletions.
28 changes: 12 additions & 16 deletions examples/bfloat16_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ def gen_bf16_vectors(num, dim):
for _ in range(num):
raw_vector = [random.random() for _ in range(dim)]
raw_vectors.append(raw_vector)
# bf16_vector = np.array(raw_vector, dtype=tf.bfloat16).view(np.uint8).tolist()
bf16_vector = tf.cast(raw_vector, dtype=tf.bfloat16).numpy().view(np.uint8).tolist()
bf16_vectors.append(bytes(bf16_vector))
bf16_vector = tf.cast(raw_vector, dtype=tf.bfloat16).numpy()
bf16_vectors.append(bf16_vector)
return raw_vectors, bf16_vectors

def bf16_vector_search():
Expand All @@ -35,23 +34,20 @@ def bf16_vector_search():
bf16_vector = FieldSchema(name=vector_field_name, dtype=DataType.BFLOAT16_VECTOR, dim=dim)
schema = CollectionSchema(fields=[int64_field, bf16_vector])

has = utility.has_collection("hello_milvus_fp16")
if has:
hello_milvus = Collection("hello_milvus_fp16")
hello_milvus.drop()
else:
hello_milvus = Collection("hello_milvus_fp16", schema)
if utility.has_collection("hello_milvus_fp16"):
utility.drop_collection("hello_milvus_fp16")
hello_milvus = Collection("hello_milvus_fp16", schema, consistency_level="Strong")

_, vectors = gen_bf16_vectors(nb, dim)
hello_milvus.insert([vectors[:6]])
rows = [
{vector_field_name: vectors[0]},
{vector_field_name: vectors[1]},
{vector_field_name: vectors[2]},
{vector_field_name: vectors[3]},
{vector_field_name: vectors[4]},
{vector_field_name: vectors[5]},
{vector_field_name: vectors[6]},
{vector_field_name: vectors[7]},
{vector_field_name: vectors[8]},
{vector_field_name: vectors[9]},
{vector_field_name: vectors[10]},
{vector_field_name: vectors[11]},
]

hello_milvus.insert(rows)
hello_milvus.flush()

Expand Down
30 changes: 14 additions & 16 deletions examples/float16_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def gen_fp16_vectors(num, dim):
for _ in range(num):
raw_vector = [random.random() for _ in range(dim)]
raw_vectors.append(raw_vector)
fp16_vector = np.array(raw_vector, dtype=np.float16).view(np.uint8).tolist()
fp16_vectors.append(bytes(fp16_vector))
fp16_vector = np.array(raw_vector, dtype=np.float16)
fp16_vectors.append(fp16_vector)
return raw_vectors, fp16_vectors

def fp16_vector_search():
Expand All @@ -33,23 +33,21 @@ def fp16_vector_search():
fp16_vector = FieldSchema(name=vector_field_name, dtype=DataType.FLOAT16_VECTOR, dim=dim)
schema = CollectionSchema(fields=[int64_field, fp16_vector])

has = utility.has_collection("hello_milvus_fp16")
if has:
hello_milvus = Collection("hello_milvus_fp16")
hello_milvus.drop()
else:
hello_milvus = Collection("hello_milvus_fp16", schema)
if utility.has_collection("hello_milvus_fp16"):
utility.drop_collection("hello_milvus_fp16")

hello_milvus = Collection("hello_milvus_fp16", schema)

_, vectors = gen_fp16_vectors(nb, dim)
hello_milvus.insert([vectors[:6]])
rows = [
{vector_field_name: vectors[0]},
{vector_field_name: vectors[1]},
{vector_field_name: vectors[2]},
{vector_field_name: vectors[3]},
{vector_field_name: vectors[4]},
{vector_field_name: vectors[5]},
{vector_field_name: vectors[6]},
{vector_field_name: vectors[7]},
{vector_field_name: vectors[8]},
{vector_field_name: vectors[9]},
{vector_field_name: vectors[10]},
{vector_field_name: vectors[11]},
]

hello_milvus.insert(rows)
hello_milvus.flush()

Expand All @@ -67,4 +65,4 @@ def fp16_vector_search():
hello_milvus.drop()

if __name__ == "__main__":
fp16_vector_search()
fp16_vector_search()
7 changes: 1 addition & 6 deletions pymilvus/client/blob.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import struct
from typing import List

# reference: https://docs.python.org/3/library/struct.html#struct.pack


def vector_binary_to_bytes(v: bytes):
return bytes(v)


# reference: https://docs.python.org/3/library/struct.html#struct.pack
def vector_float_to_bytes(v: List[float]):
# pack len(v) number of float
return struct.pack(f"{len(v)}f", *v)
20 changes: 16 additions & 4 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,23 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info
field_data.vectors.dim = len(field_value) * 8
field_data.vectors.binary_vector += bytes(field_value)
elif field_type == DataType.FLOAT16_VECTOR:
field_data.vectors.dim = len(field_value) // 2
field_data.vectors.float16_vector += bytes(field_value)
v_bytes = (
bytes(field_value)
if not isinstance(field_value, np.ndarray)
else field_value.view(np.uint8).tobytes()
)

field_data.vectors.dim = len(v_bytes) // 2
field_data.vectors.float16_vector += v_bytes
elif field_type == DataType.BFLOAT16_VECTOR:
field_data.vectors.dim = len(field_value) // 2
field_data.vectors.bfloat16_vector += bytes(field_value)
v_bytes = (
bytes(field_value)
if not isinstance(field_value, np.ndarray)
else field_value.view(np.uint8).tobytes()
)

field_data.vectors.dim = len(v_bytes) // 2
field_data.vectors.bfloat16_vector += v_bytes
elif field_type == DataType.SPARSE_FLOAT_VECTOR:
# field_value is a single row of sparse float vector in user provided format
if not sparse_is_scipy_format(field_value):
Expand Down
112 changes: 50 additions & 62 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime
from typing import Any, Dict, Iterable, List, Optional, Union

import ujson
import numpy as np

from pymilvus.client import __version__, entity_helper
from pymilvus.exceptions import DataNotMatchException, ExceptionsMessage, ParamError
Expand All @@ -11,7 +11,7 @@
from pymilvus.grpc_gen import schema_pb2 as schema_types
from pymilvus.orm.schema import CollectionSchema

from . import blob, ts_utils
from . import blob, ts_utils, utils
from .check import check_pass_param, is_legal_collection_properties
from .constants import (
DEFAULT_CONSISTENCY_LEVEL,
Expand Down Expand Up @@ -459,9 +459,9 @@ def _pre_batch_check(
):
for entity in entities:
if (
not entity.get("name", None)
or entity.get("values", None) is None
or not entity.get("type", None)
entity.get("name") is None
or entity.get("values") is None
or entity.get("type") is None
):
raise ParamError(
message="Missing param in entities, a field must have type, name and values"
Expand Down Expand Up @@ -575,22 +575,41 @@ def check_str(instr: str, prefix: str):
)

@classmethod
def _prepare_placeholders(cls, vectors: Any, nq: int, tag: Any, pl_type: Any, is_binary: bool):
pl = common_types.PlaceholderValue(tag=tag)
pl.type = pl_type
def _prepare_placeholder_str(cls, data: Any):
# sparse vector
if pl_type == PlaceholderType.SparseFloatVector:
sparse_float_array_proto = entity_helper.sparse_rows_to_proto(vectors)
pl.values.extend(sparse_float_array_proto.contents)
return pl

# dense or binary vector
for i in range(nq):
if is_binary:
pl.values.append(blob.vector_binary_to_bytes(vectors[i]))
if entity_helper.entity_is_sparse_matrix(data):
pl_type = PlaceholderType.SparseFloatVector
pl_values = entity_helper.sparse_rows_to_proto(data).contents

elif isinstance(data[0], np.ndarray):
dtype = data[0].dtype
pl_values = (array.tobytes() for array in data)

if dtype == "bfloat16":
pl_type = PlaceholderType.BFLOAT16_VECTOR
elif dtype == "float16":
pl_type = PlaceholderType.FLOAT16_VECTOR
elif dtype == "float32":
pl_type = PlaceholderType.FloatVector
elif dtype == "byte":
pl_type = PlaceholderType.BinaryVector

else:
pl.values.append(blob.vector_float_to_bytes(vectors[i]))
return pl
err_msg = f"unsupported data type: {dtype}"
raise ParamError(message=err_msg)

elif isinstance(data[0], bytes):
pl_type = PlaceholderType.BinaryVector
pl_values = data # data is already a list of bytes

else:
pl_type = PlaceholderType.FloatVector
pl_values = (blob.vector_float_to_bytes(entity) for entity in data)

pl = common_types.PlaceholderValue(tag="$0", type=pl_type, values=pl_values)
return common_types.PlaceholderGroup.SerializeToString(
common_types.PlaceholderGroup(placeholders=[pl])
)

@classmethod
def search_requests_with_expr(
Expand All @@ -606,16 +625,6 @@ def search_requests_with_expr(
round_decimal: int = -1,
**kwargs,
) -> milvus_types.SearchRequest:
if entity_helper.entity_is_sparse_matrix(data):
is_binary = False
pl_type = PlaceholderType.SparseFloatVector
elif isinstance(data[0], bytes):
is_binary = True
pl_type = PlaceholderType.BinaryVector
else:
is_binary = False
pl_type = PlaceholderType.FloatVector

use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs)

ignore_growing = param.get("ignore_growing", False) or kwargs.get("ignore_growing", False)
Expand Down Expand Up @@ -654,17 +663,13 @@ def search_requests_with_expr(
if anns_field:
search_params["anns_field"] = anns_field

def dump(v: Dict):
if isinstance(v, dict):
return ujson.dumps(v)
return str(v)

req_params = [
common_types.KeyValuePair(key=str(key), value=utils.dumps(value))
for key, value in search_params.items()
]
nq = entity_helper.get_input_num_rows(data)
tag = "$0"
pl = cls._prepare_placeholders(data, nq, tag, pl_type, is_binary)
plg = common_types.PlaceholderGroup()
plg.placeholders.append(pl)
plg_str = common_types.PlaceholderGroup.SerializeToString(plg)
plg_str = cls._prepare_placeholder_str(data)

request = milvus_types.SearchRequest(
collection_name=collection_name,
partition_names=partition_names,
Expand All @@ -673,18 +678,12 @@ def dump(v: Dict):
use_default_consistency=use_default_consistency,
consistency_level=kwargs.get("consistency_level", 0),
nq=nq,
placeholder_group=plg_str,
dsl_type=common_types.DslType.BoolExprV1,
search_params=req_params,
)
request.placeholder_group = plg_str

request.dsl_type = common_types.DslType.BoolExprV1
if expr is not None:
request.dsl = expr
request.search_params.extend(
[
common_types.KeyValuePair(key=str(key), value=dump(value))
for key, value in search_params.items()
]
)

return request

Expand All @@ -704,11 +703,6 @@ def hybrid_search_request_with_ranker(
rerank_param["limit"] = limit
rerank_param["round_decimal"] = round_decimal

def dump(v: Dict):
if isinstance(v, dict):
return ujson.dumps(v)
return str(v)

request = milvus_types.HybridSearchRequest(
collection_name=collection_name,
partition_names=partition_names,
Expand All @@ -721,7 +715,7 @@ def dump(v: Dict):

request.rank_params.extend(
[
common_types.KeyValuePair(key=str(key), value=dump(value))
common_types.KeyValuePair(key=str(key), value=utils.dumps(value))
for key, value in rerank_param.items()
]
)
Expand Down Expand Up @@ -756,26 +750,20 @@ def create_index_request(cls, collection_name: str, field_name: str, params: Dic
index_name=kwargs.get("index_name", ""),
)

def dump(tv: Dict):
return ujson.dumps(tv) if isinstance(tv, dict) else str(tv)

if isinstance(params, dict):
for tk, tv in params.items():
if tk == "dim" and (not tv or not isinstance(tv, int)):
raise ParamError(message="dim must be of int!")
kv_pair = common_types.KeyValuePair(key=str(tk), value=dump(tv))
kv_pair = common_types.KeyValuePair(key=str(tk), value=utils.dumps(tv))
index_params.extra_params.append(kv_pair)

return index_params

@classmethod
def alter_index_request(cls, collection_name: str, index_name: str, extra_params: dict):
def dump(tv: Dict):
return ujson.dumps(tv) if isinstance(tv, dict) else str(tv)

params = []
for k, v in extra_params.items():
params.append(common_types.KeyValuePair(key=str(k), value=dump(v)))
params.append(common_types.KeyValuePair(key=str(k), value=utils.dumps(v)))
return milvus_types.AlterIndexRequest(
collection_name=collection_name, index_name=index_name, extra_params=params
)
Expand Down
25 changes: 6 additions & 19 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from datetime import timedelta
from typing import Any, List, Optional, Union

import ujson

from pymilvus.exceptions import MilvusException, ParamError
from pymilvus.grpc_gen.common_pb2 import Status

Expand Down Expand Up @@ -220,7 +222,6 @@ def traverse_rows_info(fields_info: Any, entities: List):

field_name = field["name"]
location[field_name] = i
field_type = field["type"]

if field.get("is_dynamic", False):
is_dynamic = True
Expand All @@ -241,24 +242,6 @@ def traverse_rows_info(fields_info: Any, entities: List):
value = entity.get(field_name, None)
if value is None:
raise ParamError(message=f"Field {field_name} don't match in entities[{j}]")
# no special check for sparse float vector field
if field_type in [
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.FLOAT16_VECTOR,
]:
field_dim = field["params"]["dim"]
entity_dim = len(value)
if field_type in [DataType.BINARY_VECTOR]:
entity_dim = entity_dim * 8
elif field_type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR]:
entity_dim = int(entity_dim // 2)
if entity_dim != field_dim:
raise ParamError(
message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim}"
)

# though impossible from sdk
if primary_key_loc is None:
Expand Down Expand Up @@ -334,3 +317,7 @@ def traverse_info(fields_info: Any, entities: List):

def get_server_type(host: str):
return ZILLIZ if (isinstance(host, str) and "zilliz" in host.lower()) else MILVUS


def dumps(v: Union[dict, str]) -> str:
return ujson.dumps(v) if isinstance(v, dict) else str(v)
Loading

0 comments on commit 722d2b5

Please sign in to comment.