Skip to content

Commit

Permalink
[utils] update to_producer to work with latest changes
Browse files Browse the repository at this point in the history
now producer doesn’t want a multipart body
  • Loading branch information
odrling committed Jan 21, 2025
1 parent 785af78 commit 8968ed3
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions nanachan/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from concurrent.futures import ProcessPoolExecutor
from contextlib import suppress
from functools import cache, lru_cache, singledispatch, update_wrapper
from typing import Any, AsyncIterable, Coroutine, Optional, Type, TypedDict, Union, cast
from typing import Any, AsyncIterable, Coroutine, Optional, Type, TypedDict, cast

import aiohttp
import backoff
Expand Down Expand Up @@ -129,36 +129,45 @@ class ProducerResponse(TypedDict):


@singledispatch
async def to_producer(file: str | URL | io.IOBase) -> ProducerResponse:
raise RuntimeError('shouldn’t be here')


@to_producer.register
@default_backoff
async def to_producer(file: Union[str, URL]) -> ProducerResponse:
async def _(file: str | URL) -> ProducerResponse:
url = URL(file) if isinstance(file, str) else file

async with get_session().get(url) as req:
async with get_session().get(url) as resp:
filename = url.name
data = aiohttp.FormData()
data.add_field('file', req.content, filename=filename)
headers: dict[str, str] = {
'Authorization': PRODUCER_TOKEN,
'Expires': '0',
'Filename': filename,
}

async with get_session().post(PRODUCER_UPLOAD_ENDPOINT, headers=headers, data=data) as req:
async with get_session().post(
PRODUCER_UPLOAD_ENDPOINT, headers=headers, data=resp.content
) as req:
return await req.json()


async def chunk_iter(file: io.IOBase):
while chunk := file.read(64 * 1024):
yield chunk


@to_producer.register
@default_backoff
async def _(file: io.IOBase, filename=None) -> ProducerResponse:
if filename is not None:
file.name = filename # type: ignore

async def _(file: io.IOBase, filename: str) -> ProducerResponse:
headers: dict[str, str] = {
'Authorization': PRODUCER_TOKEN,
'Expires': '0',
'Filename': filename,
}

async with get_session().post(
PRODUCER_UPLOAD_ENDPOINT, headers=headers, data=dict(file=file)
PRODUCER_UPLOAD_ENDPOINT, headers=headers, data=chunk_iter(file)
) as req:
return await req.json()

Expand Down

0 comments on commit 8968ed3

Please sign in to comment.