Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfies committed Dec 3, 2024
1 parent 92e6fce commit b3af027
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
4 changes: 3 additions & 1 deletion curl_cffi/requests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"AsyncWebSocket",
"WebSocket",
"WebSocketError",
"WebSocketClosed",
"WebSocketTimeout",
"WsCloseCode",
"ExtraFingerprints",
"CookieTypes",
Expand All @@ -39,7 +41,7 @@
from .impersonate import BrowserType, BrowserTypeLiteral, ExtraFingerprints, ExtraFpDict
from .models import Request, Response
from .session import AsyncSession, HttpMethod, ProxySpec, Session, ThreadType
from .websockets import AsyncWebSocket, WebSocket, WebSocketError, WsCloseCode
from .websockets import AsyncWebSocket, WebSocket, WebSocketClosed, WebSocketError, WebSocketTimeout, WsCloseCode


def request(
Expand Down
52 changes: 28 additions & 24 deletions curl_cffi/requests/websockets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import asyncio
from select import select
import struct
from enum import IntEnum
from json import loads, dumps
from functools import partial
from json import dumps, loads
from select import select
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand All @@ -14,26 +16,25 @@
Literal,
Optional,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)

from .options import set_curl_options
from .exceptions import SessionClosed, Timeout
from ..aio import CURL_SOCKET_BAD
from ..const import CurlECode, CurlOpt, CurlWsFlag, CurlInfo
from ..const import CurlECode, CurlInfo, CurlOpt, CurlWsFlag
from ..curl import Curl, CurlError
from .exceptions import SessionClosed, Timeout
from .options import set_curl_options

if TYPE_CHECKING:
from typing_extensions import Self

from ..const import CurlHttpVersion
from ..curl import CurlWsFrame
from .cookies import CookieTypes
from .headers import HeaderTypes
from .impersonate import BrowserTypeLiteral, ExtraFingerprints, ExtraFpDict
from .session import AsyncSession, ProxySpec
from ..const import CurlHttpVersion
from ..curl import CurlWsFrame

T = TypeVar("T")

Expand All @@ -45,6 +46,9 @@

not_set: Final[Any] = object()

# We need a partial for dumps() because a custom function may not accept the parameter
dumps = partial(dumps, separators=(",", ":"))


class WsCloseCode(IntEnum):
OK = 1000
Expand All @@ -70,6 +74,14 @@ def __init__(self, message: str, code: Union[WsCloseCode, CurlECode, Literal[0]]
super().__init__(message, code) # type: ignore


class WebSocketClosed(WebSocketError, SessionClosed):
"""WebSocket is already closed."""


class WebSocketTimeout(WebSocketError, Timeout):
"""WebSocket operation timed out."""


async def aselect(fd, *, loop: asyncio.AbstractEventLoop, timeout: Optional[float] = None) -> bool:
future = loop.create_future()
loop.add_reader(fd, future.set_result, None)
Expand Down Expand Up @@ -178,7 +190,7 @@ def __init__(

def __iter__(self) -> WebSocket:
if self.closed:
raise SessionClosed("WebSocket is closed")
raise WebSocketClosed("WebSocket is closed")
return self

def __next__(self) -> bytes:
Expand Down Expand Up @@ -265,12 +277,7 @@ def connect(
curl_options: extra curl options to use.
"""
if not self.closed:
raise TypeError("WebSocket is already connected")

if proxy and proxies:
raise TypeError("Cannot specify both 'proxy' and 'proxies'")
if proxy:
proxies = {"all": proxy}
raise RuntimeError("WebSocket is already connected")

self.curl = curl = Curl(debug=self.debug)
set_curl_options(
Expand Down Expand Up @@ -311,7 +318,7 @@ def connect(
def recv_fragment(self) -> Tuple[bytes, CurlWsFrame]:
"""Receive a single frame as bytes."""
if self.closed:
raise SessionClosed("WebSocket is closed")
raise WebSocketClosed("WebSocket is closed")

chunk, frame = self.curl.ws_recv()
if frame.flags & CurlWsFlag.CLOSE:
Expand Down Expand Up @@ -382,7 +389,7 @@ def send(self, payload: Union[str, bytes], flags: CurlWsFlag = CurlWsFlag.BINARY
flags: flags for the frame.
"""
if self.closed:
raise SessionClosed("WebSocket is closed")
raise WebSocketClosed("WebSocket is closed")

# curl expects bytes
if isinstance(payload, str):
Expand Down Expand Up @@ -436,9 +443,6 @@ def run_forever(self, url: str, **kwargs):
libcurl automatically handles pings and pongs.
ref: https://curl.se/libcurl/c/libcurl-ws.html
"""
if not self.closed:
raise TypeError("WebSocket is already connected")

self.connect(url, **kwargs)
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
if sock_fd == CURL_SOCKET_BAD:
Expand Down Expand Up @@ -526,7 +530,7 @@ def loop(self):

def __aiter__(self) -> Self:
if self.closed:
raise SessionClosed("WebSocket is closed")
raise WebSocketClosed("WebSocket is closed")
return self

async def __anext__(self) -> bytes:
Expand All @@ -542,7 +546,7 @@ async def recv_fragment(self, *, timeout: Optional[float] = None) -> Tuple[bytes
timeout: how many seconds to wait before giving up.
"""
if self.closed:
raise SessionClosed("WebSocket is closed")
raise WebSocketClosed("WebSocket is closed")
if self._recv_lock.locked():
raise TypeError("Concurrent call to recv_fragment() is not allowed")

Expand All @@ -552,7 +556,7 @@ async def recv_fragment(self, *, timeout: Optional[float] = None) -> Tuple[bytes
self.loop.run_in_executor(None, self.curl.ws_recv), timeout
)
except asyncio.TimeoutError:
raise Timeout("WebSocket recv_fragment() timed out")
raise WebSocketTimeout("WebSocket recv_fragment() timed out")
if frame.flags & CurlWsFlag.CLOSE:
try:
code, message = self._close_code, self._close_reason = self._unpack_close_frame(chunk)
Expand Down Expand Up @@ -632,7 +636,7 @@ async def send(self, payload: Union[str, bytes], flags: CurlWsFlag = CurlWsFlag.
flags: flags for the frame.
"""
if self.closed:
raise SessionClosed("WebSocket is closed")
raise WebSocketClosed("WebSocket is closed")

# curl expects bytes
if isinstance(payload, str):
Expand Down

0 comments on commit b3af027

Please sign in to comment.