Skip to content

Commit

Permalink
Change BulkWriter default parameters (#1877)
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <[email protected]>
  • Loading branch information
yhmo authored Jan 18, 2024
1 parent d7a402b commit 14e51f8
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 24 deletions.
18 changes: 12 additions & 6 deletions examples/example_bulkwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def read_sample_data(file_path: str, writer: [LocalBulkWriter, RemoteBulkWriter]
row = {}
for col in csv_data.columns.values:
if col == "vector":
vec = json.loads(csv_data[col][i])
vec = json.loads(csv_data[col][i]) # convert the string format vector to List[float]
row[col] = vec
else:
row[col] = csv_data[col][i]
Expand All @@ -123,6 +123,8 @@ def local_writer(schema: CollectionSchema, file_type: BulkFileType):
segment_size=128*1024*1024,
file_type=file_type,
) as local_writer:
# read data from csv
read_sample_data("./data/train_embeddings.csv", local_writer)

# append rows
for i in range(100000):
Expand Down Expand Up @@ -179,7 +181,7 @@ def _append_row(writer: LocalBulkWriter, begin: int, end: int):
schema=schema,
local_path="/tmp/bulk_writer",
segment_size=128 * 1024 * 1024, # 128MB
file_type=BulkFileType.JSON_RB,
file_type=BulkFileType.JSON,
)
threads = []
thread_count = 10
Expand Down Expand Up @@ -267,7 +269,7 @@ def all_types_writer(bin_vec: bool, schema: CollectionSchema, file_type: BulkFil
"double": np.float64(i/7),
"varchar": f"varchar_{i}",
"json": json.dumps({"dummy": i, "ok": f"name_{i}"}),
"vector": gen_binary_vector() if bin_vec else gen_float_vector(),
"vector": np.array(gen_binary_vector(), np.dtype("int8")) if bin_vec else np.array(gen_float_vector(), np.dtype("float32")),
f"dynamic_{i}": i,
# bulkinsert doesn't support import npy with array field, the below values will be stored into dynamic field
"array_str": np.array([f"str_{k}" for k in range(5)], np.dtype("str")),
Expand Down Expand Up @@ -380,7 +382,11 @@ def cloud_bulkinsert():
if __name__ == '__main__':
create_connection()

file_types = [BulkFileType.JSON_RB, BulkFileType.NPY, BulkFileType.PARQUET]
file_types = [
BulkFileType.JSON,
BulkFileType.NUMPY,
BulkFileType.PARQUET,
]

schema = build_simple_collection()
for file_type in file_types:
Expand All @@ -394,15 +400,15 @@ def cloud_bulkinsert():
# float vectors + all scalar types
for file_type in file_types:
# Note: bulkinsert doesn't support import npy with array field
schema = build_all_type_schema(bin_vec=False, has_array=False if file_type==BulkFileType.NPY else True)
schema = build_all_type_schema(bin_vec=False, has_array=False if file_type==BulkFileType.NUMPY else True)
batch_files = all_types_writer(bin_vec=False, schema=schema, file_type=file_type)
call_bulkinsert(schema, batch_files)
retrieve_imported_data(bin_vec=False)

# binary vectors + all scalar types
for file_type in file_types:
# Note: bulkinsert doesn't support import npy with array field
schema = build_all_type_schema(bin_vec=True, has_array=False if file_type == BulkFileType.NPY else True)
schema = build_all_type_schema(bin_vec=True, has_array=False if file_type == BulkFileType.NUMPY else True)
batch_files = all_types_writer(bin_vec=True, schema=schema, file_type=file_type)
call_bulkinsert(schema, batch_files)
retrieve_imported_data(bin_vec=True)
Expand Down
6 changes: 3 additions & 3 deletions pymilvus/bulk_writer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Buffer:
def __init__(
self,
schema: CollectionSchema,
file_type: BulkFileType = BulkFileType.NPY,
file_type: BulkFileType = BulkFileType.NUMPY,
):
self._buffer = {}
self._fields = {}
Expand Down Expand Up @@ -115,9 +115,9 @@ def persist(self, local_path: str, **kwargs) -> list:
)

