From a9eb5d17fa6b1ba64615b6e088fdca898aeda9c1 Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Fri, 12 Apr 2024 14:41:52 +0800 Subject: [PATCH] handle all type Signed-off-by: bigsheeper --- pymilvus/bulk_writer/buffer.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pymilvus/bulk_writer/buffer.py b/pymilvus/bulk_writer/buffer.py index 878cda737..e77723777 100644 --- a/pymilvus/bulk_writer/buffer.py +++ b/pymilvus/bulk_writer/buffer.py @@ -213,13 +213,11 @@ def _persist_parquet(self, local_path: str, **kwargs): arr.append(np.array(val, dtype=np.dtype("uint8"))) data[k] = pd.Series(arr) elif field_schema.dtype == DataType.ARRAY: - if field_schema.element_type == DataType.FLOAT: - arr = [] - for val in self._buffer[k]: - arr.append(np.array(val, dtype=np.dtype("float32"))) - data[k] = pd.Series(arr) - else: - data[k] = pd.Series(self._buffer[k]) + dt = NUMPY_TYPE_CREATOR[field_schema.element_type.name] + arr = [] + for val in self._buffer[k]: + arr.append(np.array(val, dtype=dt)) + data[k] = pd.Series(arr) elif field_schema.dtype.name in NUMPY_TYPE_CREATOR: dt = NUMPY_TYPE_CREATOR[field_schema.dtype.name] data[k] = pd.Series(self._buffer[k], dtype=dt)