diff --git a/pinecone/data/features/bulk_import.py b/pinecone/data/features/bulk_import.py index 43d9b079..3c4a8ac0 100644 --- a/pinecone/data/features/bulk_import.py +++ b/pinecone/data/features/bulk_import.py @@ -37,14 +37,17 @@ def __init__(self, **kwargs): ) openapi_config = ConfigBuilder.build_openapi_config(config, kwargs.get("openapi_config", None)) - self.__import_operations_api = setup_openapi_client( - api_client_klass=ApiClient, - api_klass=BulkOperationsApi, - config=config, - openapi_config=openapi_config, - pool_threads=kwargs.get("pool_threads", 1), - api_version=API_VERSION, - ) + if kwargs.get("__import_operations_api", None): + self.__import_operations_api = kwargs.get("__import_operations_api") + else: + self.__import_operations_api = setup_openapi_client( + api_client_klass=ApiClient, + api_klass=BulkOperationsApi, + config=config, + openapi_config=openapi_config, + pool_threads=kwargs.get("pool_threads", 1), + api_version=API_VERSION, + ) @prerelease_feature def start_import(self, uri: str, integration: Optional[str] = None) -> StartImportResponse: diff --git a/tests/unit/data/test_bulk_import.py b/tests/unit/data/test_bulk_import.py new file mode 100644 index 00000000..81c95e5f --- /dev/null +++ b/tests/unit/data/test_bulk_import.py @@ -0,0 +1,88 @@ +import pytest +import warnings + +from urllib3 import BaseHTTPResponse, HTTPResponse + +from datetime import datetime, date + +from pinecone.core_ea.openapi.db_data.api.bulk_operations_api import BulkOperationsApi +from pinecone.core_ea.openapi.db_data.models import ImportModel, StartImportResponse +from pinecone.core_ea.openapi.shared.api_client import ApiClient +from pinecone.core_ea.openapi.shared.exceptions import PineconeApiException + +from pinecone.data.features.bulk_import import ImportFeatureMixin + + +def build_api_w_faked_response(mocker, body: str, status: int = 200) -> BaseHTTPResponse: + response = mocker.Mock() + response.headers = {"content-type": "application/json"} + response.status = status + response.data = body.encode("utf-8") + + api_client = ApiClient() + mocker.patch.object(api_client.rest_client.pool_manager, "request", return_value=response) + return BulkOperationsApi(api_client=api_client) + + +def build_client_w_faked_response(mocker, body: str, status: int = 200): + api_client = build_api_w_faked_response(mocker, body, status) + return ImportFeatureMixin(__import_operations_api=api_client, api_key="asdf", host="asdf") + + +class TestBulkImportStartImport: + def test_start_import(self, mocker): + body = """ + { + "id": "1" + } + """ + client = build_client_w_faked_response(mocker, body) + + with pytest.warns(UserWarning, match="prerelease"): + my_import = client.start_import("s3://path/to/file.parquet") + assert my_import.id == "1" + assert my_import["id"] == "1" + assert my_import.to_dict() == {"id": "1"} + assert my_import.__class__ == StartImportResponse + + def test_start_import_with_kwargs(self, mocker): + body = """ + { + "id": "1" + } + """ + client = build_client_w_faked_response(mocker, body) + + with pytest.warns(UserWarning, match="prerelease"): + my_import = client.start_import(uri="s3://path/to/file.parquet") + assert my_import.id == "1" + assert my_import["id"] == "1" + assert my_import.to_dict() == {"id": "1"} + assert my_import.__class__ == StartImportResponse + + def test_start_invalid_uri(self, mocker): + body = """ + { + "code": "3", + "message": "Bulk import URIs must start with the scheme of a supported storage provider", + "details": [] + } + """ + client = build_client_w_faked_response(mocker, body, 400) + + with pytest.warns(UserWarning, match="prerelease"): + with pytest.raises(PineconeApiException) as e: + my_import = client.start_import(uri="invalid path") + + assert e.value.status == 400 + assert e.value.body == body + assert "Bulk import URIs must start with the scheme of a supported storage provider" in str(e.value) + + def test_no_arguments(self, mocker): + client = build_client_w_faked_response(mocker, "") + + with pytest.warns(UserWarning, match="prerelease"): + with pytest.raises(TypeError) as e: + my_import = client.start_import() + + assert "missing 1 required positional argument" in str(e.value) diff --git a/tests/unit/data/test_import_datetime_parsing.py b/tests/unit/data/test_import_datetime_parsing.py index aeddf121..cf6213f9 100644 --- a/tests/unit/data/test_import_datetime_parsing.py +++ b/tests/unit/data/test_import_datetime_parsing.py @@ -9,20 +9,21 @@ from pinecone.core_ea.openapi.shared.api_client import ApiClient, Endpoint as _Endpoint from pinecone.core_ea.openapi.shared.rest import RESTResponse + def fake_response(mocker, body: str, status: int = 200) -> BaseHTTPResponse: resp = HTTPResponse( - body=body.encode('utf-8'), - headers={'content-type': 'application/json'}, + body=body.encode("utf-8"), + headers={"content-type": "application/json"}, status=status, reason="OK", - preload_content=True + preload_content=True, ) api_client = ApiClient() - mocker.patch.object(api_client, 'request', return_value=RESTResponse(resp)) + mocker.patch.object(api_client, "request", return_value=RESTResponse(resp)) return api_client -class TestBulkImport(): +class TestBulkImport: def test_parsing_datetime_fields(self, mocker): body = """ { @@ -36,10 +37,9 @@ def test_parsing_datetime_fields(self, mocker): """ api_client = fake_response(mocker, body, 200) api = BulkOperationsApi(api_client=api_client) - - r = api.describe_import(id='1') + + r = api.describe_import(id="1") assert r.created_at.year == 2024 assert r.created_at.month == 8 assert r.created_at.date() == date(year=2024, month=8, day=27) assert r.created_at.time() == datetime.strptime("17:10:32.206413", "%H:%M:%S.%f").time() - \ No newline at end of file