Skip to content

Commit

Permalink
chore: sdk tests
Browse files Browse the repository at this point in the history
  • Loading branch information
simjak committed Nov 7, 2024
1 parent 73facde commit 633315f
Show file tree
Hide file tree
Showing 13 changed files with 577 additions and 407 deletions.
32 changes: 16 additions & 16 deletions aurelio_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from requests_toolbelt.multipart.encoder import MultipartEncoder

from aurelio_sdk.const import POLLING_INTERVAL, WAIT_TIME_BEFORE_POLLING
from aurelio_sdk.exceptions import APIError, APITimeoutError
from aurelio_sdk.exceptions import ApiError, ApiTimeoutError
from aurelio_sdk.logger import logger
from aurelio_sdk.schema import (
ChunkingOptions,
Expand Down Expand Up @@ -105,18 +105,18 @@ def chunk(
error_content = response.json()
except Exception:
error_content = response.text
raise APIError(
raise ApiError(
message=error_content,
status_code=response.status_code,
base_url=self.base_url,
)
except Exception as e:
raise APIError(message=str(e), base_url=self.base_url) from e
raise ApiError(message=str(e), base_url=self.base_url) from e

def extract_file(
self,
file: Optional[Union[IO[bytes], bytes]] = None,
file_path: Optional[str] = None,
file_path: Optional[Union[str, pathlib.Path]] = None,
quality: Literal["low", "high"] = "low",
chunk: bool = True,
wait: int = 30,
Expand Down Expand Up @@ -188,7 +188,7 @@ def extract_file(
error_content = response.json()
except Exception:
error_content = response.text
raise APIError(
raise ApiError(
message=error_content,
status_code=response.status_code,
base_url=self.base_url,
Expand All @@ -210,11 +210,11 @@ def extract_file(
document_id=document_id, wait=wait, polling_interval=polling_interval
)
except requests.exceptions.Timeout:
raise APITimeoutError(
raise ApiTimeoutError(
timeout=session_timeout, base_url=self.base_url
) from None
except Exception as e:
raise APIError(message=str(e), base_url=self.base_url) from e
raise ApiError(message=str(e), base_url=self.base_url) from e

def extract_url(
self,
Expand Down Expand Up @@ -273,7 +273,7 @@ def extract_url(
error_content = response.json()
except Exception:
error_content = response.text
raise APIError(
raise ApiError(
message=error_content,
status_code=response.status_code,
base_url=self.base_url,
Expand All @@ -295,11 +295,11 @@ def extract_url(
document_id=document_id, wait=wait, polling_interval=polling_interval
)
except requests.exceptions.Timeout:
raise APITimeoutError(
raise ApiTimeoutError(
timeout=session_timeout, base_url=self.base_url
) from None
except Exception as e:
raise APIError(
raise ApiError(
message=str(e),
base_url=self.base_url,
) from e
Expand All @@ -324,15 +324,15 @@ def get_document(self, document_id: str, timeout: int = 30) -> ExtractResponse:
error_content = response.json()
except Exception:
error_content = response.text
raise APIError(
raise ApiError(
message=error_content,
status_code=response.status_code,
base_url=self.base_url,
)
except requests.exceptions.Timeout:
raise APITimeoutError(timeout=timeout, base_url=self.base_url) from None
raise ApiTimeoutError(timeout=timeout, base_url=self.base_url) from None
except Exception as e:
raise APIError(message=str(e), base_url=self.base_url) from e
raise ApiError(message=str(e), base_url=self.base_url) from e

def wait_for(
self,
Expand Down Expand Up @@ -412,15 +412,15 @@ def embedding(
error_content = response.json()
except Exception:
error_content = response.text
raise APIError(
raise ApiError(
message=error_content,
status_code=response.status_code,
base_url=self.base_url,
)
except requests.exceptions.Timeout:
raise APITimeoutError(timeout=timeout, base_url=self.base_url) from None
raise ApiTimeoutError(timeout=timeout, base_url=self.base_url) from None
except Exception as e:
raise APIError(
raise ApiError(
message=str(e),
base_url=self.base_url,
) from e
54 changes: 25 additions & 29 deletions aurelio_sdk/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
UPLOAD_CHUNK_SIZE,
WAIT_TIME_BEFORE_POLLING,
)
from aurelio_sdk.exceptions import APIError, APITimeoutError
from aurelio_sdk.exceptions import ApiError, ApiTimeoutError
from aurelio_sdk.logger import logger
from aurelio_sdk.schema import (
ChunkingOptions,
Expand Down Expand Up @@ -120,24 +120,24 @@ async def chunk(
error_content = await response.json()
except Exception:
error_content = await response.text()
raise APIError(
raise ApiError(
message=error_content,
status_code=response.status,
)
except asyncio.TimeoutError:
raise APITimeoutError(
raise ApiTimeoutError(
timeout=timeout,
base_url=self.base_url,
) from None
except Exception as e:
raise APIError(message=str(e), base_url=self.base_url) from e
raise ApiError(message=str(e), base_url=self.base_url) from e

async def extract_file(
self,
quality: Literal["low", "high"],
chunk: bool,
file: Optional[Union[IO[bytes], bytes]] = None,
file_path: Optional[str] = None,
file_path: Optional[Union[str, Path]] = None,
wait: int = 30,
polling_interval: int = POLLING_INTERVAL,
) -> ExtractResponse:
Expand Down Expand Up @@ -169,37 +169,33 @@ async def extract_file(

client_url = f"{self.base_url}/v1/extract/file"

# Form data
data = aiohttp.FormData()
data.add_field("quality", quality)
data.add_field("chunk", str(chunk))
initial_wait = WAIT_TIME_BEFORE_POLLING if polling_interval > 0 else wait
data.add_field("wait", str(initial_wait))

# Handle file from path, convert to AsyncGenerator
if file_path:
logger.debug(f"Uploading file from path, {file_path}")
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
file_stream = _file_stream_generator(file_path)
filename = Path(file_path).name
else:
filename = None

# Add file field
data = aiohttp.FormData()
if file_stream:
logger.debug("Uploading using stream")
# Wrap the AsyncGenerator with an AsyncIterablePayload
file_payload = aiohttp.payload.AsyncIterablePayload(value=file_stream)
file_payload.content_type
data.add_field(
name="file",
value=file_payload,
filename=filename,
content_type="application/octet-stream",
content_type=file_payload.content_type,
)
else:
logger.debug("Uploading file bytes")
data.add_field("file", file, filename=filename)

# Add other fields
data.add_field("quality", quality)
data.add_field("chunk", str(chunk))
initial_wait = WAIT_TIME_BEFORE_POLLING if polling_interval > 0 else wait
data.add_field("wait", str(initial_wait))
data.add_field("file", file)

if wait <= 0:
session_timeout = None
Expand Down Expand Up @@ -242,12 +238,12 @@ async def extract_file(
document_id=document_id, wait=wait, polling_interval=polling_interval
)
except asyncio.TimeoutError:
raise APITimeoutError(
raise ApiTimeoutError(
base_url=self.base_url,
timeout=session_timeout.total if session_timeout else None,
) from None
except Exception as e:
raise APIError(
raise ApiError(
message=str(e), base_url=self.base_url, status_code=status_code
) from e

Expand Down Expand Up @@ -312,7 +308,7 @@ async def extract_url(
error_content = await response.json()
except Exception:
error_content = await response.text()
raise APIError(
raise ApiError(
message=error_content,
status_code=response.status,
)
Expand All @@ -333,12 +329,12 @@ async def extract_url(
document_id=document_id, wait=wait, polling_interval=polling_interval
)
except asyncio.TimeoutError:
raise APITimeoutError(
raise ApiTimeoutError(
base_url=self.base_url,
timeout=session_timeout.total if session_timeout else None,
) from None
except Exception as e:
raise APIError(
raise ApiError(
message=str(e),
base_url=self.base_url,
) from e
Expand Down Expand Up @@ -370,12 +366,12 @@ async def get_document(
error_content = await response.json()
except Exception:
error_content = await response.text()
raise APIError(
raise ApiError(
message=error_content,
status_code=response.status,
)
except aiohttp.ConnectionTimeoutError as e:
raise APITimeoutError(
raise ApiTimeoutError(
base_url=self.base_url,
timeout=session_timeout.total if session_timeout else None,
) from e
Expand Down Expand Up @@ -461,17 +457,17 @@ async def embedding(
error_content = await response.json()
except Exception:
error_content = await response.text()
raise APIError(
raise ApiError(
message=error_content,
status_code=response.status,
)
except asyncio.TimeoutError:
raise APITimeoutError(
raise ApiTimeoutError(
base_url=self.base_url,
timeout=session_timeout.total if session_timeout else None,
) from None
except Exception as e:
raise APIError(message=str(e), base_url=self.base_url) from e
raise ApiError(message=str(e), base_url=self.base_url) from e


async def _file_stream_generator(
Expand Down
11 changes: 2 additions & 9 deletions aurelio_sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, Union


class APIError(Exception):
class ApiError(Exception):
"""
Exception for API errors.
"""
Expand Down Expand Up @@ -30,7 +30,7 @@ def __init__(
super().__init__(full_message)


class APITimeoutError(TimeoutError):
class ApiTimeoutError(TimeoutError):
"""
Exception for timeout errors.
"""
Expand All @@ -47,10 +47,3 @@ def __init__(
message += f" Base URL: {base_url}"
super().__init__(message)


class FileNotFoundError(Exception):
"""
Exception for file not found errors.
"""

pass
46 changes: 31 additions & 15 deletions examples/01_chunk_async.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 633315f

Please sign in to comment.