# output files
if self._file_type == BulkFileType.NPY:
if self._file_type == BulkFileType.NUMPY:
return self._persist_npy(local_path, **kwargs)
if self._file_type == BulkFileType.JSON_RB:
if self._file_type == BulkFileType.JSON:
return self._persist_json_rows(local_path, **kwargs)
if self._file_type == BulkFileType.PARQUET:
return self._persist_parquet(local_path, **kwargs)
Expand Down
15 changes: 10 additions & 5 deletions pymilvus/bulk_writer/bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,23 @@ class BulkWriter:
def __init__(
self,
schema: CollectionSchema,
segment_size: int,
file_type: BulkFileType = BulkFileType.NPY,
chunk_size: int,
file_type: BulkFileType,
**kwargs,
):
self._schema = schema
self._buffer_size = 0
self._buffer_row_count = 0
self._total_row_count = 0
self._segment_size = segment_size
self._file_type = file_type
self._buffer_lock = Lock()

# the old parameter segment_size is changed to chunk_size, compatible with the legacy code
self._chunk_size = chunk_size
segment_size = kwargs.get("segment_size", 0)
if segment_size > 0:
self._chunk_size = segment_size

if len(self._schema.fields) == 0:
self._throw("collection schema fields list is empty")

Expand All @@ -71,8 +76,8 @@ def total_row_count(self):
return self._total_row_count

@property
def segment_size(self):
return self._segment_size
def chunk_size(self):
return self._chunk_size

def _new_buffer(self):
old_buffer = self._buffer
Expand Down
8 changes: 5 additions & 3 deletions pymilvus/bulk_writer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
DataType.FLOAT.name: lambda x: isinstance(x, float),
DataType.DOUBLE.name: lambda x: isinstance(x, float),
DataType.VARCHAR.name: lambda x, max_len: isinstance(x, str) and len(x) <= max_len,
DataType.JSON.name: lambda x: isinstance(x, (dict, list)),
DataType.JSON.name: lambda x: isinstance(x, dict),
DataType.FLOAT_VECTOR.name: lambda x, dim: isinstance(x, list) and len(x) == dim,
DataType.BINARY_VECTOR.name: lambda x, dim: isinstance(x, list) and len(x) * 8 == dim,
DataType.ARRAY.name: lambda x, cap: isinstance(x, list) and len(x) <= cap,
Expand All @@ -66,6 +66,8 @@


class BulkFileType(IntEnum):
NPY = 1
JSON_RB = 2
NUMPY = 1
NPY = 1 # deprecated
JSON = 2
JSON_RB = 2 # deprecated
PARQUET = 3
8 changes: 4 additions & 4 deletions pymilvus/bulk_writer/local_bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def __init__(
self,
schema: CollectionSchema,
local_path: str,
segment_size: int = 512 * MB,
file_type: BulkFileType = BulkFileType.NPY,
chunk_size: int = 128 * MB,
file_type: BulkFileType = BulkFileType.PARQUET,
**kwargs,
):
super().__init__(schema, segment_size, file_type, **kwargs)
super().__init__(schema, chunk_size, file_type, **kwargs)
self._local_path = local_path
self._uuid = str(uuid.uuid4())
self._flush_count = 0
Expand Down Expand Up @@ -94,7 +94,7 @@ def append_row(self, row: dict, **kwargs):
# in anync mode, the flush thread is asynchronously, other threads can
# continue to append if the new buffer size is less than target size
with self._working_thread_lock:
if super().buffer_size > super().segment_size:
if super().buffer_size > super().chunk_size:
self.commit(_async=True)

def commit(self, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions pymilvus/bulk_writer/remote_bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ def __init__(
schema: CollectionSchema,
remote_path: str,
connect_param: ConnectParam,
segment_size: int = 512 * MB,
file_type: BulkFileType = BulkFileType.NPY,
chunk_size: int = 1024 * MB,
file_type: BulkFileType = BulkFileType.PARQUET,
**kwargs,
):
local_path = Path(sys.argv[0]).resolve().parent.joinpath("bulk_writer")
super().__init__(schema, str(local_path), segment_size, file_type, **kwargs)
super().__init__(schema, str(local_path), chunk_size, file_type, **kwargs)
self._remote_path = Path("/").joinpath(remote_path).joinpath(super().uuid)
self._connect_param = connect_param
self._client = None
Expand Down

0 comments on commit 14e51f8

Please sign in to comment.