Skip to content

Commit

Permalink
WIP on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Aug 29, 2024
1 parent b69b421 commit 85897bf
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 16 deletions.
19 changes: 11 additions & 8 deletions pinecone/data/features/bulk_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@ def __init__(self, **kwargs):
)
openapi_config = ConfigBuilder.build_openapi_config(config, kwargs.get("openapi_config", None))

self.__import_operations_api = setup_openapi_client(
api_client_klass=ApiClient,
api_klass=BulkOperationsApi,
config=config,
openapi_config=openapi_config,
pool_threads=kwargs.get("pool_threads", 1),
api_version=API_VERSION,
)
if kwargs.get("__import_operations_api", None):
self.__import_operations_api = kwargs.get("__import_operations_api")
else:
self.__import_operations_api = setup_openapi_client(
api_client_klass=ApiClient,
api_klass=BulkOperationsApi,
config=config,
openapi_config=openapi_config,
pool_threads=kwargs.get("pool_threads", 1),
api_version=API_VERSION,
)

@prerelease_feature
def start_import(self, uri: str, integration: Optional[str] = None) -> StartImportResponse:
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/data/test_bulk_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
import warnings

from urllib3 import BaseHTTPResponse, HTTPResponse

from datetime import datetime, date

from pinecone.core_ea.openapi.db_data.api.bulk_operations_api import BulkOperationsApi
from pinecone.core_ea.openapi.db_data.models import ImportModel, StartImportResponse
from pinecone.core_ea.openapi.shared.api_client import ApiClient
from pinecone.core_ea.openapi.shared.exceptions import PineconeApiException

from pinecone.data.features.bulk_import import ImportFeatureMixin


def build_api_w_faked_response(mocker, body: str, status: int = 200) -> BaseHTTPResponse:
response = mocker.Mock()
response.headers = {"content-type": "application/json"}
response.status = status
response.data = body.encode("utf-8")

api_client = ApiClient()
mocker.patch.object(api_client.rest_client.pool_manager, "request", return_value=response)
return BulkOperationsApi(api_client=api_client)


def build_client_w_faked_response(mocker, body: str, status: int = 200):
api_client = build_api_w_faked_response(mocker, body, status)
return ImportFeatureMixin(__import_operations_api=api_client, api_key="asdf", host="asdf")


class TestBulkImportStartImport:
def test_start_import(self, mocker):
body = """
{
"id": "1"
}
"""
client = build_client_w_faked_response(mocker, body)

with pytest.warns(UserWarning, match="prerelease"):
my_import = client.start_import("s3://path/to/file.parquet")
assert my_import.id == "1"
assert my_import["id"] == "1"
assert my_import.to_dict() == {"id": "1"}
assert my_import.__class__ == StartImportResponse

def test_start_import_with_kwargs(self, mocker):
body = """
{
"id": "1"
}
"""
client = build_client_w_faked_response(mocker, body)

with pytest.warns(UserWarning, match="prerelease"):
my_import = client.start_import(uri="s3://path/to/file.parquet")
assert my_import.id == "1"
assert my_import["id"] == "1"
assert my_import.to_dict() == {"id": "1"}
assert my_import.__class__ == StartImportResponse

def test_start_invalid_uri(self, mocker):
body = """
{
"code": "3",
"message": "Bulk import URIs must start with the scheme of a supported storage provider",
"details": []
}
"""
client = build_client_w_faked_response(mocker, body, 400)

with pytest.warns(UserWarning, match="prerelease"):
with pytest.raises(PineconeApiException) as e:
my_import = client.start_import(uri="invalid path")

assert e.value.status == 400
assert e.value.body == body
assert "Bulk import URIs must start with the scheme of a supported storage provider" in str(e.value)

def test_no_arguments(self, mocker):
client = build_client_w_faked_response(mocker, "")

with pytest.warns(UserWarning, match="prerelease"):
with pytest.raises(TypeError) as e:
my_import = client.start_import()

assert "missing 1 required positional argument" in str(e.value)
16 changes: 8 additions & 8 deletions tests/unit/data/test_import_datetime_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,21 @@
from pinecone.core_ea.openapi.shared.api_client import ApiClient, Endpoint as _Endpoint
from pinecone.core_ea.openapi.shared.rest import RESTResponse


def fake_response(mocker, body: str, status: int = 200) -> BaseHTTPResponse:
resp = HTTPResponse(
body=body.encode('utf-8'),
headers={'content-type': 'application/json'},
body=body.encode("utf-8"),
headers={"content-type": "application/json"},
status=status,
reason="OK",
preload_content=True
preload_content=True,
)
api_client = ApiClient()
mocker.patch.object(api_client, 'request', return_value=RESTResponse(resp))
mocker.patch.object(api_client, "request", return_value=RESTResponse(resp))
return api_client


class TestBulkImport():
class TestBulkImport:
def test_parsing_datetime_fields(self, mocker):
body = """
{
Expand All @@ -36,10 +37,9 @@ def test_parsing_datetime_fields(self, mocker):
"""
api_client = fake_response(mocker, body, 200)
api = BulkOperationsApi(api_client=api_client)
r = api.describe_import(id='1')

r = api.describe_import(id="1")
assert r.created_at.year == 2024
assert r.created_at.month == 8
assert r.created_at.date() == date(year=2024, month=8, day=27)
assert r.created_at.time() == datetime.strptime("17:10:32.206413", "%H:%M:%S.%f").time()

0 comments on commit 85897bf

Please sign in to comment.