Skip to content

Commit

Permalink
feat: retry logic
Browse files Browse the repository at this point in the history
  • Loading branch information
simjak committed Nov 12, 2024
1 parent 841eeea commit 4e6e512
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 78 deletions.
1 change: 0 additions & 1 deletion aurelio_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def extract_file(
continue # Retry
except Exception as e:
if attempt == retries:
logger.error(f"Error on attempt {attempt}: {e}")
raise ApiError(message=str(e), base_url=self.base_url) from e
else:
logger.debug(f"Retrying due to exception (attempt {attempt}): {e}")
Expand Down
96 changes: 55 additions & 41 deletions aurelio_sdk/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,34 +213,6 @@ async def extract_file(

client_url = f"{self.base_url}/v1/extract/file"

# Form data
data = aiohttp.FormData()
data.add_field("quality", quality)
data.add_field("chunk", str(chunk))
initial_wait = WAIT_TIME_BEFORE_POLLING if polling_interval > 0 else wait
data.add_field("wait", str(initial_wait))

# Handle file from path, convert to AsyncGenerator
if file_path:
logger.debug(f"Uploading file from path, {file_path}")
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
file_stream = _file_stream_generator(file_path)
filename = Path(file_path).name

# Wrap the AsyncGenerator with an AsyncIterablePayload
file_payload = aiohttp.payload.AsyncIterablePayload(value=file_stream)
file_payload.content_type
data.add_field(
name="file",
value=file_payload,
filename=filename,
content_type=file_payload.content_type,
)
else:
logger.debug("Uploading file bytes")
data.add_field("file", file)

if wait <= 0:
session_timeout = None
else:
Expand All @@ -249,8 +221,49 @@ async def extract_file(

extract_response = None

# If file is a file-like object, reset its position before retries
if hasattr(file, 'seek')
initial_file_position = file.tell()
else:
initial_file_position = None

async with aiohttp.ClientSession(timeout=session_timeout) as session:
for attempt in range(1, retries + 1):
# Form data
data = aiohttp.FormData()
data.add_field("quality", quality)
data.add_field("chunk", str(chunk))
initial_wait = (
WAIT_TIME_BEFORE_POLLING if polling_interval > 0 else wait
)
data.add_field("wait", str(initial_wait))

# Handle file from path, convert to AsyncGenerator
if file_path:
logger.debug(f"Uploading file from path, {file_path}")
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
file_stream = _file_stream_generator(file_path)
filename = Path(file_path).name

# Wrap the AsyncGenerator with an AsyncIterablePayload
file_payload = aiohttp.payload.AsyncIterablePayload(
value=file_stream
)
file_payload.content_type
data.add_field(
name="file",
value=file_payload,
filename=filename,
content_type=file_payload.content_type,
)
else:
logger.debug("Uploading file bytes")
# Reset file-like object position
if initial_file_position is not None:
file.seek(initial_file_position)
data.add_field("file", file)

try:
async with session.post(
client_url, data=data, headers=self.headers
Expand Down Expand Up @@ -358,18 +371,6 @@ async def extract_url(
APIError: If there's an error in the API response.
"""
client_url = f"{self.base_url}/v1/extract/url"
data = aiohttp.FormData()
data.add_field("url", url)
data.add_field("quality", quality)
data.add_field("chunk", str(chunk))

# If polling is enabled (polling_interval > 0), use a short wait time
# (WAIT_TIME_BEFORE_POLLING)
# If polling is disabled (polling_interval <= 0), use the full wait time
initial_request_timeout = (
WAIT_TIME_BEFORE_POLLING if polling_interval < 0 else wait
)
data.add_field("wait", str(initial_request_timeout))

if wait <= 0:
session_timeout = None
Expand All @@ -381,6 +382,19 @@ async def extract_url(

async with aiohttp.ClientSession(timeout=session_timeout) as session:
for attempt in range(1, retries + 1):
data = aiohttp.FormData()
data.add_field("url", url)
data.add_field("quality", quality)
data.add_field("chunk", str(chunk))

# If polling is enabled (polling_interval > 0), use a short wait time
# (WAIT_TIME_BEFORE_POLLING)
# If polling is disabled (polling_interval <= 0), use the full wait time
initial_request_timeout = (
WAIT_TIME_BEFORE_POLLING if polling_interval < 0 else wait
)
data.add_field("wait", str(initial_request_timeout))

try:
async with session.post(
client_url, data=data, headers=self.headers
Expand Down Expand Up @@ -617,13 +631,13 @@ async def embedding(
ApiRateLimitError: If the rate limit is exceeded.
"""
client_url = f"{self.base_url}/v1/embeddings"
data = {"input": input, "model": model}

session_timeout = aiohttp.ClientTimeout(total=timeout)

# Added retry logic similar to extract_url
async with aiohttp.ClientSession(timeout=session_timeout) as session:
for attempt in range(1, retries + 1):
data = {"input": input, "model": model}
try:
async with session.post(
client_url, json=data, headers=self.headers
Expand Down
45 changes: 11 additions & 34 deletions examples/02_extract.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 4e6e512

Please sign in to comment.