Skip to content

Commit

Permalink
feat: updated streaming logic
Browse files Browse the repository at this point in the history
  • Loading branch information
simjak committed Nov 19, 2024
1 parent 6331269 commit 42e35b3
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 94 deletions.
244 changes: 156 additions & 88 deletions aurelio_sdk/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,19 +236,111 @@ async def extract_file(
if file_path:
if not await aiofiles.os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
filename = Path(file_path).name

# Wrap the AsyncGenerator with an AsyncIterablePayload
file_payload = aiohttp.payload.AsyncIterablePayload(
_file_stream(file_path=file_path, chunk_size=chunk_size)
)
file_payload.content_type
data.add_field(
name="file",
value=file_payload,
filename=filename,
content_type=file_payload.content_type,
)
# Open the file and keep it open during the API call
async with aiofiles.open(file_path, "rb") as file_buffer:
filename = Path(file_path).name
file_size = await aiofiles.os.path.getsize(file_path)

# Stream the file in chunks
async def _file_stream() -> AsyncGenerator[bytes, None]:
total_bytes = 0
chunk_count = 0
try:
while True:
chunk = await file_buffer.read(chunk_size)
if not chunk:
break
yield chunk
total_bytes += len(chunk)
chunk_count += 1
logger.debug(
f"Reading chunk {chunk_count}, chunk_size: "
f"{chunk_size / 1024 / 1024:.2f} MB, "
f"total size: {total_bytes / 1024 / 1024:.2f} MB"
)

if total_bytes != file_size:
logger.warning(
f"File size mismatch: expected {file_size} bytes, "
f"but read {total_bytes} bytes"
)
except Exception as e:
logger.error(f"Error during file streaming: {str(e)}")
raise

# Wrap the AsyncGenerator with an AsyncIterablePayload
file_payload = aiohttp.payload.AsyncIterablePayload(
_file_stream()
)
file_payload.content_type
data.add_field(
name="file",
value=file_payload,
filename=filename,
content_type=file_payload.content_type,
)

# Call the API within the same context to keep the file open
try:
async with session.post(
client_url, data=data, headers=self.headers
) as response:
logger.debug("Calling API")
if response.status == 200:
extract_response = ExtractResponse(
**await response.json()
)
document_id = extract_response.document.id
break # Success
elif response.status == 429:
raise ApiRateLimitError(
status_code=response.status,
base_url=self.base_url,
)
elif response.status >= 500:
error_content = await response.text()
if attempt == retries:
raise ApiError(
message=error_content,
status_code=response.status,
)
else:
logger.debug(
f"Retrying due to server error (attempt {attempt}): "
f"{error_content}"
)
continue # Retry
else:
try:
error_content = await response.text()
except Exception:
error_content = await response.text()
raise ApiError(
message=error_content,
status_code=response.status,
)
except ApiRateLimitError as e:
raise e from None
except asyncio.TimeoutError:
if attempt == retries:
raise ApiTimeoutError(
base_url=self.base_url,
timeout=session_timeout.total
if session_timeout
else None,
) from None
else:
logger.debug(
f"Timeout error on attempt {attempt}, retrying..."
)
continue # Retry
except Exception as e:
raise ApiError(
message=str(e),
base_url=self.base_url,
) from e

# Handles file bytes
else:
try:
Expand All @@ -267,58 +359,65 @@ async def extract_file(
base_url=self.base_url,
) from e

try:
async with session.post(
client_url, data=data, headers=self.headers
) as response:
logger.debug("Calling API")
if response.status == 200:
extract_response = ExtractResponse(**await response.json())
document_id = extract_response.document.id
break # Success
elif response.status == 429:
raise ApiRateLimitError(
status_code=response.status,
base_url=self.base_url,
)
elif response.status >= 500:
error_content = await response.text()
if attempt == retries:
raise ApiError(
message=error_content,
# Call the API
try:
async with session.post(
client_url, data=data, headers=self.headers
) as response:
logger.debug("Calling API")
if response.status == 200:
extract_response = ExtractResponse(
**await response.json()
)
document_id = extract_response.document.id
break # Success
elif response.status == 429:
raise ApiRateLimitError(
status_code=response.status,
base_url=self.base_url,
)
elif response.status >= 500:
error_content = await response.text()
if attempt == retries:
raise ApiError(
message=error_content,
status_code=response.status,
)
else:
logger.debug(
f"Retrying due to server error (attempt {attempt}): "
f"{error_content}"
)
continue # Retry
else:
logger.debug(
f"Retrying due to server error (attempt {attempt}): "
f"{error_content}"
try:
error_content = await response.text()
except Exception:
error_content = await response.text()
raise ApiError(
message=error_content,
status_code=response.status,
)
continue # Retry
except ApiRateLimitError as e:
raise e from None
except asyncio.TimeoutError:
if attempt == retries:
raise ApiTimeoutError(
base_url=self.base_url,
timeout=session_timeout.total
if session_timeout
else None,
) from None
else:
try:
error_content = await response.text()
except Exception:
error_content = await response.text()
raise ApiError(
message=error_content,
status_code=response.status,
logger.debug(
f"Timeout error on attempt {attempt}, retrying..."
)
except ApiRateLimitError as e:
raise e from None
except asyncio.TimeoutError:
if attempt == retries:
raise ApiTimeoutError(
continue # Retry
except Exception as e:
raise ApiError(
message=str(e),
base_url=self.base_url,
timeout=session_timeout.total if session_timeout else None,
) from None
else:
logger.debug(f"Timeout error on attempt {attempt}, retrying...")
continue # Retry
except Exception as e:
raise ApiError(
message=str(e),
base_url=self.base_url,
) from e
) from e

if extract_response is None:
raise ApiError(
Expand Down Expand Up @@ -676,34 +775,3 @@ async def embedding(
message="Failed to get embedding response",
base_url=self.base_url,
)


async def _file_stream(
file_path: Path | str, chunk_size: int = UPLOAD_CHUNK_SIZE
) -> AsyncGenerator[bytes, None]:
async with aiofiles.open(file_path, "rb") as file_buffer:
file_size = await aiofiles.os.path.getsize(file_path)
total_bytes = 0
chunk_count = 0
try:
while True:
chunk = await file_buffer.read(chunk_size)
if not chunk:
break
yield chunk
total_bytes += len(chunk)
chunk_count += 1
logger.debug(
f"Reading chunk {chunk_count}, chunk_size: "
f"{chunk_size / 1024 / 1024:.2f} MB, "
f"total size: {total_bytes / 1024 / 1024:.2f} MB"
)

if total_bytes != file_size:
logger.warning(
f"File size mismatch: expected {file_size} bytes, "
f"but read {total_bytes} bytes"
)
except Exception as e:
logger.error(f"Error during file streaming: {str(e)}")
raise
Loading

0 comments on commit 42e35b3

Please sign in to comment.