Skip to content

Commit

Permalink
Create WSType on the test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Aug 30, 2023
1 parent 7519e6b commit 87ad36a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ omit = [

[tool.coverage.report]
precision = 2
fail_under = 98.35
fail_under = 98.65
show_missing = true
skip_covered = true
exclude_lines = [
Expand Down
9 changes: 8 additions & 1 deletion tests/middleware/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import contextlib
import logging
import socket
Expand All @@ -22,8 +24,13 @@

if typing.TYPE_CHECKING:
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
from uvicorn.protocols.websockets.websockets_sansio_impl import (
WebSocketsSansIOProtocol,
)
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol

WSType = typing.Type["WSProtocol | WebSocketProtocol | WebSocketsSansIOProtocol"]


@contextlib.contextmanager
def caplog_for_logger(caplog, logger_name):
Expand Down Expand Up @@ -96,7 +103,7 @@ async def test_trace_logging_on_http_protocol(

@pytest.mark.anyio
async def test_trace_logging_on_ws_protocol(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
caplog,
logging_config,
unused_tcp_port: int,
Expand Down
70 changes: 38 additions & 32 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import typing
from copy import deepcopy
Expand All @@ -14,6 +16,7 @@
from tests.utils import run_server
from uvicorn.config import Config
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol

try:
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
Expand All @@ -22,6 +25,9 @@
except ModuleNotFoundError:
skip_if_no_wsproto = pytest.mark.skipif(True, reason="wsproto is not installed.")

if typing.TYPE_CHECKING:
WSType = typing.Type["WSProtocol | WebSocketProtocol | WebSocketsSansIOProtocol"]


class WebSocketResponse:
def __init__(self, scope, receive, send):
Expand All @@ -46,7 +52,7 @@ async def asgi(self):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_invalid_upgrade(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -85,7 +91,7 @@ def app(scope):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_accept_connection(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand All @@ -112,7 +118,7 @@ async def open_connection(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_supports_permessage_deflate_extension(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -142,7 +148,7 @@ async def open_connection(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_can_disable_permessage_deflate_extension(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -175,7 +181,7 @@ async def open_connection(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_close_connection(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -205,7 +211,7 @@ async def open_connection(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_headers(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -238,7 +244,7 @@ async def open_connection(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_extra_headers(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -267,7 +273,7 @@ async def open_connection(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_path_and_raw_path(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -298,7 +304,7 @@ async def open_connection(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_text_data_to_client(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -326,7 +332,7 @@ async def get_data(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_binary_data_to_client(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -354,7 +360,7 @@ async def get_data(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_and_close_connection(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -390,7 +396,7 @@ async def get_data(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_text_data_to_server(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -422,7 +428,7 @@ async def send_text(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_binary_data_to_server(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -454,7 +460,7 @@ async def send_text(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_text_data_to_server_in_multiple_frames(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -497,7 +503,7 @@ async def send_text(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_binary_data_to_server_in_multiple_frames(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -540,7 +546,7 @@ async def send_bytes(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_after_protocol_close(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -578,7 +584,7 @@ async def get_data(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_missing_handshake(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand All @@ -604,7 +610,7 @@ async def connect(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_before_handshake(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand All @@ -630,7 +636,7 @@ async def connect(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_duplicate_handshake(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -658,7 +664,7 @@ async def connect(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_asgi_return_value(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -697,7 +703,7 @@ async def connect(url):
ids=["none_as_reason", "normal_reason", "without_reason"],
)
async def test_app_close(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
code,
Expand Down Expand Up @@ -744,7 +750,7 @@ async def websocket_session(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_client_close(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -777,7 +783,7 @@ async def websocket_session(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_client_connection_lost(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -816,7 +822,7 @@ async def app(scope, receive, send):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_connection_lost_before_handshake_complete(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -870,7 +876,7 @@ async def websocket_session(uri):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_close_on_server_shutdown(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -921,7 +927,7 @@ async def websocket_session(uri):
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
@pytest.mark.parametrize("subprotocol", ["proto1", "proto2"])
async def test_subprotocols(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
subprotocol,
unused_tcp_port: int,
Expand Down Expand Up @@ -1014,7 +1020,7 @@ async def send_text(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_reject_connection(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -1055,7 +1061,7 @@ async def websocket_session(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_can_read_messages_in_buffer_after_close(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -1100,7 +1106,7 @@ async def send_text(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_default_server_headers(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand All @@ -1127,7 +1133,7 @@ async def open_connection(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_no_server_headers(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -1183,7 +1189,7 @@ async def open_connection(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_multiple_server_header(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down Expand Up @@ -1218,7 +1224,7 @@ async def open_connection(url):
@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_lifespan_state(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
ws_protocol_cls: WSType,
http_protocol_cls,
unused_tcp_port: int,
):
Expand Down
2 changes: 1 addition & 1 deletion uvicorn/protocols/websockets/websockets_sansio_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
_loop: typing.Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
config.load()
config.load() # pragma: no cover

self.config = config
self.app = config.loaded_app
Expand Down
2 changes: 1 addition & 1 deletion uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
_loop: typing.Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
config.load()
config.load() # pragma: no cover

self.config = config
self.app = config.loaded_app
Expand Down

0 comments on commit 87ad36a

Please sign in to comment.