From 722d2b5bca2946b0bbefdeb909b0891d89c2615f Mon Sep 17 00:00:00 2001 From: XuanYang-cn Date: Tue, 2 Apr 2024 16:19:07 +0800 Subject: [PATCH] fix: Use wrong placeholder type for bf16 and float16 (#2011) See also: #2004 --------- Signed-off-by: yangxuan --- examples/bfloat16_example.py | 28 ++++---- examples/float16_example.py | 30 ++++----- pymilvus/client/blob.py | 7 +- pymilvus/client/entity_helper.py | 20 ++++-- pymilvus/client/prepare.py | 112 ++++++++++++++----------------- pymilvus/client/utils.py | 25 ++----- pymilvus/orm/prepare.py | 22 ++++-- 7 files changed, 115 insertions(+), 129 deletions(-) diff --git a/examples/bfloat16_example.py b/examples/bfloat16_example.py index 162dd2139..06064794a 100644 --- a/examples/bfloat16_example.py +++ b/examples/bfloat16_example.py @@ -20,9 +20,8 @@ def gen_bf16_vectors(num, dim): for _ in range(num): raw_vector = [random.random() for _ in range(dim)] raw_vectors.append(raw_vector) - # bf16_vector = np.array(raw_vector, dtype=tf.bfloat16).view(np.uint8).tolist() - bf16_vector = tf.cast(raw_vector, dtype=tf.bfloat16).numpy().view(np.uint8).tolist() - bf16_vectors.append(bytes(bf16_vector)) + bf16_vector = tf.cast(raw_vector, dtype=tf.bfloat16).numpy() + bf16_vectors.append(bf16_vector) return raw_vectors, bf16_vectors def bf16_vector_search(): @@ -35,23 +34,20 @@ def bf16_vector_search(): bf16_vector = FieldSchema(name=vector_field_name, dtype=DataType.BFLOAT16_VECTOR, dim=dim) schema = CollectionSchema(fields=[int64_field, bf16_vector]) - has = utility.has_collection("hello_milvus_fp16") - if has: - hello_milvus = Collection("hello_milvus_fp16") - hello_milvus.drop() - else: - hello_milvus = Collection("hello_milvus_fp16", schema) + if utility.has_collection("hello_milvus_fp16"): + utility.drop_collection("hello_milvus_fp16") + hello_milvus = Collection("hello_milvus_fp16", schema, consistency_level="Strong") _, vectors = gen_bf16_vectors(nb, dim) + hello_milvus.insert([vectors[:6]]) rows = [ - {vector_field_name: vectors[0]}, - {vector_field_name: vectors[1]}, - {vector_field_name: vectors[2]}, - {vector_field_name: vectors[3]}, - {vector_field_name: vectors[4]}, - {vector_field_name: vectors[5]}, + {vector_field_name: vectors[6]}, + {vector_field_name: vectors[7]}, + {vector_field_name: vectors[8]}, + {vector_field_name: vectors[9]}, + {vector_field_name: vectors[10]}, + {vector_field_name: vectors[11]}, ] - hello_milvus.insert(rows) hello_milvus.flush() diff --git a/examples/float16_example.py b/examples/float16_example.py index 95bf7cfab..d3cc519d6 100644 --- a/examples/float16_example.py +++ b/examples/float16_example.py @@ -19,8 +19,8 @@ def gen_fp16_vectors(num, dim): for _ in range(num): raw_vector = [random.random() for _ in range(dim)] raw_vectors.append(raw_vector) - fp16_vector = np.array(raw_vector, dtype=np.float16).view(np.uint8).tolist() - fp16_vectors.append(bytes(fp16_vector)) + fp16_vector = np.array(raw_vector, dtype=np.float16) + fp16_vectors.append(fp16_vector) return raw_vectors, fp16_vectors def fp16_vector_search(): @@ -33,23 +33,21 @@ def fp16_vector_search(): fp16_vector = FieldSchema(name=vector_field_name, dtype=DataType.FLOAT16_VECTOR, dim=dim) schema = CollectionSchema(fields=[int64_field, fp16_vector]) - has = utility.has_collection("hello_milvus_fp16") - if has: - hello_milvus = Collection("hello_milvus_fp16") - hello_milvus.drop() - else: - hello_milvus = Collection("hello_milvus_fp16", schema) + if utility.has_collection("hello_milvus_fp16"): + utility.drop_collection("hello_milvus_fp16") + + hello_milvus = Collection("hello_milvus_fp16", schema) _, vectors = gen_fp16_vectors(nb, dim) + hello_milvus.insert([vectors[:6]]) rows = [ - {vector_field_name: vectors[0]}, - {vector_field_name: vectors[1]}, - {vector_field_name: vectors[2]}, - {vector_field_name: vectors[3]}, - {vector_field_name: vectors[4]}, - {vector_field_name: vectors[5]}, + {vector_field_name: vectors[6]}, + {vector_field_name: vectors[7]}, + {vector_field_name: vectors[8]}, + {vector_field_name: vectors[9]}, + {vector_field_name: vectors[10]}, + {vector_field_name: vectors[11]}, ] - hello_milvus.insert(rows) hello_milvus.flush() @@ -67,4 +65,4 @@ def fp16_vector_search(): hello_milvus.drop() if __name__ == "__main__": - fp16_vector_search() \ No newline at end of file + fp16_vector_search() diff --git a/pymilvus/client/blob.py b/pymilvus/client/blob.py index 56fcbc02c..80faa0798 100644 --- a/pymilvus/client/blob.py +++ b/pymilvus/client/blob.py @@ -1,13 +1,8 @@ import struct from typing import List -# reference: https://docs.python.org/3/library/struct.html#struct.pack - - -def vector_binary_to_bytes(v: bytes): - return bytes(v) - +# reference: https://docs.python.org/3/library/struct.html#struct.pack def vector_float_to_bytes(v: List[float]): # pack len(v) number of float return struct.pack(f"{len(v)}f", *v) diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index cfb5e91fb..1fa78c196 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -310,11 +310,23 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info field_data.vectors.dim = len(field_value) * 8 field_data.vectors.binary_vector += bytes(field_value) elif field_type == DataType.FLOAT16_VECTOR: - field_data.vectors.dim = len(field_value) // 2 - field_data.vectors.float16_vector += bytes(field_value) + v_bytes = ( + bytes(field_value) + if not isinstance(field_value, np.ndarray) + else field_value.view(np.uint8).tobytes() + ) + + field_data.vectors.dim = len(v_bytes) // 2 + field_data.vectors.float16_vector += v_bytes elif field_type == DataType.BFLOAT16_VECTOR: - field_data.vectors.dim = len(field_value) // 2 - field_data.vectors.bfloat16_vector += bytes(field_value) + v_bytes = ( + bytes(field_value) + if not isinstance(field_value, np.ndarray) + else field_value.view(np.uint8).tobytes() + ) + + field_data.vectors.dim = len(v_bytes) // 2 + field_data.vectors.bfloat16_vector += v_bytes elif field_type == DataType.SPARSE_FLOAT_VECTOR: # field_value is a single row of sparse float vector in user provided format if not sparse_is_scipy_format(field_value): diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 8cf094f74..3f9ed2f9a 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -2,7 +2,7 @@ import datetime from typing import Any, Dict, Iterable, List, Optional, Union -import ujson +import numpy as np from pymilvus.client import __version__, entity_helper from pymilvus.exceptions import DataNotMatchException, ExceptionsMessage, ParamError @@ -11,7 +11,7 @@ from pymilvus.grpc_gen import schema_pb2 as schema_types from pymilvus.orm.schema import CollectionSchema -from . import blob, ts_utils +from . import blob, ts_utils, utils from .check import check_pass_param, is_legal_collection_properties from .constants import ( DEFAULT_CONSISTENCY_LEVEL, @@ -459,9 +459,9 @@ def _pre_batch_check( ): for entity in entities: if ( - not entity.get("name", None) - or entity.get("values", None) is None - or not entity.get("type", None) + entity.get("name") is None + or entity.get("values") is None + or entity.get("type") is None ): raise ParamError( message="Missing param in entities, a field must have type, name and values" @@ -575,22 +575,41 @@ def check_str(instr: str, prefix: str): ) @classmethod - def _prepare_placeholders(cls, vectors: Any, nq: int, tag: Any, pl_type: Any, is_binary: bool): - pl = common_types.PlaceholderValue(tag=tag) - pl.type = pl_type + def _prepare_placeholder_str(cls, data: Any): # sparse vector - if pl_type == PlaceholderType.SparseFloatVector: - sparse_float_array_proto = entity_helper.sparse_rows_to_proto(vectors) - pl.values.extend(sparse_float_array_proto.contents) - return pl - - # dense or binary vector - for i in range(nq): - if is_binary: - pl.values.append(blob.vector_binary_to_bytes(vectors[i])) + if entity_helper.entity_is_sparse_matrix(data): + pl_type = PlaceholderType.SparseFloatVector + pl_values = entity_helper.sparse_rows_to_proto(data).contents + + elif isinstance(data[0], np.ndarray): + dtype = data[0].dtype + pl_values = (array.tobytes() for array in data) + + if dtype == "bfloat16": + pl_type = PlaceholderType.BFLOAT16_VECTOR + elif dtype == "float16": + pl_type = PlaceholderType.FLOAT16_VECTOR + elif dtype == "float32": + pl_type = PlaceholderType.FloatVector + elif dtype == "byte": + pl_type = PlaceholderType.BinaryVector + else: - pl.values.append(blob.vector_float_to_bytes(vectors[i])) - return pl + err_msg = f"unsupported data type: {dtype}" + raise ParamError(message=err_msg) + + elif isinstance(data[0], bytes): + pl_type = PlaceholderType.BinaryVector + pl_values = data # data is already a list of bytes + + else: + pl_type = PlaceholderType.FloatVector + pl_values = (blob.vector_float_to_bytes(entity) for entity in data) + + pl = common_types.PlaceholderValue(tag="$0", type=pl_type, values=pl_values) + return common_types.PlaceholderGroup.SerializeToString( + common_types.PlaceholderGroup(placeholders=[pl]) + ) @classmethod def search_requests_with_expr( @@ -606,16 +625,6 @@ def search_requests_with_expr( round_decimal: int = -1, **kwargs, ) -> milvus_types.SearchRequest: - if entity_helper.entity_is_sparse_matrix(data): - is_binary = False - pl_type = PlaceholderType.SparseFloatVector - elif isinstance(data[0], bytes): - is_binary = True - pl_type = PlaceholderType.BinaryVector - else: - is_binary = False - pl_type = PlaceholderType.FloatVector - use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs) ignore_growing = param.get("ignore_growing", False) or kwargs.get("ignore_growing", False) @@ -654,17 +663,13 @@ def search_requests_with_expr( if anns_field: search_params["anns_field"] = anns_field - def dump(v: Dict): - if isinstance(v, dict): - return ujson.dumps(v) - return str(v) - + req_params = [ + common_types.KeyValuePair(key=str(key), value=utils.dumps(value)) + for key, value in search_params.items() + ] nq = entity_helper.get_input_num_rows(data) - tag = "$0" - pl = cls._prepare_placeholders(data, nq, tag, pl_type, is_binary) - plg = common_types.PlaceholderGroup() - plg.placeholders.append(pl) - plg_str = common_types.PlaceholderGroup.SerializeToString(plg) + plg_str = cls._prepare_placeholder_str(data) + request = milvus_types.SearchRequest( collection_name=collection_name, partition_names=partition_names, @@ -673,18 +678,12 @@ def dump(v: Dict): use_default_consistency=use_default_consistency, consistency_level=kwargs.get("consistency_level", 0), nq=nq, + placeholder_group=plg_str, + dsl_type=common_types.DslType.BoolExprV1, + search_params=req_params, ) - request.placeholder_group = plg_str - - request.dsl_type = common_types.DslType.BoolExprV1 if expr is not None: request.dsl = expr - request.search_params.extend( - [ - common_types.KeyValuePair(key=str(key), value=dump(value)) - for key, value in search_params.items() - ] - ) return request @@ -704,11 +703,6 @@ def hybrid_search_request_with_ranker( rerank_param["limit"] = limit rerank_param["round_decimal"] = round_decimal - def dump(v: Dict): - if isinstance(v, dict): - return ujson.dumps(v) - return str(v) - request = milvus_types.HybridSearchRequest( collection_name=collection_name, partition_names=partition_names, @@ -721,7 +715,7 @@ def dump(v: Dict): request.rank_params.extend( [ - common_types.KeyValuePair(key=str(key), value=dump(value)) + common_types.KeyValuePair(key=str(key), value=utils.dumps(value)) for key, value in rerank_param.items() ] ) @@ -756,26 +750,20 @@ def create_index_request(cls, collection_name: str, field_name: str, params: Dic index_name=kwargs.get("index_name", ""), ) - def dump(tv: Dict): - return ujson.dumps(tv) if isinstance(tv, dict) else str(tv) - if isinstance(params, dict): for tk, tv in params.items(): if tk == "dim" and (not tv or not isinstance(tv, int)): raise ParamError(message="dim must be of int!") - kv_pair = common_types.KeyValuePair(key=str(tk), value=dump(tv)) + kv_pair = common_types.KeyValuePair(key=str(tk), value=utils.dumps(tv)) index_params.extra_params.append(kv_pair) return index_params @classmethod def alter_index_request(cls, collection_name: str, index_name: str, extra_params: dict): - def dump(tv: Dict): - return ujson.dumps(tv) if isinstance(tv, dict) else str(tv) - params = [] for k, v in extra_params.items(): - params.append(common_types.KeyValuePair(key=str(k), value=dump(v))) + params.append(common_types.KeyValuePair(key=str(k), value=utils.dumps(v))) return milvus_types.AlterIndexRequest( collection_name=collection_name, index_name=index_name, extra_params=params ) diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index d0b029890..4d407e1ca 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -2,6 +2,8 @@ from datetime import timedelta from typing import Any, List, Optional, Union +import ujson + from pymilvus.exceptions import MilvusException, ParamError from pymilvus.grpc_gen.common_pb2 import Status @@ -220,7 +222,6 @@ def traverse_rows_info(fields_info: Any, entities: List): field_name = field["name"] location[field_name] = i - field_type = field["type"] if field.get("is_dynamic", False): is_dynamic = True @@ -241,24 +242,6 @@ def traverse_rows_info(fields_info: Any, entities: List): value = entity.get(field_name, None) if value is None: raise ParamError(message=f"Field {field_name} don't match in entities[{j}]") - # no special check for sparse float vector field - if field_type in [ - DataType.FLOAT_VECTOR, - DataType.BINARY_VECTOR, - DataType.BFLOAT16_VECTOR, - DataType.FLOAT16_VECTOR, - ]: - field_dim = field["params"]["dim"] - entity_dim = len(value) - if field_type in [DataType.BINARY_VECTOR]: - entity_dim = entity_dim * 8 - elif field_type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR]: - entity_dim = int(entity_dim // 2) - if entity_dim != field_dim: - raise ParamError( - message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim}" - ) # though impossible from sdk if primary_key_loc is None: @@ -334,3 +317,7 @@ def traverse_info(fields_info: Any, entities: List): def get_server_type(host: str): return ZILLIZ if (isinstance(host, str) and "zilliz" in host.lower()) else MILVUS + + +def dumps(v: Union[dict, str]) -> str: + return ujson.dumps(v) if isinstance(v, dict) else str(v) diff --git a/pymilvus/orm/prepare.py b/pymilvus/orm/prepare.py index 0d5d320c2..4466fa78d 100644 --- a/pymilvus/orm/prepare.py +++ b/pymilvus/orm/prepare.py @@ -17,6 +17,7 @@ import pandas as pd from pymilvus.client import entity_helper +from pymilvus.client.types import DataType from pymilvus.exceptions import ( DataNotMatchException, DataTypeNotSupportException, @@ -66,16 +67,25 @@ def prepare_insert_data( for i, field in enumerate(tmp_fields): try: - if isinstance(data[i], np.ndarray): - d = data[i].tolist() - else: - d = data[i] if data[i] is not None else [] - - entities.append({"name": field.name, "type": field.dtype, "values": d}) + f_data = data[i] # the last missing part of data is also completed in order according to the schema except IndexError: entities.append({"name": field.name, "type": field.dtype, "values": []}) + if isinstance(f_data, np.ndarray): + d = f_data.tolist() + + elif isinstance(f_data[0], np.ndarray) and field.dtype in ( + DataType.FLOAT16_VECTOR, + DataType.BFLOAT16_VECTOR, + ): + d = [bytes(ndarr.view(np.uint8).tolist()) for ndarr in f_data] + + else: + d = f_data if f_data is not None else [] + + entities.append({"name": field.name, "type": field.dtype, "values": d}) + return entities @classmethod