diff --git a/.coveragerc b/.coveragerc index b9faf956..ac834946 100644 --- a/.coveragerc +++ b/.coveragerc @@ -11,6 +11,14 @@ source = [report] show_missing = true precision = 2 +omit = + src/webob/types.py +exclude_lines = + pragma: no cover + @overload + if TYPE_CHECKING: + if __name__ == .__main__. + raise NotImplementedError [html] show_contexts = True diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 140c4549..43e27627 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -120,7 +120,7 @@ jobs: - run: tox -e docs lint: runs-on: ubuntu-22.04 - name: Lint the package + name: Lint and type check the package steps: - uses: actions/checkout@v4 - name: Setup python @@ -129,4 +129,4 @@ jobs: python-version: "3.13" architecture: x64 - run: pip install tox - - run: tox -e lint + - run: tox -e lint,mypy diff --git a/pyproject.toml b/pyproject.toml index c1ca2d06..1f475b9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,3 +25,9 @@ line_length = 88 force_sort_within_sections = true default_section = "THIRDPARTY" known_first_party = "webob" + +[tool.mypy] +python_version = 3.9 +strict = true +warn_unreachable = true +mypy_path = "$MYPY_CONFIG_FILE_DIR/src" diff --git a/setup.py b/setup.py index a83f673a..ec26fae3 100644 --- a/setup.py +++ b/setup.py @@ -52,8 +52,11 @@ license="MIT", packages=find_packages("src", exclude=["tests"]), package_dir={"": "src"}, + include_package_data=True, + package_data={"webob": ["py.typed"]}, python_requires=">=3.9.0", install_requires=[ + "typing-extensions>=4.12.0", "legacy-cgi>=2.6; python_version>='3.13'", ], zip_safe=True, diff --git a/src/webob/acceptparse.py b/src/webob/acceptparse.py index 039a7726..1c34c4b5 100644 --- a/src/webob/acceptparse.py +++ b/src/webob/acceptparse.py @@ -5,11 +5,83 @@ ``Accept-Language``. """ -from collections import namedtuple +from __future__ import annotations + import re import textwrap +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Protocol, overload import warnings +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Iterator, Sequence + from typing import TypeVar + + from _typeshed import SupportsItems + from typing_extensions import Self, TypeAlias + + from webob.request import BaseRequest + from webob.types import AsymmetricPropertyWithDelete, ListOrTuple + + _T = TypeVar("_T") + _ParsedAccept: TypeAlias = tuple[ + str, float, list[tuple[str, str]], list["str | tuple[str, str]"] + ] + + class _SupportsStr(Protocol): + def __str__(self) -> str: + pass + + _AnyAcceptHeader: TypeAlias = ( + "AcceptValidHeader | AcceptInvalidHeader | AcceptNoHeader" + ) + _AnyAcceptCharsetHeader: TypeAlias = ( + "AcceptCharsetValidHeader | AcceptCharsetInvalidHeader | AcceptCharsetNoHeader" + ) + _AnyAcceptEncodingHeader: TypeAlias = ( + "AcceptEncodingValidHeader | AcceptEncodingInvalidHeader | AcceptEncodingNoHeader" + ) + _AnyAcceptLanguageHeader: TypeAlias = ( + "AcceptLanguageValidHeader | AcceptLanguageInvalidHeader | AcceptLanguageNoHeader" + ) + + _AcceptProperty: TypeAlias = AsymmetricPropertyWithDelete[ + _AnyAcceptHeader, + """( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr | str | None + )""", + ] + _AcceptCharsetProperty: TypeAlias = AsymmetricPropertyWithDelete[ + _AnyAcceptCharsetHeader, + """( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr | str | None + )""", + ] + _AcceptEncodingProperty: TypeAlias = AsymmetricPropertyWithDelete[ + _AnyAcceptEncodingHeader, + """( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr | str | None + )""", + ] + _AcceptLanguageProperty: TypeAlias = AsymmetricPropertyWithDelete[ + _AnyAcceptLanguageHeader, + """( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr | str | None + )""", + ] + + # RFC 7230 Section 3.2.3 "Whitespace" # OWS = *( SP / HTAB ) # ; optional whitespace @@ -33,11 +105,11 @@ weight_re = OWS_re + ";" + OWS_re + "[qQ]=(" + qvalue_re + ")" -def _item_n_weight_re(item_re): +def _item_n_weight_re(item_re: str) -> str: return "(" + item_re + ")(?:" + weight_re + ")?" -def _item_qvalue_pair_to_header_element(pair): +def _item_qvalue_pair_to_header_element(pair: tuple[str, float] | list[Any]) -> str: item, qvalue = pair if qvalue == 1.0: @@ -50,7 +122,7 @@ def _item_qvalue_pair_to_header_element(pair): return element -def _list_0_or_more__compiled_re(element_re): +def _list_0_or_more__compiled_re(element_re: str) -> re.Pattern[str]: # RFC 7230 Section 7 "ABNF List Extension: #rule": # #element => [ ( "," / element ) *( OWS "," [ OWS element ] ) ] @@ -70,7 +142,7 @@ def _list_0_or_more__compiled_re(element_re): ) -def _list_1_or_more__compiled_re(element_re): +def _list_1_or_more__compiled_re(element_re: str) -> re.Pattern[str]: # RFC 7230 Section 7 "ABNF List Extension: #rule": # 1#element => *( "," OWS ) element *( OWS "," [ OWS element ] ) # and RFC 7230 Errata ID: 4169 @@ -89,7 +161,7 @@ def _list_1_or_more__compiled_re(element_re): ) -class AcceptOffer(namedtuple("AcceptOffer", ["type", "subtype", "params"])): +class AcceptOffer(NamedTuple): """ A pre-parsed offer tuple represeting a value in the format ``type/subtype;param0=value0;param1=value1``. @@ -100,9 +172,11 @@ class AcceptOffer(namedtuple("AcceptOffer", ["type", "subtype", "params"])): """ - __slots__ = () + type: str + subtype: str + params: tuple[tuple[str, str], ...] - def __str__(self): + def __str__(self) -> str: """ Return the properly quoted media type string. @@ -163,10 +237,9 @@ class Accept: + "/" + subtype_re + ")" - + # '*' is included through type_re and subtype_re, so this covers */* # and type/* - ")" + + ")" + "(" + "(?:" + OWS_re @@ -244,7 +317,7 @@ class Accept: media_type_compiled_re = re.compile("^" + media_type_re + "$") @classmethod - def _escape_and_quote_parameter_value(cls, param_value): + def _escape_and_quote_parameter_value(cls, param_value: str) -> str: """ Escape and quote parameter value where necessary. @@ -262,7 +335,9 @@ def _escape_and_quote_parameter_value(cls, param_value): return param_value @classmethod - def _form_extension_params_segment(cls, extension_params): + def _form_extension_params_segment( + cls, extension_params: Iterable[str | tuple[str, str]] + ) -> str: """ Convert iterable of extension parameters to str segment for header. @@ -272,9 +347,9 @@ def _form_extension_params_segment(cls, extension_params): extension_params_segment = "" for item in extension_params: - try: + if isinstance(item, str): extension_params_segment += ";" + item - except TypeError: + else: param_name, param_value = item param_value = cls._escape_and_quote_parameter_value( param_value=param_value @@ -284,7 +359,9 @@ def _form_extension_params_segment(cls, extension_params): return extension_params_segment @classmethod - def _form_media_range(cls, type_subtype, media_type_params): + def _form_media_range( + cls, type_subtype: str, media_type_params: Iterable[tuple[str, str]] + ) -> str: """ Combine `type_subtype` and `media_type_params` to form a media range. @@ -300,7 +377,7 @@ def _form_media_range(cls, type_subtype, media_type_params): return type_subtype + media_type_params_segment @classmethod - def _iterable_to_header_element(cls, iterable): + def _iterable_to_header_element(cls, iterable: Iterable[Any]) -> str: """ Convert iterable of tuples into header element ``str``. @@ -326,7 +403,9 @@ def _iterable_to_header_element(cls, iterable): return element @classmethod - def _parse_media_type_params(cls, media_type_params_segment): + def _parse_media_type_params( + cls, media_type_params_segment: str + ) -> list[tuple[str, str]]: """ Parse media type parameters segment into list of (name, value) tuples. """ @@ -342,7 +421,7 @@ def _parse_media_type_params(cls, media_type_params_segment): return media_type_params @classmethod - def _process_quoted_string_token(cls, token): + def _process_quoted_string_token(cls, token: str) -> str: """ Return unescaped and unquoted value from quoted token. """ @@ -353,7 +432,7 @@ def _process_quoted_string_token(cls, token): return re.sub(r"\\(?![\\])", "", token[1:-1]).replace("\\\\", "\\") @classmethod - def _python_value_to_header_str(cls, value): + def _python_value_to_header_str(cls, value: object) -> str: """ Convert Python value to header string for __add__/__radd__. """ @@ -394,7 +473,7 @@ def _python_value_to_header_str(cls, value): return header_str @classmethod - def parse(cls, value): + def parse(cls, value: str) -> Iterator[_ParsedAccept]: """ Parse an ``Accept`` header. @@ -431,7 +510,7 @@ def parse(cls, value): if cls.accept_compiled_re.match(value) is None: raise ValueError("Invalid value for an Accept header.") - def generator(value): + def generator(value: str) -> Iterator[_ParsedAccept]: for match in cls.media_range_n_accept_params_compiled_re.finditer(value): groups = match.groups() @@ -478,7 +557,7 @@ def generator(value): return generator(value=value) @classmethod - def parse_offer(cls, offer): + def parse_offer(cls, offer: str | AcceptOffer) -> AcceptOffer: """ Parse an offer into its component parts. @@ -514,12 +593,14 @@ def parse_offer(cls, offer): ) @classmethod - def _parse_and_normalize_offers(cls, offers): + def _parse_and_normalize_offers( + cls, offers: Iterable[str | AcceptOffer] + ) -> list[tuple[int, AcceptOffer]]: """ Throw out any offers that do not match the media range ABNF. - :return: A list of offers split into the format ``[offer_index, - parsed_offer]``. + :return: A list of offers split into the format ``(offer_index, + parsed_offer)``. """ parsed_offers = [] @@ -529,7 +610,7 @@ def _parse_and_normalize_offers(cls, offers): parsed_offer = cls.parse_offer(offer) except ValueError: continue - parsed_offers.append([index, parsed_offer]) + parsed_offers.append((index, parsed_offer)) return parsed_offers @@ -547,13 +628,13 @@ class AcceptValidHeader(Accept): """ @property - def header_value(self): + def header_value(self) -> str: """(``str`` or ``None``) The header value.""" return self._header_value @property - def parsed(self): + def parsed(self) -> list[_ParsedAccept]: """ (``list`` or ``None``) Parsed form of the header. @@ -579,7 +660,7 @@ def parsed(self): return self._parsed - def __init__(self, header_value): + def __init__(self, header_value: str) -> None: """ Create an :class:`AcceptValidHeader` instance. @@ -589,17 +670,27 @@ def __init__(self, header_value): """ self._header_value = header_value self._parsed = list(self.parse(header_value)) - self._parsed_nonzero = [item for item in self.parsed if item[1]] + self._parsed_nonzero = [item for item in self._parsed if item[1]] # item[1] is the qvalue - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. """ return self.__class__(self._header_value) - def __add__(self, other): + def __add__( + self, + other: ( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self: """ Add to header, creating a new header object. @@ -640,7 +731,7 @@ def __add__(self, other): if other.header_value == "": return self.__class__(header_value=self.header_value) else: - return create_accept_header( + return create_accept_header( # type: ignore[return-value] header_value=self.header_value + ", " + other.header_value ) @@ -649,7 +740,7 @@ def __add__(self, other): return self._add_instance_and_non_accept_type(instance=self, other=other) - def __bool__(self): + def __bool__(self) -> Literal[True]: """ Return whether ``self`` represents a valid ``Accept`` header. @@ -662,7 +753,7 @@ def __bool__(self): return True - def __contains__(self, offer): + def __contains__(self, offer: str) -> bool: """ Return ``bool`` indicating whether `offer` is acceptable. @@ -712,7 +803,7 @@ def __contains__(self, offer): return False - def __iter__(self): + def __iter__(self) -> Iterator[str]: """ Return all the ranges with non-0 qvalues, in order of preference. @@ -743,7 +834,17 @@ def __iter__(self): ): yield media_range - def __radd__(self, other): + def __radd__( + self, + other: ( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self: """ Add to header, creating a new header object. @@ -754,10 +855,10 @@ def __radd__(self, other): instance=self, other=other, instance_on_the_right=True ) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} ({str(self)!r})>" - def __str__(self): + def __str__(self) -> str: r""" Return a tidied up version of the header value. @@ -783,8 +884,8 @@ def __str__(self): ) def _add_instance_and_non_accept_type( - self, instance, other, instance_on_the_right=False - ): + self, instance: Self, other: object, instance_on_the_right: bool = False + ) -> Self: if not other: return self.__class__(header_value=instance.header_value) @@ -807,7 +908,7 @@ def _add_instance_and_non_accept_type( ) return self.__class__(header_value=new_header_value) - def _old_match(self, mask, offer): + def _old_match(self, mask: str, offer: str) -> bool: """ Check if the offer is covered by the mask @@ -876,7 +977,7 @@ def _old_match(self, mask, offer): return offer.lower() == mask.lower() - def accept_html(self): + def accept_html(self) -> bool: """ Return ``True`` if any HTML-like type is accepted. @@ -894,10 +995,17 @@ def accept_html(self): ) ) - accepts_html = property(fget=accept_html, doc=accept_html.__doc__) - # note the plural + if TYPE_CHECKING: + + @property + def accepts_html(self) -> bool: + pass + + else: + accepts_html = property(fget=accept_html, doc=accept_html.__doc__) - def acceptable_offers(self, offers): + # note the plural + def acceptable_offers(self, offers: Sequence[str]) -> list[tuple[str, float]]: """ Return the offers that are acceptable according to the header. @@ -911,7 +1019,7 @@ def acceptable_offers(self, offers): Any offers that cannot be parsed via :meth:`.Accept.parse_offer` will be ignored. - :param offers: ``iterable`` of ``str`` media types (media types can + :param offers: ``sequence`` of ``str`` media types (media types can include media type parameters) or pre-parsed instances of :class:`.AcceptOffer`. :return: A list of tuples of the form (media type, qvalue), in @@ -920,6 +1028,7 @@ def acceptable_offers(self, offers): `offers`. """ parsed = self.parsed + assert parsed is not None # RFC 7231, section 3.1.1.1 "Media Type": # "The type, subtype, and parameter name tokens are case-insensitive. @@ -935,7 +1044,7 @@ def acceptable_offers(self, offers): ] lowercased_offers_parsed = self._parse_and_normalize_offers(offers) - acceptable_offers_n_quality_factors = {} + acceptable_offers_n_quality_factors: dict[str, tuple[float, int, int]] = {} for offer_index, parsed_offer in lowercased_offers_parsed: offer = offers[offer_index] offer_type, offer_subtype, offer_media_type_params = parsed_offer @@ -989,7 +1098,7 @@ def acceptable_offers(self, offers): specificity, # specifity of matched range ) - acceptable_offers_n_quality_factors = [ + filtered_acceptable_offers_n_quality_factors = [ # key is offer, value[0] is qvalue, value[1] is offer_index (key, value[0], value[1]) for key, value in acceptable_offers_n_quality_factors.items() @@ -1000,19 +1109,38 @@ def acceptable_offers(self, offers): # text/html' (which does not make sense, but is nonetheless valid), # and offers is ['text/html'] ] - acceptable_offers_n_quality_factors.sort( + filtered_acceptable_offers_n_quality_factors.sort( key=lambda tuple_: (tuple_[1], -tuple_[2]), reverse=True, # descending sort by (qvalue, -offer_index) ) # return list of (offer, qvalue) tuples, dropping offer_index - return [(item[0], item[1]) for item in acceptable_offers_n_quality_factors] + return [ + (offer, qvalue) + for offer, qvalue, _ in filtered_acceptable_offers_n_quality_factors + ] # If a media range is repeated in the header (which would not make # sense, but would be valid according to the rules in the RFC), an # offer for which the media range is the most specific match would take # its qvalue from the first appearance of the range in the header. - def best_match(self, offers, default_match=None): + @overload + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: None = None, + ) -> str | None: ... + + @overload + def best_match( + self, offers: Iterable[str | tuple[str, float] | list[Any]], default_match: str + ) -> str: ... + + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: str | None = None, + ) -> str | None: """ Return the best match from the sequence of media type `offers`. @@ -1091,7 +1219,7 @@ def best_match(self, offers, default_match=None): " in the future, as it does not conform to the RFC.", DeprecationWarning, ) - best_quality = -1 + best_quality: float = -1 best_offer = default_match matched_by = "*/*" for offer in offers: @@ -1116,7 +1244,7 @@ def best_match(self, offers, default_match=None): matched_by = mask return best_offer - def quality(self, offer): + def quality(self, offer: str) -> float | None: """ Return quality value of given offer, or ``None`` if there is no match. @@ -1167,7 +1295,7 @@ def quality(self, offer): "in the future, as it does not conform to the RFC.", DeprecationWarning, ) - bestq = 0 + bestq: float = 0 for item in self.parsed: media_range = item[0] qvalue = item[1] @@ -1199,7 +1327,7 @@ class MIMEAccept(Accept): """ - def __init__(self, header_value): + def __init__(self, header_value: str) -> None: warnings.warn( "The MIMEAccept class has been replaced by " "webob.acceptparse.create_accept_header. This compatibility shim " @@ -1215,7 +1343,7 @@ def __init__(self, header_value): self._parsed_nonzero = [] @staticmethod - def parse(value): + def parse(value: str) -> Iterator[tuple[str, float]]: # type: ignore[override] try: parsed_accepted = Accept.parse(value) @@ -1224,35 +1352,39 @@ def parse(value): except ValueError: pass - def __repr__(self): - return self._accept.__repr__() + # NOTE: These all should have the same signatures as in Accept + # so no point in type checking these for this compatibility shim + if not TYPE_CHECKING: - def __iter__(self): - return self._accept.__iter__() + def __repr__(self): + return self._accept.__repr__() - def __str__(self): - return self._accept.__str__() + def __iter__(self): + return self._accept.__iter__() - def __add__(self, other): - if isinstance(other, self.__class__): - return self.__class__(str(self._accept.__add__(other._accept))) - else: - return self.__class__(str(self._accept.__add__(other))) + def __str__(self): + return self._accept.__str__() - def __radd__(self, other): - return self.__class__(str(self._accept.__radd__(other))) + def __add__(self, other): + if isinstance(other, self.__class__): + return self.__class__(str(self._accept.__add__(other._accept))) + else: + return self.__class__(str(self._accept.__add__(other))) - def __contains__(self, offer): - return offer in self._accept + def __radd__(self, other): + return self.__class__(str(self._accept.__radd__(other))) - def quality(self, offer): - return self._accept.quality(offer) + def __contains__(self, offer): + return offer in self._accept - def best_match(self, offers, default_match=None): - return self._accept.best_match(offers, default_match=default_match) + def quality(self, offer): + return self._accept.quality(offer) - def accept_html(self): - return self._accept.accept_html() + def best_match(self, offers, default_match=None): + return self._accept.best_match(offers, default_match=default_match) + + def accept_html(self): + return self._accept.accept_html() class _AcceptInvalidOrNoHeader(Accept): @@ -1268,7 +1400,7 @@ class _AcceptInvalidOrNoHeader(Accept): :class:`.AcceptNoHeader` have much behaviour in common. """ - def __bool__(self): + def __bool__(self) -> Literal[False]: """ Return whether ``self`` represents a valid ``Accept`` header. @@ -1280,7 +1412,7 @@ def __bool__(self): """ return False - def __contains__(self, offer): + def __contains__(self, offer: str) -> Literal[True]: """ Return ``bool`` indicating whether `offer` is acceptable. @@ -1306,7 +1438,7 @@ def __contains__(self, offer): ) return True - def __iter__(self): + def __iter__(self) -> Iterator[str]: """ Return all the ranges with non-0 qvalues, in order of preference. @@ -1332,7 +1464,7 @@ def __iter__(self): ) return iter(()) - def accept_html(self): + def accept_html(self) -> bool: """ Return ``True`` if any HTML-like type is accepted. @@ -1357,14 +1489,14 @@ def accept_html(self): accepts_html = property(fget=accept_html, doc=accept_html.__doc__) # note the plural - def acceptable_offers(self, offers): + def acceptable_offers(self, offers: Sequence[str]) -> list[tuple[str, float]]: """ Return the offers that are acceptable according to the header. Any offers that cannot be parsed via :meth:`.Accept.parse_offer` will be ignored. - :param offers: ``iterable`` of ``str`` media types (media types can + :param offers: ``sequence`` of ``str`` media types (media types can include media type parameters) :return: When the header is invalid, or there is no ``Accept`` header in the request, all `offers` are considered acceptable, so @@ -1380,7 +1512,23 @@ def acceptable_offers(self, offers): for offer_index, _ in self._parse_and_normalize_offers(offers) ] - def best_match(self, offers, default_match=None): + @overload + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: None = None, + ) -> str | None: ... + + @overload + def best_match( + self, offers: Iterable[str | tuple[str, float] | list[Any]], default_match: str + ) -> str: ... + + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: str | None = None, + ) -> str | None: """ Return the best match from the sequence of language tag `offers`. @@ -1424,7 +1572,7 @@ def best_match(self, offers, default_match=None): "in (and currently does not conform to) RFC 7231.", DeprecationWarning, ) - best_quality = -1 + best_quality: float = -1 best_offer = default_match for offer in offers: if isinstance(offer, (list, tuple)): @@ -1436,7 +1584,7 @@ def best_match(self, offers, default_match=None): best_quality = quality return best_offer - def quality(self, offer): + def quality(self, offer: str) -> float: """ Return quality value of given offer, or ``None`` if there is no match. @@ -1476,7 +1624,7 @@ class AcceptNoHeader(_AcceptInvalidOrNoHeader): """ @property - def header_value(self): + def header_value(self) -> None: """ (``str`` or ``None``) The header value. @@ -1485,7 +1633,7 @@ def header_value(self): return self._header_value @property - def parsed(self): + def parsed(self) -> None: """ (``list`` or ``None``) Parsed form of the header. @@ -1493,7 +1641,7 @@ def parsed(self): """ return self._parsed - def __init__(self): + def __init__(self) -> None: """ Create an :class:`AcceptNoHeader` instance. """ @@ -1501,14 +1649,43 @@ def __init__(self): self._parsed = None self._parsed_nonzero = None - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. """ return self.__class__() - def __add__(self, other): + @overload + def __add__(self, other: AcceptValidHeader | Literal[""]) -> AcceptValidHeader: ... + + @overload + def __add__(self, other: AcceptNoHeader | AcceptInvalidHeader | None) -> Self: ... + + @overload + def __add__( + self, + other: ( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptValidHeader: ... + + def __add__( + self, + other: ( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptValidHeader: """ Add to header, creating a new header object. @@ -1547,23 +1724,55 @@ def __add__(self, other): return self._add_instance_and_non_accept_type(instance=self, other=other) - def __radd__(self, other): + @overload + def __radd__(self, other: AcceptValidHeader | Literal[""]) -> AcceptValidHeader: ... + + @overload + def __radd__(self, other: AcceptNoHeader | AcceptInvalidHeader | None) -> Self: ... + + @overload + def __radd__( + self, + other: ( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptValidHeader: ... + + def __radd__( + self, + other: ( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptValidHeader: """ Add to header, creating a new header object. See the docstring for :meth:`AcceptNoHeader.__add__`. """ - return self.__add__(other=other) + return self.__add__(other) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" - def __str__(self): + def __str__(self) -> str: """Return the ``str`` ``''``.""" return "" - def _add_instance_and_non_accept_type(self, instance, other): + def _add_instance_and_non_accept_type( + self, instance: Self, other: object + ) -> Self | AcceptValidHeader: + if other is None: return self.__class__() @@ -1592,13 +1801,13 @@ class AcceptInvalidHeader(_AcceptInvalidOrNoHeader): """ @property - def header_value(self): + def header_value(self) -> str: """(``str`` or ``None``) The header value.""" return self._header_value @property - def parsed(self): + def parsed(self) -> None: """ (``list`` or ``None``) Parsed form of the header. @@ -1607,7 +1816,7 @@ def parsed(self): return self._parsed - def __init__(self, header_value): + def __init__(self, header_value: str) -> None: """ Create an :class:`AcceptInvalidHeader` instance. """ @@ -1615,14 +1824,45 @@ def __init__(self, header_value): self._parsed = None self._parsed_nonzero = None - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. """ return self.__class__(self._header_value) - def __add__(self, other): + @overload + def __add__(self, other: AcceptValidHeader | Literal[""]) -> AcceptValidHeader: ... + + @overload + def __add__( + self, other: AcceptInvalidHeader | AcceptNoHeader | None + ) -> AcceptNoHeader: ... + + @overload + def __add__( + self, + other: ( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptValidHeader | AcceptNoHeader: ... + + def __add__( + self, + other: ( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptValidHeader | AcceptNoHeader: """ Add to header, creating a new header object. @@ -1662,7 +1902,38 @@ def __add__(self, other): return self._add_instance_and_non_accept_type(instance=self, other=other) - def __radd__(self, other): + @overload + def __radd__(self, other: AcceptValidHeader | Literal[""]) -> AcceptValidHeader: ... + + @overload + def __radd__( + self, other: AcceptInvalidHeader | AcceptNoHeader | None + ) -> AcceptNoHeader: ... + + @overload + def __radd__( + self, + other: ( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptValidHeader | AcceptNoHeader: ... + + def __radd__( + self, + other: ( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptValidHeader | AcceptNoHeader: """ Add to header, creating a new header object. @@ -1673,20 +1944,21 @@ def __radd__(self, other): instance=self, other=other, instance_on_the_right=True ) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" # We do not display the header_value, as it is untrusted input. The # header_value could always be easily obtained from the .header_value # property. - def __str__(self): + def __str__(self) -> str: """Return the ``str`` ``''``.""" return "" def _add_instance_and_non_accept_type( - self, instance, other, instance_on_the_right=False - ): + self, instance: Self, other: object, instance_on_the_right: bool = False + ) -> AcceptValidHeader | AcceptNoHeader: + if other is None: return AcceptNoHeader() @@ -1698,7 +1970,35 @@ def _add_instance_and_non_accept_type( return AcceptNoHeader() -def create_accept_header(header_value): +@overload +def create_accept_header( + header_value: AcceptValidHeader | Literal[""], +) -> AcceptValidHeader: ... + + +@overload +def create_accept_header(header_value: AcceptInvalidHeader) -> AcceptInvalidHeader: ... + + +@overload +def create_accept_header(header_value: None | AcceptNoHeader) -> AcceptNoHeader: ... + + +@overload +def create_accept_header( + header_value: str, +) -> AcceptValidHeader | AcceptInvalidHeader: ... + + +@overload +def create_accept_header( + header_value: _AnyAcceptHeader | str | None, +) -> _AnyAcceptHeader: ... + + +def create_accept_header( + header_value: _AnyAcceptHeader | str | None, +) -> _AnyAcceptHeader: """ Create an object representing the ``Accept`` header in a request. @@ -1723,7 +2023,7 @@ def create_accept_header(header_value): return AcceptInvalidHeader(header_value=header_value) -def accept_property(): +def accept_property() -> _AcceptProperty: doc = """ Property representing the ``Accept`` header. @@ -1737,12 +2037,22 @@ def accept_property(): ENVIRON_KEY = "HTTP_ACCEPT" - def fget(request): + def fget(request: BaseRequest) -> _AnyAcceptHeader: """Get an object representing the header in the request.""" return create_accept_header(header_value=request.environ.get(ENVIRON_KEY)) - def fset(request, value): + def fset( + request: BaseRequest, + value: ( + _AnyAcceptHeader + | SupportsItems[str, float | tuple[float, str]] + | ListOrTuple[str | tuple[str, float, str] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> None: """ Set the corresponding key in the request environ. @@ -1775,7 +2085,7 @@ def fset(request, value): header_value = Accept._python_value_to_header_str(value=value) request.environ[ENVIRON_KEY] = header_value - def fdel(request): + def fdel(request: BaseRequest) -> None: """Delete the corresponding key from the request environ.""" try: del request.environ[ENVIRON_KEY] @@ -1805,7 +2115,7 @@ class AcceptCharset: ) @classmethod - def _python_value_to_header_str(cls, value): + def _python_value_to_header_str(cls, value: object) -> str: if isinstance(value, str): header_str = value else: @@ -1826,7 +2136,7 @@ def _python_value_to_header_str(cls, value): return header_str @classmethod - def parse(cls, value): + def parse(cls, value: str) -> Iterator[tuple[str, float]]: """ Parse an ``Accept-Charset`` header. @@ -1844,7 +2154,7 @@ def parse(cls, value): if cls.accept_charset_compiled_re.match(value) is None: raise ValueError("Invalid value for an Accept-Charset header.") - def generator(value): + def generator(value: str) -> Iterator[tuple[str, float]]: for match in cls.charset_n_weight_compiled_re.finditer(value): charset = match.group(1) qvalue = match.group(2) @@ -1867,13 +2177,13 @@ class AcceptCharsetValidHeader(AcceptCharset): """ @property - def header_value(self): + def header_value(self) -> str: """(``str``) The header value.""" return self._header_value @property - def parsed(self): + def parsed(self) -> list[tuple[str, float]]: """ (``list``) Parsed form of the header. @@ -1882,7 +2192,7 @@ def parsed(self): return self._parsed - def __init__(self, header_value): + def __init__(self, header_value: str) -> None: """ Create an :class:`AcceptCharsetValidHeader` instance. @@ -1896,14 +2206,24 @@ def __init__(self, header_value): item for item in self.parsed if item[1] # item[1] is the qvalue ] - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. """ return self.__class__(self._header_value) - def __add__(self, other): + def __add__( + self, + other: ( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self: """ Add to header, creating a new header object. @@ -1931,7 +2251,7 @@ def __add__(self, other): """ if isinstance(other, AcceptCharsetValidHeader): - return create_accept_charset_header( + return create_accept_charset_header( # type: ignore[return-value] header_value=self.header_value + ", " + other.header_value ) @@ -1942,7 +2262,7 @@ def __add__(self, other): instance=self, other=other ) - def __bool__(self): + def __bool__(self) -> Literal[True]: """ Return whether ``self`` represents a valid ``Accept-Charset`` header. @@ -1955,7 +2275,7 @@ def __bool__(self): return True - def __contains__(self, offer): + def __contains__(self, offer: str) -> bool: """ Return ``bool`` indicating whether `offer` is acceptable. @@ -1990,7 +2310,7 @@ def __contains__(self, offer): return False - def __iter__(self): + def __iter__(self) -> Iterator[str]: """ Return all the items with non-0 qvalues, in order of preference. @@ -2021,7 +2341,17 @@ def __iter__(self): ): yield mask - def __radd__(self, other): + def __radd__( + self, + other: ( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self: """ Add to header, creating a new header object. @@ -2032,10 +2362,10 @@ def __radd__(self, other): instance=self, other=other, instance_on_the_right=True ) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} ({str(self)!r})>" - def __str__(self): + def __str__(self) -> str: r""" Return a tidied up version of the header value. @@ -2048,8 +2378,9 @@ def __str__(self): ) def _add_instance_and_non_accept_charset_type( - self, instance, other, instance_on_the_right=False - ): + self, instance: Self, other: object, instance_on_the_right: bool = False + ) -> Self: + if not other: return self.__class__(header_value=instance.header_value) @@ -2067,7 +2398,7 @@ def _add_instance_and_non_accept_charset_type( ) return self.__class__(header_value=new_header_value) - def _old_match(self, mask, offer): + def _old_match(self, mask: str, offer: str) -> bool: """ Return whether charset offer matches header item (charset or ``*``). @@ -2090,7 +2421,7 @@ def _old_match(self, mask, offer): """ return mask == "*" or offer.lower() == mask.lower() - def acceptable_offers(self, offers): + def acceptable_offers(self, offers: Sequence[str]) -> list[tuple[str, float]]: """ Return the offers that are acceptable according to the header. @@ -2101,7 +2432,7 @@ def acceptable_offers(self, offers): This uses the matching rules described in :rfc:`RFC 7231, section 5.3.3 <7231#section-5.3.3>`. - :param offers: ``iterable`` of ``str`` charsets + :param offers: ``sequence`` of ``str`` charsets :return: A list of tuples of the form (charset, qvalue), in descending order of qvalue. Where two offers have the same qvalue, they are returned in the same order as their order in `offers`. @@ -2128,9 +2459,9 @@ def acceptable_offers(self, offers): not_acceptable_charsets.add(charset) else: acceptable_charsets[charset] = qvalue - acceptable_charsets = list(acceptable_charsets.items()) + acceptable_charsets_list = list(acceptable_charsets.items()) # Sort acceptable_charsets by qvalue, descending order - acceptable_charsets.sort(key=lambda tuple_: tuple_[1], reverse=True) + acceptable_charsets_list.sort(key=lambda tuple_: tuple_[1], reverse=True) filtered_offers = [] for index, offer in enumerate(lowercased_offers): @@ -2139,7 +2470,7 @@ def acceptable_offers(self, offers): continue matched_charset_qvalue = None - for charset, qvalue in acceptable_charsets: + for charset, qvalue in acceptable_charsets_list: if offer == charset: matched_charset_qvalue = qvalue break @@ -2159,7 +2490,23 @@ def acceptable_offers(self, offers): return [(item[0], item[1]) for item in filtered_offers] # (offer, qvalue), dropping the position - def best_match(self, offers, default_match=None): + @overload + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: None = None, + ) -> str | None: ... + + @overload + def best_match( + self, offers: Iterable[str | tuple[str, float] | list[Any]], default_match: str + ) -> str: ... + + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: str | None = None, + ) -> str | None: """ Return the best match from the sequence of charset `offers`. @@ -2212,7 +2559,7 @@ def best_match(self, offers, default_match=None): "deprecated in the future, as it does not conform to the RFC.", DeprecationWarning, ) - best_quality = -1 + best_quality: float = -1 best_offer = default_match matched_by = "*/*" for offer in offers: @@ -2237,7 +2584,7 @@ def best_match(self, offers, default_match=None): matched_by = mask return best_offer - def quality(self, offer): + def quality(self, offer: str) -> float | None: """ Return quality value of given offer, or ``None`` if there is no match. @@ -2269,7 +2616,7 @@ def quality(self, offer): "deprecated in the future, as it does not conform to the RFC.", DeprecationWarning, ) - bestq = 0 + bestq: float = 0 for mask, quality in self.parsed: if self._old_match(mask, offer): bestq = max(bestq, quality) @@ -2291,7 +2638,7 @@ class _AcceptCharsetInvalidOrNoHeader(AcceptCharset): have much behaviour in common. """ - def __bool__(self): + def __bool__(self) -> Literal[False]: """ Return whether ``self`` represents a valid ``Accept-Charset`` header. @@ -2303,7 +2650,7 @@ def __bool__(self): """ return False - def __contains__(self, offer): + def __contains__(self, offer: str) -> Literal[True]: """ Return ``bool`` indicating whether `offer` is acceptable. @@ -2329,7 +2676,7 @@ def __contains__(self, offer): ) return True - def __iter__(self): + def __iter__(self) -> Iterator[str]: """ Return all the items with non-0 qvalues, in order of preference. @@ -2355,7 +2702,7 @@ def __iter__(self): ) return iter(()) - def acceptable_offers(self, offers): + def acceptable_offers(self, offers: Iterable[str]) -> list[tuple[str, float]]: """ Return the offers that are acceptable according to the header. @@ -2379,7 +2726,23 @@ def acceptable_offers(self, offers): """ return [(offer, 1.0) for offer in offers] - def best_match(self, offers, default_match=None): + @overload + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: None = None, + ) -> str | None: ... + + @overload + def best_match( + self, offers: Iterable[str | tuple[str, float] | list[Any]], default_match: str + ) -> str: ... + + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: str | None = None, + ) -> str | None: """ Return the best match from the sequence of charset `offers`. @@ -2423,7 +2786,7 @@ def best_match(self, offers, default_match=None): "specified in (and currently does not conform to) RFC 7231.", DeprecationWarning, ) - best_quality = -1 + best_quality: float = -1 best_offer = default_match for offer in offers: if isinstance(offer, (list, tuple)): @@ -2435,7 +2798,7 @@ def best_match(self, offers, default_match=None): best_quality = quality return best_offer - def quality(self, offer): + def quality(self, offer: str) -> float | None: """ Return quality value of given offer, or ``None`` if there is no match. @@ -2475,7 +2838,7 @@ class AcceptCharsetNoHeader(_AcceptCharsetInvalidOrNoHeader): """ @property - def header_value(self): + def header_value(self) -> None: """ (``str`` or ``None``) The header value. @@ -2484,7 +2847,7 @@ def header_value(self): return self._header_value @property - def parsed(self): + def parsed(self) -> None: """ (``list`` or ``None``) Parsed form of the header. @@ -2492,7 +2855,7 @@ def parsed(self): """ return self._parsed - def __init__(self): + def __init__(self) -> None: """ Create an :class:`AcceptCharsetNoHeader` instance. """ @@ -2500,14 +2863,46 @@ def __init__(self): self._parsed = None self._parsed_nonzero = None - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. """ return self.__class__() - def __add__(self, other): + @overload + def __add__(self, other: AcceptCharsetValidHeader) -> AcceptCharsetValidHeader: ... + + @overload + def __add__( + self, + other: AcceptCharsetInvalidHeader | AcceptCharsetNoHeader | Literal[""] | None, + ) -> Self: ... + + @overload + def __add__( + self, + other: ( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptCharsetValidHeader: ... + + def __add__( + self, + other: ( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptCharsetValidHeader: """ Add to header, creating a new header object. @@ -2542,23 +2937,58 @@ def __add__(self, other): instance=self, other=other ) - def __radd__(self, other): + @overload + def __radd__(self, other: AcceptCharsetValidHeader) -> AcceptCharsetValidHeader: ... + + @overload + def __radd__( + self, + other: AcceptCharsetInvalidHeader | AcceptCharsetNoHeader | Literal[""] | None, + ) -> Self: ... + + @overload + def __radd__( + self, + other: ( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptCharsetValidHeader: ... + + def __radd__( + self, + other: ( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptCharsetValidHeader: """ Add to header, creating a new header object. See the docstring for :meth:`AcceptCharsetNoHeader.__add__`. """ - return self.__add__(other=other) + return self.__add__(other) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" - def __str__(self): + def __str__(self) -> str: """Return the ``str`` ``''``.""" return "" - def _add_instance_and_non_accept_charset_type(self, instance, other): + def _add_instance_and_non_accept_charset_type( + self, instance: Self, other: object + ) -> Self | AcceptCharsetValidHeader: + if not other: return self.__class__() @@ -2589,13 +3019,13 @@ class AcceptCharsetInvalidHeader(_AcceptCharsetInvalidOrNoHeader): """ @property - def header_value(self): + def header_value(self) -> str: """(``str`` or ``None``) The header value.""" return self._header_value @property - def parsed(self): + def parsed(self) -> None: """ (``list`` or ``None``) Parsed form of the header. @@ -2604,7 +3034,7 @@ def parsed(self): return self._parsed - def __init__(self, header_value): + def __init__(self, header_value: str) -> None: """ Create an :class:`AcceptCharsetInvalidHeader` instance. """ @@ -2612,14 +3042,46 @@ def __init__(self, header_value): self._parsed = None self._parsed_nonzero = None - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. """ return self.__class__(self._header_value) - def __add__(self, other): + @overload + def __add__(self, other: AcceptCharsetValidHeader) -> AcceptCharsetValidHeader: ... + + @overload + def __add__( + self, + other: AcceptCharsetInvalidHeader | AcceptCharsetNoHeader | Literal[""] | None, + ) -> AcceptCharsetNoHeader: ... + + @overload + def __add__( + self, + other: ( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptCharsetValidHeader | AcceptCharsetNoHeader: ... + + def __add__( + self, + other: ( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptCharsetValidHeader | AcceptCharsetNoHeader: """ Add to header, creating a new header object. @@ -2655,7 +3117,39 @@ def __add__(self, other): instance=self, other=other ) - def __radd__(self, other): + @overload + def __radd__(self, other: AcceptCharsetValidHeader) -> AcceptCharsetValidHeader: ... + + @overload + def __radd__( + self, + other: AcceptCharsetInvalidHeader | AcceptCharsetNoHeader | Literal[""] | None, + ) -> AcceptCharsetNoHeader: ... + + @overload + def __radd__( + self, + other: ( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptCharsetValidHeader | AcceptCharsetNoHeader: ... + + def __radd__( + self, + other: ( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptCharsetValidHeader | AcceptCharsetNoHeader: """ Add to header, creating a new header object. @@ -2666,20 +3160,21 @@ def __radd__(self, other): instance=self, other=other, instance_on_the_right=True ) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" # We do not display the header_value, as it is untrusted input. The # header_value could always be easily obtained from the .header_value # property. - def __str__(self): + def __str__(self) -> str: """Return the ``str`` ``''``.""" return "" def _add_instance_and_non_accept_charset_type( - self, instance, other, instance_on_the_right=False - ): + self, instance: Self, other: object, instance_on_the_right: bool = False + ) -> AcceptCharsetValidHeader | AcceptCharsetNoHeader: + if not other: return AcceptCharsetNoHeader() @@ -2691,7 +3186,39 @@ def _add_instance_and_non_accept_charset_type( return AcceptCharsetNoHeader() -def create_accept_charset_header(header_value): +@overload +def create_accept_charset_header( + header_value: AcceptCharsetValidHeader | Literal[""], +) -> AcceptCharsetValidHeader: ... + + +@overload +def create_accept_charset_header( + header_value: AcceptCharsetInvalidHeader, +) -> AcceptCharsetInvalidHeader: ... + + +@overload +def create_accept_charset_header( + header_value: AcceptCharsetNoHeader | None, +) -> AcceptCharsetNoHeader: ... + + +@overload +def create_accept_charset_header( + header_value: str, +) -> AcceptCharsetValidHeader | AcceptCharsetInvalidHeader: ... + + +@overload +def create_accept_charset_header( + header_value: _AnyAcceptCharsetHeader | str | None, +) -> _AnyAcceptCharsetHeader: ... + + +def create_accept_charset_header( + header_value: _AnyAcceptCharsetHeader | str | None, +) -> _AnyAcceptCharsetHeader: """ Create an object representing the ``Accept-Charset`` header in a request. @@ -2716,7 +3243,7 @@ def create_accept_charset_header(header_value): return AcceptCharsetInvalidHeader(header_value=header_value) -def accept_charset_property(): +def accept_charset_property() -> _AcceptCharsetProperty: doc = """ Property representing the ``Accept-Charset`` header. @@ -2730,14 +3257,24 @@ def accept_charset_property(): ENVIRON_KEY = "HTTP_ACCEPT_CHARSET" - def fget(request): + def fget(request: BaseRequest) -> _AnyAcceptCharsetHeader: """Get an object representing the header in the request.""" return create_accept_charset_header( header_value=request.environ.get(ENVIRON_KEY) ) - def fset(request, value): + def fset( + request: BaseRequest, + value: ( + _AnyAcceptCharsetHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> None: """ Set the corresponding key in the request environ. @@ -2765,7 +3302,7 @@ def fset(request, value): header_value = AcceptCharset._python_value_to_header_str(value=value) request.environ[ENVIRON_KEY] = header_value - def fdel(request): + def fdel(request: BaseRequest) -> None: """Delete the corresponding key from the request environ.""" try: del request.environ[ENVIRON_KEY] @@ -2798,7 +3335,7 @@ class AcceptEncoding: ) @classmethod - def _python_value_to_header_str(cls, value): + def _python_value_to_header_str(cls, value: object) -> str: if isinstance(value, str): header_str = value else: @@ -2819,7 +3356,7 @@ def _python_value_to_header_str(cls, value): return header_str @classmethod - def parse(cls, value): + def parse(cls, value: str) -> Iterator[tuple[str, float]]: """ Parse an ``Accept-Encoding`` header. @@ -2837,7 +3374,7 @@ def parse(cls, value): if cls.accept_encoding_compiled_re.match(value) is None: raise ValueError("Invalid value for an Accept-Encoding header.") - def generator(value): + def generator(value: str) -> Iterator[tuple[str, float]]: for match in cls.codings_n_weight_compiled_re.finditer(value): codings = match.group(1) qvalue = match.group(2) @@ -2860,13 +3397,13 @@ class AcceptEncodingValidHeader(AcceptEncoding): """ @property - def header_value(self): + def header_value(self) -> str: """(``str`` or ``None``) The header value.""" return self._header_value @property - def parsed(self): + def parsed(self) -> list[tuple[str, float]]: """ (``list`` or ``None``) Parsed form of the header. @@ -2880,7 +3417,7 @@ def parsed(self): return self._parsed - def __init__(self, header_value): + def __init__(self, header_value: str) -> None: """ Create an :class:`AcceptEncodingValidHeader` instance. @@ -2893,14 +3430,24 @@ def __init__(self, header_value): self._parsed_nonzero = [item for item in self.parsed if item[1]] # item[1] is the qvalue - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. """ return self.__class__(self._header_value) - def __add__(self, other): + def __add__( + self, + other: ( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self: """ Add to header, creating a new header object. @@ -2937,7 +3484,7 @@ def __add__(self, other): if other.header_value == "": return self.__class__(header_value=self.header_value) else: - return create_accept_encoding_header( + return create_accept_encoding_header( # type: ignore[return-value] header_value=self.header_value + ", " + other.header_value ) @@ -2948,7 +3495,7 @@ def __add__(self, other): instance=self, other=other ) - def __bool__(self): + def __bool__(self) -> Literal[True]: """ Return whether ``self`` represents a valid ``Accept-Encoding`` header. @@ -2961,7 +3508,7 @@ def __bool__(self): return True - def __contains__(self, offer): + def __contains__(self, offer: str) -> bool: """ Return ``bool`` indicating whether `offer` is acceptable. @@ -2996,8 +3543,9 @@ def __contains__(self, offer): for mask, _quality in self._parsed_nonzero: if self._old_match(mask, offer): return True + return False - def __iter__(self): + def __iter__(self) -> Iterator[str]: """ Return all the ranges with non-0 qvalues, in order of preference. @@ -3029,7 +3577,17 @@ def __iter__(self): ): yield mask - def __radd__(self, other): + def __radd__( + self, + other: ( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self: """ Add to header, creating a new header object. @@ -3040,10 +3598,10 @@ def __radd__(self, other): instance=self, other=other, instance_on_the_right=True ) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} ({str(self)!r})>" - def __str__(self): + def __str__(self) -> str: r""" Return a tidied up version of the header value. @@ -3055,8 +3613,9 @@ def __str__(self): ) def _add_instance_and_non_accept_encoding_type( - self, instance, other, instance_on_the_right=False - ): + self, instance: Self, other: object, instance_on_the_right: bool = False + ) -> Self: + if not other: return self.__class__(header_value=instance.header_value) @@ -3079,7 +3638,7 @@ def _add_instance_and_non_accept_encoding_type( ) return self.__class__(header_value=new_header_value) - def _old_match(self, mask, offer): + def _old_match(self, mask: str, offer: str) -> bool: """ Return whether content-coding offer matches codings header item. @@ -3103,7 +3662,7 @@ def _old_match(self, mask, offer): """ return mask == "*" or offer.lower() == mask.lower() - def acceptable_offers(self, offers): + def acceptable_offers(self, offers: Sequence[str]) -> list[tuple[str, float]]: """ Return the offers that are acceptable according to the header. @@ -3114,7 +3673,7 @@ def acceptable_offers(self, offers): This uses the matching rules described in :rfc:`RFC 7231, section 5.3.4 <7231#section-5.3.4>`. - :param offers: ``iterable`` of ``str``s, where each ``str`` is a + :param offers: ``sequence`` of ``str``s, where each ``str`` is a content-coding or the string ``identity`` (the token used to represent "no encoding") :return: A list of tuples of the form (content-coding or "identity", @@ -3153,9 +3712,9 @@ def acceptable_offers(self, offers): not_acceptable_codingss.add(codings) else: acceptable_codingss[codings] = qvalue - acceptable_codingss = list(acceptable_codingss.items()) + acceptable_codingsl = list(acceptable_codingss.items()) # Sort acceptable_codingss by qvalue, descending order - acceptable_codingss.sort(key=lambda tuple_: tuple_[1], reverse=True) + acceptable_codingsl.sort(key=lambda tuple_: tuple_[1], reverse=True) filtered_offers = [] for index, offer in enumerate(lowercased_offers): @@ -3164,7 +3723,7 @@ def acceptable_offers(self, offers): continue matched_codings_qvalue = None - for codings, qvalue in acceptable_codingss: + for codings, qvalue in acceptable_codingsl: if offer == codings: matched_codings_qvalue = qvalue break @@ -3186,7 +3745,23 @@ def acceptable_offers(self, offers): return [(item[0], item[1]) for item in filtered_offers] # (offer, qvalue), dropping the position - def best_match(self, offers, default_match=None): + @overload + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: None = None, + ) -> str | None: ... + + @overload + def best_match( + self, offers: Iterable[str | tuple[str, float] | list[Any]], default_match: str + ) -> str: ... + + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: str | None = None, + ) -> str | None: """ Return the best match from the sequence of `offers`. @@ -3247,7 +3822,7 @@ def best_match(self, offers, default_match=None): " RFC.", DeprecationWarning, ) - best_quality = -1 + best_quality: float = -1 best_offer = default_match matched_by = "*/*" for offer in offers: @@ -3274,7 +3849,7 @@ def best_match(self, offers, default_match=None): matched_by = mask return best_offer - def quality(self, offer): + def quality(self, offer: str) -> float | None: """ Return quality value of given offer, or ``None`` if there is no match. @@ -3308,7 +3883,7 @@ def quality(self, offer): "deprecated in the future, as it does not conform to the RFC.", DeprecationWarning, ) - bestq = 0 + bestq: float = 0 for mask, q in self.parsed: if self._old_match(mask, offer): bestq = max(bestq, q) @@ -3330,7 +3905,7 @@ class _AcceptEncodingInvalidOrNoHeader(AcceptEncoding): have much behaviour in common. """ - def __bool__(self): + def __bool__(self) -> Literal[False]: """ Return whether ``self`` represents a valid ``Accept-Encoding`` header. @@ -3342,7 +3917,7 @@ def __bool__(self): """ return False - def __contains__(self, offer): + def __contains__(self, offer: str) -> Literal[True]: """ Return ``bool`` indicating whether `offer` is acceptable. @@ -3368,7 +3943,7 @@ def __contains__(self, offer): ) return True - def __iter__(self): + def __iter__(self) -> Iterator[str]: """ Return all the header items with non-0 qvalues, in order of preference. @@ -3395,7 +3970,7 @@ def __iter__(self): ) return iter(()) - def acceptable_offers(self, offers): + def acceptable_offers(self, offers: Iterable[str]) -> list[tuple[str, float]]: """ Return the offers that are acceptable according to the header. @@ -3411,7 +3986,23 @@ def acceptable_offers(self, offers): """ return [(offer, 1.0) for offer in offers] - def best_match(self, offers, default_match=None): + @overload + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: None = None, + ) -> str | None: ... + + @overload + def best_match( + self, offers: Iterable[str | tuple[str, float] | list[Any]], default_match: str + ) -> str: ... + + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: str | None = None, + ) -> str | None: """ Return the best match from the sequence of `offers`. @@ -3457,7 +4048,7 @@ def best_match(self, offers, default_match=None): "specified in (and currently does not conform to) RFC 7231.", DeprecationWarning, ) - best_quality = -1 + best_quality: float = -1 best_offer = default_match for offer in offers: if isinstance(offer, (list, tuple)): @@ -3469,7 +4060,7 @@ def best_match(self, offers, default_match=None): best_quality = quality return best_offer - def quality(self, offer): + def quality(self, offer: str) -> float | None: """ Return quality value of given offer, or ``None`` if there is no match. @@ -3509,7 +4100,7 @@ class AcceptEncodingNoHeader(_AcceptEncodingInvalidOrNoHeader): """ @property - def header_value(self): + def header_value(self) -> None: """ (``str`` or ``None``) The header value. @@ -3518,7 +4109,7 @@ def header_value(self): return self._header_value @property - def parsed(self): + def parsed(self) -> None: """ (``list`` or ``None``) Parsed form of the header. @@ -3526,7 +4117,7 @@ def parsed(self): """ return self._parsed - def __init__(self): + def __init__(self) -> None: """ Create an :class:`AcceptEncodingNoHeader` instance. """ @@ -3534,14 +4125,47 @@ def __init__(self): self._parsed = None self._parsed_nonzero = None - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. """ return self.__class__() - def __add__(self, other): + @overload + def __add__( + self, other: AcceptEncodingValidHeader | Literal[""] + ) -> AcceptEncodingValidHeader: ... + + @overload + def __add__( + self, other: AcceptEncodingInvalidHeader | AcceptEncodingNoHeader | None + ) -> Self: ... + + @overload + def __add__( + self, + other: ( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptEncodingValidHeader: ... + + def __add__( + self, + other: ( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptEncodingValidHeader: """ Add to header, creating a new header object. @@ -3578,23 +4202,59 @@ def __add__(self, other): instance=self, other=other ) - def __radd__(self, other): + @overload + def __radd__( + self, other: AcceptEncodingValidHeader | Literal[""] + ) -> AcceptEncodingValidHeader: ... + + @overload + def __radd__( + self, other: AcceptEncodingInvalidHeader | AcceptEncodingNoHeader | None + ) -> Self: ... + + @overload + def __radd__( + self, + other: ( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptEncodingValidHeader: ... + + def __radd__( + self, + other: ( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptEncodingValidHeader: """ Add to header, creating a new header object. See the docstring for :meth:`AcceptEncodingNoHeader.__add__`. """ - return self.__add__(other=other) + return self.__add__(other) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" - def __str__(self): + def __str__(self) -> str: """Return the ``str`` ``''``.""" return "" - def _add_instance_and_non_accept_encoding_type(self, instance, other): + def _add_instance_and_non_accept_encoding_type( + self, instance: Self, other: object + ) -> Self | AcceptEncodingValidHeader: + if other is None: return self.__class__() @@ -3624,13 +4284,13 @@ class AcceptEncodingInvalidHeader(_AcceptEncodingInvalidOrNoHeader): """ @property - def header_value(self): + def header_value(self) -> str: """(``str`` or ``None``) The header value.""" return self._header_value @property - def parsed(self): + def parsed(self) -> None: """ (``list`` or ``None``) Parsed form of the header. @@ -3639,7 +4299,7 @@ def parsed(self): return self._parsed - def __init__(self, header_value): + def __init__(self, header_value: str) -> None: """ Create an :class:`AcceptEncodingInvalidHeader` instance. """ @@ -3647,14 +4307,47 @@ def __init__(self, header_value): self._parsed = None self._parsed_nonzero = None - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. """ return self.__class__(self._header_value) - def __add__(self, other): + @overload + def __add__( + self, other: AcceptEncodingValidHeader | Literal[""] + ) -> AcceptEncodingValidHeader: ... + + @overload + def __add__( + self, other: AcceptEncodingInvalidHeader | AcceptEncodingNoHeader | None + ) -> AcceptEncodingNoHeader: ... + + @overload + def __add__( + self, + other: ( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptEncodingValidHeader | AcceptEncodingNoHeader: ... + + def __add__( + self, + other: ( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptEncodingValidHeader | AcceptEncodingNoHeader: """ Add to header, creating a new header object. @@ -3692,7 +4385,40 @@ def __add__(self, other): instance=self, other=other ) - def __radd__(self, other): + @overload + def __radd__( + self, other: AcceptEncodingValidHeader | Literal[""] + ) -> AcceptEncodingValidHeader: ... + + @overload + def __radd__( + self, other: AcceptEncodingInvalidHeader | AcceptEncodingNoHeader | None + ) -> AcceptEncodingNoHeader: ... + + @overload + def __radd__( + self, + other: ( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptEncodingValidHeader | AcceptEncodingNoHeader: ... + + def __radd__( + self, + other: ( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptEncodingValidHeader | AcceptEncodingNoHeader: """ Add to header, creating a new header object. @@ -3703,20 +4429,21 @@ def __radd__(self, other): instance=self, other=other, instance_on_the_right=True ) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" # We do not display the header_value, as it is untrusted input. The # header_value could always be easily obtained from the .header_value # property. - def __str__(self): + def __str__(self) -> str: """Return the ``str`` ``''``.""" return "" def _add_instance_and_non_accept_encoding_type( - self, instance, other, instance_on_the_right=False - ): + self, instance: Self, other: object, instance_on_the_right: bool = False + ) -> AcceptEncodingValidHeader | AcceptEncodingNoHeader: + if other is None: return AcceptEncodingNoHeader() @@ -3728,7 +4455,39 @@ def _add_instance_and_non_accept_encoding_type( return AcceptEncodingNoHeader() -def create_accept_encoding_header(header_value): +@overload +def create_accept_encoding_header( + header_value: AcceptEncodingValidHeader | Literal[""], +) -> AcceptEncodingValidHeader: ... + + +@overload +def create_accept_encoding_header( + header_value: AcceptEncodingInvalidHeader, +) -> AcceptEncodingInvalidHeader: ... + + +@overload +def create_accept_encoding_header( + header_value: AcceptEncodingNoHeader | None, +) -> AcceptEncodingNoHeader: ... + + +@overload +def create_accept_encoding_header( + header_value: str, +) -> AcceptEncodingValidHeader | AcceptEncodingInvalidHeader: ... + + +@overload +def create_accept_encoding_header( + header_value: _AnyAcceptEncodingHeader | str | None, +) -> _AnyAcceptEncodingHeader: ... + + +def create_accept_encoding_header( + header_value: _AnyAcceptEncodingHeader | str | None, +) -> _AnyAcceptEncodingHeader: """ Create an object representing the ``Accept-Encoding`` header in a request. @@ -3753,7 +4512,7 @@ def create_accept_encoding_header(header_value): return AcceptEncodingInvalidHeader(header_value=header_value) -def accept_encoding_property(): +def accept_encoding_property() -> _AcceptEncodingProperty: doc = """ Property representing the ``Accept-Encoding`` header. @@ -3767,14 +4526,24 @@ def accept_encoding_property(): ENVIRON_KEY = "HTTP_ACCEPT_ENCODING" - def fget(request): + def fget(request: BaseRequest) -> _AnyAcceptEncodingHeader: """Get an object representing the header in the request.""" return create_accept_encoding_header( header_value=request.environ.get(ENVIRON_KEY) ) - def fset(request, value): + def fset( + request: BaseRequest, + value: ( + _AnyAcceptEncodingHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> None: """ Set the corresponding key in the request environ. @@ -3804,7 +4573,7 @@ def fset(request, value): header_value = AcceptEncoding._python_value_to_header_str(value=value) request.environ[ENVIRON_KEY] = header_value - def fdel(request): + def fdel(request: BaseRequest) -> None: """Delete the corresponding key from the request environ.""" try: del request.environ[ENVIRON_KEY] @@ -3837,7 +4606,7 @@ class AcceptLanguage: ) @classmethod - def _python_value_to_header_str(cls, value): + def _python_value_to_header_str(cls, value: object) -> str: if isinstance(value, str): header_str = value else: @@ -3858,7 +4627,7 @@ def _python_value_to_header_str(cls, value): return header_str @classmethod - def parse(cls, value): + def parse(cls, value: str) -> Iterator[tuple[str, float]]: """ Parse an ``Accept-Language`` header. @@ -3876,7 +4645,7 @@ def parse(cls, value): if cls.accept_language_compiled_re.match(value) is None: raise ValueError("Invalid value for an Accept-Language header.") - def generator(value): + def generator(value: str) -> Iterator[tuple[str, float]]: for match in cls.lang_range_n_weight_compiled_re.finditer(value): lang_range = match.group(1) qvalue = match.group(2) @@ -3903,7 +4672,7 @@ class AcceptLanguageValidHeader(AcceptLanguage): docstring for :meth:`AcceptLanguageValidHeader.__add__`). """ - def __init__(self, header_value): + def __init__(self, header_value: str) -> None: """ Create an :class:`AcceptLanguageValidHeader` instance. @@ -3916,7 +4685,7 @@ def __init__(self, header_value): self._parsed_nonzero = [item for item in self.parsed if item[1]] # item[1] is the qvalue - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. @@ -3924,13 +4693,13 @@ def copy(self): return self.__class__(self._header_value) @property - def header_value(self): + def header_value(self) -> str: """(``str`` or ``None``) The header value.""" return self._header_value @property - def parsed(self): + def parsed(self) -> list[tuple[str, float]]: """ (``list`` or ``None``) Parsed form of the header. @@ -3939,7 +4708,17 @@ def parsed(self): return self._parsed - def __add__(self, other): + def __add__( + self, + other: ( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self: """ Add to header, creating a new header object. @@ -3968,7 +4747,7 @@ def __add__(self, other): """ if isinstance(other, AcceptLanguageValidHeader): - return create_accept_language_header( + return create_accept_language_header( # type: ignore[return-value] header_value=self.header_value + ", " + other.header_value ) @@ -3979,7 +4758,7 @@ def __add__(self, other): instance=self, other=other ) - def __bool__(self): + def __bool__(self) -> Literal[True]: """ Return whether ``self`` represents a valid ``Accept-Language`` header. @@ -3992,7 +4771,7 @@ def __bool__(self): return True - def __contains__(self, offer): + def __contains__(self, offer: str) -> bool: """ Return ``bool`` indicating whether `offer` is acceptable. @@ -4044,7 +4823,7 @@ def __contains__(self, offer): return False - def __iter__(self): + def __iter__(self) -> Iterator[str]: """ Return all the ranges with non-0 qvalues, in order of preference. @@ -4075,7 +4854,17 @@ def __iter__(self): ): yield mask - def __radd__(self, other): + def __radd__( + self, + other: ( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self: """ Add to header, creating a new header object. @@ -4086,10 +4875,10 @@ def __radd__(self, other): instance=self, other=other, instance_on_the_right=True ) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} ({str(self)!r})>" - def __str__(self): + def __str__(self) -> str: r""" Return a tidied up version of the header value. @@ -4102,8 +4891,9 @@ def __str__(self): ) def _add_instance_and_non_accept_language_type( - self, instance, other, instance_on_the_right=False - ): + self, instance: Self, other: object, instance_on_the_right: bool = False + ) -> Self: + if not other: return self.__class__(header_value=instance.header_value) @@ -4121,7 +4911,7 @@ def _add_instance_and_non_accept_language_type( ) return self.__class__(header_value=new_header_value) - def _old_match(self, mask, item): + def _old_match(self, mask: str, item: str) -> bool: """ Return whether a language tag matches a language range. @@ -4193,7 +4983,7 @@ def _old_match(self, mask, item): or item == mask.split("-")[0] ) - def basic_filtering(self, language_tags): + def basic_filtering(self, language_tags: Sequence[str]) -> list[tuple[str, float]]: """ Return the tags that match the header, using Basic Filtering. @@ -4204,7 +4994,7 @@ def basic_filtering(self, language_tags): tags in the `language_tags` argument and returns the ones that match the header according to the matching scheme. - :param language_tags: (``iterable``) language tags + :param language_tags: (``sequence``) language tags :return: A list of tuples of the form (language tag, qvalue), in descending order of qvalue. If two or more tags have the same qvalue, they are returned in the same order as that in the @@ -4269,18 +5059,18 @@ def basic_filtering(self, language_tags): not_acceptable_ranges.add(range_) else: acceptable_ranges[range_] = (qvalue, position_in_header) - acceptable_ranges = [ + acceptable_ranges_list = [ (range_, qvalue, position_in_header) for range_, (qvalue, position_in_header) in acceptable_ranges.items() ] # Sort acceptable_ranges by position_in_header, ascending order - acceptable_ranges.sort(key=lambda tuple_: tuple_[2]) + acceptable_ranges_list.sort(key=lambda tuple_: tuple_[2]) # Sort acceptable_ranges by qvalue, descending order - acceptable_ranges.sort(key=lambda tuple_: tuple_[1], reverse=True) + acceptable_ranges_list.sort(key=lambda tuple_: tuple_[1], reverse=True) # Sort guaranteed to be stable with Python >= 2.2, so position in # header is tiebreaker when two ranges have the same qvalue - def match(tag, range_): + def match(tag: str, range_: str) -> bool: # RFC 4647, section 2.1: 'A language range matches a particular # language tag if, in a case-insensitive comparison, it exactly # equals the tag, or if it exactly equals a prefix of the tag such @@ -4297,7 +5087,7 @@ def match(tag, range_): continue matched_range_qvalue = None - for range_, qvalue, position_in_header in acceptable_ranges: + for range_, qvalue, position_in_header in acceptable_ranges_list: # acceptable_ranges is in descending order of qvalue, and tied # ranges are in ascending order of position_in_header, so the # first range_ that matches the tag is the best match @@ -4349,7 +5139,23 @@ def match(tag, range_): # (same qvalue), which we would not be able to do easily with a set or # a list without e.g. making a member of the set or list a sequence. - def best_match(self, offers, default_match=None): + @overload + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: None = None, + ) -> str | None: ... + + @overload + def best_match( + self, offers: Iterable[str | tuple[str, float] | list[Any]], default_match: str + ) -> str: ... + + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: str | None = None, + ) -> str | None: """ Return the best match from the sequence of language tag `offers`. @@ -4457,7 +5263,7 @@ def best_match(self, offers, default_match=None): "RFC.", DeprecationWarning, ) - best_quality = -1 + best_quality: float = -1 best_offer = default_match matched_by = "*/*" # [We can see that this was written for the ``Accept`` header and not @@ -4494,7 +5300,70 @@ def best_match(self, offers, default_match=None): matched_by = mask return best_offer - def lookup(self, language_tags, default_range=None, default_tag=None, default=None): + @overload + def lookup( + self, + language_tags: Sequence[str], + default_range: str | None, + default_tag: str, + default: None = None, + ) -> str | None: ... + + @overload + def lookup( + self, + language_tags: Sequence[str], + *, + default_range: str | None = None, + default_tag: str, + default: None = None, + ) -> str | None: ... + + @overload + def lookup( + self, + language_tags: Sequence[str], + default_range: str | None, + default_tag: None, + default: _T | Callable[[], _T], + ) -> _T | str | None: ... + + @overload + def lookup( + self, + language_tags: Sequence[str], + default_range: str | None, + default_tag: str, + default: _T | Callable[[], _T], + ) -> _T | str: ... + + @overload + def lookup( + self, + language_tags: Sequence[str], + *, + default_range: str | None = None, + default_tag: None = None, + default: _T | Callable[[], _T], + ) -> _T | str | None: ... + + @overload + def lookup( + self, + language_tags: Sequence[str], + *, + default_range: str | None = None, + default_tag: str, + default: _T | Callable[[], _T], + ) -> _T | str: ... + + def lookup( + self, + language_tags: Sequence[str], + default_range: str | None = None, + default_tag: str | None = None, + default: _T | Callable[[], _T] | None = None, + ) -> _T | str | None: """ Return the language tag that best matches the header, using Lookup. @@ -4523,7 +5392,7 @@ def lookup(self, language_tags, default_range=None, default_tag=None, default=No 5. zh 6. (default) - :param language_tags: (``iterable``) language tags + :param language_tags: (``sequence``) language tags :param default_range: (optional, ``None`` or ``str``) @@ -4647,7 +5516,7 @@ def lookup(self, language_tags, default_range=None, default_tag=None, default=No tags = language_tags not_acceptable_ranges = [] - acceptable_ranges = [] + acceptable_range_pairs = [] asterisk_q0_found = False # Whether there is a '*' range in the header with q=0 @@ -4664,17 +5533,17 @@ def lookup(self, language_tags, default_range=None, default_tag=None, default=No elif qvalue == 0.0: not_acceptable_ranges.append(range_.lower()) else: - acceptable_ranges.append((range_, qvalue)) + acceptable_range_pairs.append((range_, qvalue)) # range_ is .lower()ed later # Sort acceptable_ranges by qvalue, descending order - acceptable_ranges.sort(key=lambda tuple_: tuple_[1], reverse=True) + acceptable_range_pairs.sort(key=lambda tuple_: tuple_[1], reverse=True) # Sort guaranteed to be stable with Python >= 2.2, so position in # header is tiebreaker when two ranges have the same qvalue - acceptable_ranges = [tuple_[0] for tuple_ in acceptable_ranges] + acceptable_ranges = [tuple_[0] for tuple_ in acceptable_range_pairs] lowered_tags = [tag.lower() for tag in tags] - def best_match(range_): + def best_match(range_: str) -> str | None: subtags = range_.split("-") while True: for index, tag in enumerate(lowered_tags): @@ -4696,7 +5565,7 @@ def best_match(range_): try: subtag_before_this = subtags[-2] except IndexError: # len(subtags) == 1 - break + return None # len(subtags) >= 2 if len(subtag_before_this) == 1 and ( subtag_before_this.isdigit() or subtag_before_this.isalpha() @@ -4722,12 +5591,12 @@ def best_match(range_): if lowered_default_tag not in not_acceptable_ranges: return default_tag - try: + if callable(default): return default() - except TypeError: # default is not a callable + else: # default is not a callable return default - def quality(self, offer): + def quality(self, offer: str) -> float | None: """ Return quality value of given offer, or ``None`` if there is no match. @@ -4802,7 +5671,7 @@ def quality(self, offer): "RFC.", DeprecationWarning, ) - bestq = 0 + bestq: float = 0 for mask, q in self.parsed: if self._old_match(mask, offer): bestq = max(bestq, q) @@ -4824,7 +5693,7 @@ class _AcceptLanguageInvalidOrNoHeader(AcceptLanguage): have much behaviour in common. """ - def __bool__(self): + def __bool__(self) -> Literal[False]: """ Return whether ``self`` represents a valid ``Accept-Language`` header. @@ -4836,7 +5705,7 @@ def __bool__(self): """ return False - def __contains__(self, offer): + def __contains__(self, offer: str) -> Literal[True]: """ Return ``bool`` indicating whether `offer` is acceptable. @@ -4862,7 +5731,7 @@ def __contains__(self, offer): ) return True - def __iter__(self): + def __iter__(self) -> Iterator[str]: """ Return all the ranges with non-0 qvalues, in order of preference. @@ -4888,7 +5757,7 @@ def __iter__(self): ) return iter(()) - def basic_filtering(self, language_tags): + def basic_filtering(self, language_tags: Iterable[str]) -> list[tuple[str, float]]: """ Return the tags that match the header, using Basic Filtering. @@ -4901,7 +5770,23 @@ def basic_filtering(self, language_tags): """ return [] - def best_match(self, offers, default_match=None): + @overload + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: None = None, + ) -> str | None: ... + + @overload + def best_match( + self, offers: Iterable[str | tuple[str, float] | list[Any]], default_match: str + ) -> str: ... + + def best_match( + self, + offers: Iterable[str | tuple[str, float] | list[Any]], + default_match: str | None = None, + ) -> str | None: """ Return the best match from the sequence of language tag `offers`. @@ -4948,7 +5833,7 @@ def best_match(self, offers, default_match=None): "specified in (and currently does not conform to) RFC 7231.", DeprecationWarning, ) - best_quality = -1 + best_quality: float = -1 best_offer = default_match for offer in offers: if isinstance(offer, (list, tuple)): @@ -4960,9 +5845,51 @@ def best_match(self, offers, default_match=None): best_quality = quality return best_offer + @overload + def lookup( + self, + language_tags: object, + default_range: object, + default_tag: str, + default: object = None, + ) -> str: ... + + @overload + def lookup( + self, + language_tags: object = None, + *, + default_range: object = None, + default_tag: str, + default: object = None, + ) -> str: ... + + @overload + def lookup( + self, + language_tags: object, + default_range: object, + default_tag: None, + default: _T | Callable[[], _T], + ) -> _T: ... + + @overload + def lookup( + self, + language_tags: object = None, + *, + default_range: object = None, + default_tag: None = None, + default: _T | Callable[[], _T], + ) -> _T: ... + def lookup( - self, language_tags=None, default_range=None, default_tag=None, default=None - ): + self, + language_tags: object = None, + default_range: object = None, + default_tag: str | None = None, + default: _T | Callable[[], _T] | None = None, + ) -> _T | str | None: """ Return the language tag that best matches the header, using Lookup. @@ -5034,12 +5961,12 @@ def lookup( if default_tag is not None: return default_tag - try: + if callable(default): return default() - except TypeError: # default is not a callable + else: # default is not a callable return default - def quality(self, offer): + def quality(self, offer: str) -> float | None: """ Return quality value of given offer, or ``None`` if there is no match. @@ -5078,7 +6005,7 @@ class AcceptLanguageNoHeader(_AcceptLanguageInvalidOrNoHeader): docstring for :meth:`AcceptLanguageNoHeader.__add__`). """ - def __init__(self): + def __init__(self) -> None: """ Create an :class:`AcceptLanguageNoHeader` instance. """ @@ -5086,7 +6013,7 @@ def __init__(self): self._parsed = None self._parsed_nonzero = None - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. @@ -5094,7 +6021,7 @@ def copy(self): return self.__class__() @property - def header_value(self): + def header_value(self) -> None: """ (``str`` or ``None``) The header value. @@ -5103,7 +6030,7 @@ def header_value(self): return self._header_value @property - def parsed(self): + def parsed(self) -> None: """ (``list`` or ``None``) Parsed form of the header. @@ -5111,7 +6038,43 @@ def parsed(self): """ return self._parsed - def __add__(self, other): + @overload + def __add__( + self, other: AcceptLanguageValidHeader + ) -> AcceptLanguageValidHeader: ... + + @overload + def __add__( + self, + other: ( + AcceptLanguageInvalidHeader | AcceptLanguageNoHeader | Literal[""] | None + ), + ) -> Self: ... + + @overload + def __add__( + self, + other: ( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptLanguageValidHeader: ... + + def __add__( + self, + other: ( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptLanguageValidHeader: """ Add to header, creating a new header object. @@ -5147,23 +6110,62 @@ def __add__(self, other): instance=self, other=other ) - def __radd__(self, other): + @overload + def __radd__( + self, other: AcceptLanguageValidHeader + ) -> AcceptLanguageValidHeader: ... + + @overload + def __radd__( + self, + other: ( + AcceptLanguageInvalidHeader | AcceptLanguageNoHeader | Literal[""] | None + ), + ) -> Self: ... + + @overload + def __radd__( + self, + other: ( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptLanguageValidHeader: ... + + def __radd__( + self, + other: ( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> Self | AcceptLanguageValidHeader: """ Add to header, creating a new header object. See the docstring for :meth:`AcceptLanguageNoHeader.__add__`. """ - return self.__add__(other=other) + return self.__add__(other) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" - def __str__(self): + def __str__(self) -> str: """Return the ``str`` ``''``.""" return "" - def _add_instance_and_non_accept_language_type(self, instance, other): + def _add_instance_and_non_accept_language_type( + self, instance: Self, other: object + ) -> Self | AcceptLanguageValidHeader: + if not other: return self.__class__() @@ -5193,7 +6195,7 @@ class AcceptLanguageInvalidHeader(_AcceptLanguageInvalidOrNoHeader): docstring for :meth:`AcceptLanguageInvalidHeader.__add__`). """ - def __init__(self, header_value): + def __init__(self, header_value: str) -> None: """ Create an :class:`AcceptLanguageInvalidHeader` instance. """ @@ -5201,7 +6203,7 @@ def __init__(self, header_value): self._parsed = None self._parsed_nonzero = None - def copy(self): + def copy(self) -> Self: """ Create a copy of the header object. @@ -5209,13 +6211,13 @@ def copy(self): return self.__class__(self._header_value) @property - def header_value(self): + def header_value(self) -> str: """(``str`` or ``None``) The header value.""" return self._header_value @property - def parsed(self): + def parsed(self) -> None: """ (``list`` or ``None``) Parsed form of the header. @@ -5224,7 +6226,43 @@ def parsed(self): return self._parsed - def __add__(self, other): + @overload + def __add__( + self, other: AcceptLanguageValidHeader + ) -> AcceptLanguageValidHeader: ... + + @overload + def __add__( + self, + other: ( + AcceptLanguageInvalidHeader | AcceptLanguageNoHeader | Literal[""] | None + ), + ) -> AcceptLanguageNoHeader: ... + + @overload + def __add__( + self, + other: ( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptLanguageValidHeader | AcceptLanguageNoHeader: ... + + def __add__( + self, + other: ( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptLanguageValidHeader | AcceptLanguageNoHeader: """ Add to header, creating a new header object. @@ -5261,7 +6299,43 @@ def __add__(self, other): instance=self, other=other ) - def __radd__(self, other): + @overload + def __radd__( + self, other: AcceptLanguageValidHeader + ) -> AcceptLanguageValidHeader: ... + + @overload + def __radd__( + self, + other: ( + AcceptLanguageInvalidHeader | AcceptLanguageNoHeader | Literal[""] | None + ), + ) -> AcceptLanguageNoHeader: ... + + @overload + def __radd__( + self, + other: ( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptLanguageValidHeader | AcceptLanguageNoHeader: ... + + def __radd__( + self, + other: ( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> AcceptLanguageValidHeader | AcceptLanguageNoHeader: """ Add to header, creating a new header object. @@ -5272,20 +6346,21 @@ def __radd__(self, other): instance=self, other=other, instance_on_the_right=True ) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" # We do not display the header_value, as it is untrusted input. The # header_value could always be easily obtained from the .header_value # property. - def __str__(self): + def __str__(self) -> str: """Return the ``str`` ``''``.""" return "" def _add_instance_and_non_accept_language_type( - self, instance, other, instance_on_the_right=False - ): + self, instance: Self, other: object, instance_on_the_right: bool = False + ) -> AcceptLanguageValidHeader | AcceptLanguageNoHeader: + if not other: return AcceptLanguageNoHeader() @@ -5297,7 +6372,39 @@ def _add_instance_and_non_accept_language_type( return AcceptLanguageNoHeader() -def create_accept_language_header(header_value): +@overload +def create_accept_language_header( + header_value: AcceptLanguageValidHeader | Literal[""], +) -> AcceptLanguageValidHeader: ... + + +@overload +def create_accept_language_header( + header_value: AcceptLanguageNoHeader | None, +) -> AcceptLanguageNoHeader: ... + + +@overload +def create_accept_language_header( + header_value: AcceptLanguageInvalidHeader, +) -> AcceptLanguageInvalidHeader: ... + + +@overload +def create_accept_language_header( + header_value: str, +) -> AcceptLanguageValidHeader | AcceptLanguageInvalidHeader: ... + + +@overload +def create_accept_language_header( + header_value: _AnyAcceptLanguageHeader | str | None, +) -> _AnyAcceptLanguageHeader: ... + + +def create_accept_language_header( + header_value: _AnyAcceptLanguageHeader | str | None, +) -> _AnyAcceptLanguageHeader: """ Create an object representing the ``Accept-Language`` header in a request. @@ -5322,7 +6429,7 @@ def create_accept_language_header(header_value): return AcceptLanguageInvalidHeader(header_value=header_value) -def accept_language_property(): +def accept_language_property() -> _AcceptLanguageProperty: doc = """ Property representing the ``Accept-Language`` header. @@ -5336,14 +6443,24 @@ def accept_language_property(): ENVIRON_KEY = "HTTP_ACCEPT_LANGUAGE" - def fget(request): + def fget(request: BaseRequest) -> _AnyAcceptLanguageHeader: """Get an object representing the header in the request.""" return create_accept_language_header( header_value=request.environ.get(ENVIRON_KEY) ) - def fset(request, value): + def fset( + request: BaseRequest, + value: ( + _AnyAcceptLanguageHeader + | SupportsItems[str, float] + | ListOrTuple[str | tuple[str, float] | list[Any]] + | _SupportsStr + | str + | None + ), + ) -> None: """ Set the corresponding key in the request environ. @@ -5372,7 +6489,7 @@ def fset(request, value): header_value = AcceptLanguage._python_value_to_header_str(value=value) request.environ[ENVIRON_KEY] = header_value - def fdel(request): + def fdel(request: BaseRequest) -> None: """Delete the corresponding key from the request environ.""" try: del request.environ[ENVIRON_KEY] diff --git a/src/webob/byterange.py b/src/webob/byterange.py index 823bd499..e142e12f 100644 --- a/src/webob/byterange.py +++ b/src/webob/byterange.py @@ -1,4 +1,12 @@ +from __future__ import annotations + import re +from typing import TYPE_CHECKING, overload + +if TYPE_CHECKING: + from collections.abc import Iterator + + from typing_extensions import Self __all__ = ["Range", "ContentRange"] @@ -11,12 +19,18 @@ class Range: Represents the Range header. """ - def __init__(self, start, end): + @overload + def __init__(self, start: None, end: None) -> None: ... + + @overload + def __init__(self, start: int, end: int | None) -> None: ... + + def __init__(self, start: int | None, end: int | None) -> None: assert end is None or end >= 0, "Bad range end: %r" % end self.start = start self.end = end # non-inclusive - def range_for_length(self, length): + def range_for_length(self, length: int | None) -> tuple[int, int] | None: """ *If* there is only one range, and *if* it is satisfiable by the given length, then return a (start, end) non-inclusive range @@ -26,16 +40,18 @@ def range_for_length(self, length): return None start, end = self.start, self.end if end is None: + assert start is not None end = length if start < 0: start += length if _is_content_range_valid(start, end, length): + assert start is not None stop = min(end, length) return (start, stop) else: return None - def content_range(self, length): + def content_range(self, length: int | None) -> ContentRange | None: """ Works like range_for_length; returns None or a ContentRange object @@ -50,23 +66,24 @@ def content_range(self, length): return None return ContentRange(range[0], range[1], length) - def __str__(self): + def __str__(self) -> str: s, e = self.start, self.end if e is None: + assert s is not None r = "bytes=%s" % s if s >= 0: r += "-" return r return f"bytes={s}-{e - 1}" - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} bytes {self.start!r}-{self.end!r}>" - def __iter__(self): + def __iter__(self) -> Iterator[int | None]: return iter((self.start, self.end)) @classmethod - def parse(cls, header): + def parse(cls, header: str | None) -> Self | None: """ Parse the header; may return None if header is invalid """ @@ -93,28 +110,35 @@ class ContentRange: can be ``*`` (represented as None in the attributes). """ - def __init__(self, start, stop, length): + @overload + def __init__(self, start: None, stop: None, length: int | None) -> None: ... + + @overload + def __init__(self, start: int, stop: int, length: int | None) -> None: ... + + def __init__(self, start: int | None, stop: int | None, length: int | None) -> None: if not _is_content_range_valid(start, stop, length): raise ValueError(f"Bad start:stop/length: {start!r}-{stop!r}/{length!r}") self.start = start self.stop = stop # this is python-style range end (non-inclusive) self.length = length - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} {self}>" - def __str__(self): + def __str__(self) -> str: if self.length is None: - length = "*" + length: str | int = "*" else: length = self.length if self.start is None: assert self.stop is None return "bytes */%s" % length + assert self.stop is not None stop = self.stop - 1 # from non-inclusive to HTTP-style return f"bytes {self.start}-{stop}/{length}" - def __iter__(self): + def __iter__(self) -> Iterator[int | None]: """ Mostly so you can unpack this, like: @@ -123,29 +147,38 @@ def __iter__(self): return iter([self.start, self.stop, self.length]) @classmethod - def parse(cls, value): + def parse(cls, value: str | None) -> Self | None: """ Parse the header. May return None if it cannot parse. """ m = _rx_content_range.match(value or "") if not m: return None - s, e, l = m.groups() - if s: - s = int(s) - e = int(e) + 1 - l = l and int(l) + s_str, e_str, l_str = m.groups() + if s_str: + s = int(s_str) + e = int(e_str) + 1 + else: + s = None + e = None + l = int(l_str) if l_str else None if not _is_content_range_valid(s, e, l, response=True): return None - return cls(s, e, l) + return cls(s, e, l) # type: ignore[arg-type] + +def _is_content_range_valid( + start: int | None, stop: int | None, length: int | None, response: bool = False +) -> bool: -def _is_content_range_valid(start, stop, length, response=False): if (start is None) != (stop is None): return False - elif start is None: + + if start is None: return length is None or length >= 0 - elif length is None: + + assert stop is not None + if length is None: return 0 <= start < stop elif start >= stop: return False diff --git a/src/webob/cachecontrol.py b/src/webob/cachecontrol.py index 524cbc48..42345236 100644 --- a/src/webob/cachecontrol.py +++ b/src/webob/cachecontrol.py @@ -2,10 +2,34 @@ Represents the Cache-Control header """ +from __future__ import annotations + import re +from typing import TYPE_CHECKING, Any, Generic, Literal, overload + +if TYPE_CHECKING: + from collections.abc import Callable + + from _typeshed import SupportsItems + from typing_extensions import Self, TypeAlias, TypeVar + + _T = TypeVar("_T") + _DefaultT = TypeVar("_DefaultT", default=None) + _NoneLiteral = TypeVar("_NoneLiteral", default=None) + _ScopeT = TypeVar( + "_ScopeT", Literal["request"], Literal["response"], None, default=None + ) + _ScopeT2 = TypeVar("_ScopeT2", Literal["request"], Literal["response"], None) +else: + from typing import TypeVar + _T = TypeVar("_T") + _DefaultT = TypeVar("_DefaultT") + _NoneLiteral = TypeVar("_NoneLiteral") + _ScopeT = TypeVar("_ScopeT") -class UpdateDict(dict): + +class UpdateDict(dict[str, Any]): """ Dict that has a callback on all updates """ @@ -13,10 +37,10 @@ class UpdateDict(dict): # these are declared as class attributes so that # we don't need to override constructor just to # set some defaults - updated = None - updated_args = None + updated: Callable[..., Any] | None = None + updated_args: tuple[Any, ...] | None = None - def _updated(self): + def _updated(self) -> None: """ Assign to new_dict.updated to track updates """ @@ -27,59 +51,94 @@ def _updated(self): args = (self,) updated(*args) - def __setitem__(self, key, item): - dict.__setitem__(self, key, item) - self._updated() + # NOTE: These wrappers are supposed to be transparent, so let's + # not bother copying the type annotations + if not TYPE_CHECKING: - def __delitem__(self, key): - dict.__delitem__(self, key) - self._updated() + def __setitem__(self, key, item): + dict.__setitem__(self, key, item) + self._updated() - def clear(self): - dict.clear(self) - self._updated() + def __delitem__(self, key): + dict.__delitem__(self, key) + self._updated() - def update(self, *args, **kw): - dict.update(self, *args, **kw) - self._updated() + def clear(self): + dict.clear(self) + self._updated() - def setdefault(self, key, value=None): - val = dict.setdefault(self, key, value) - if val is value: + def update(self, *args, **kw): + dict.update(self, *args, **kw) self._updated() - return val - def pop(self, *args): - v = dict.pop(self, *args) - self._updated() - return v + def setdefault(self, key, value=None): + val = dict.setdefault(self, key, value) + if val is value: + self._updated() + return val + + def pop(self, *args): + v = dict.pop(self, *args) + self._updated() + return v - def popitem(self): - v = dict.popitem(self) - self._updated() - return v + def popitem(self): + v = dict.popitem(self) + self._updated() + return v token_re = re.compile(r'([a-zA-Z][a-zA-Z_-]*)\s*(?:=(?:"([^"]*)"|([^ \t",;]*)))?') need_quote_re = re.compile(r"[^a-zA-Z0-9._-]") -class exists_property: +class exists_property(Generic[_ScopeT]): """ Represents a property that either is listed in the Cache-Control header, or is not listed (has no value) """ - def __init__(self, prop, type=None): + def __init__( + self, prop: str, type: _ScopeT = None # type: ignore[assignment] + ) -> None: self.prop = prop - self.type = type - - def __get__(self, obj, type=None): + self.type: _ScopeT = type + + @overload + def __get__( + self, obj: None, type: type[CacheControl[Any]] | None = None + ) -> Self: ... + + @overload + def __get__( + self: exists_property[None], + obj: CacheControl[Any], + type: type[CacheControl[Any]] | None = None, + ) -> bool: ... + + @overload + def __get__( + self, obj: CacheControl[_ScopeT], type: type[CacheControl[Any]] | None = None + ) -> bool: ... + + def __get__( + self, obj: CacheControl[Any] | None, type: type[CacheControl[Any]] | None = None + ) -> Self | bool: if obj is None: return self return self.prop in obj.properties - def __set__(self, obj, value): + @overload + def __set__( + self: exists_property[None], obj: CacheControl[Any], value: bool | None + ) -> None: ... + + @overload + def __set__( + self, obj: CacheControl[_ScopeT] | CacheControl[None], value: bool | None + ) -> None: ... + + def __set__(self, obj: CacheControl[Any], value: bool | None) -> None: if self.type is not None and self.type != obj.type: raise AttributeError( "The property %s only applies to %s Cache-Control" @@ -92,24 +151,56 @@ def __set__(self, obj, value): if self.prop in obj.properties: del obj.properties[self.prop] - def __delete__(self, obj): + @overload + def __delete__(self: exists_property[None], obj: CacheControl[Any]) -> None: ... + + @overload + def __delete__(self, obj: CacheControl[_ScopeT] | CacheControl[None]) -> None: ... + + def __delete__(self, obj: CacheControl[Any]) -> None: self.__set__(obj, False) -class value_property: +class value_property(Generic[_T, _DefaultT, _NoneLiteral, _ScopeT]): """ Represents a property that has a value in the Cache-Control header. When no value is actually given, the value of self.none is returned. """ - def __init__(self, prop, default=None, none=None, type=None): + def __init__( + self, + prop: str, + default: _DefaultT = None, # type: ignore[assignment] + none: _NoneLiteral = None, # type: ignore[assignment] + type: _ScopeT = None, # type: ignore[assignment] + ) -> None: self.prop = prop - self.default = default - self.none = none - self.type = type + self.default: _DefaultT = default + self.none: _NoneLiteral = none + self.type: _ScopeT = type + + @overload + def __get__( + self, obj: None, type: type[CacheControl[Any]] | None = None + ) -> Self: ... + + @overload + def __get__( + self: value_property[_T, _DefaultT, _NoneLiteral, None], + obj: CacheControl[Any], + type: type[CacheControl[Any]] | None = None, + ) -> _T | _DefaultT | _NoneLiteral: ... + + @overload + def __get__( + self, obj: CacheControl[_ScopeT], type: type[CacheControl[Any]] | None = None + ) -> _T | _DefaultT | _NoneLiteral: ... + + def __get__( + self, obj: CacheControl[Any] | None, type: type[CacheControl[Any]] | None = None + ) -> Self | _T | _DefaultT | _NoneLiteral: - def __get__(self, obj, type=None): if obj is None: return self if self.prop in obj.properties: @@ -117,11 +208,26 @@ def __get__(self, obj, type=None): if value is None: return self.none else: - return value + return value # type: ignore[no-any-return] else: return self.default - def __set__(self, obj, value): + @overload + def __set__( + self: value_property[_T, _DefaultT, _NoneLiteral, None], + obj: CacheControl[Any], + value: _T | _DefaultT | Literal[True] | None, + ) -> None: ... + + @overload + def __set__( + self, obj: CacheControl[_ScopeT], value: _T | _DefaultT | Literal[True] | None + ) -> None: ... + + def __set__( + self, obj: CacheControl[Any], value: _T | _DefaultT | Literal[True] | None + ) -> None: + if self.type is not None and self.type != obj.type: raise AttributeError( "The property %s only applies to %s Cache-Control" @@ -135,12 +241,20 @@ def __set__(self, obj, value): else: obj.properties[self.prop] = value - def __delete__(self, obj): + @overload + def __delete__( + self: value_property[_T, _DefaultT, _NoneLiteral, None], obj: CacheControl[Any] + ) -> None: ... + + @overload + def __delete__(self, obj: CacheControl[_ScopeT] | CacheControl[None]) -> None: ... + + def __delete__(self, obj: CacheControl[Any]) -> None: if self.prop in obj.properties: del obj.properties[self.prop] -class CacheControl: +class CacheControl(Generic[_ScopeT]): """ Represents the Cache-Control header. @@ -149,20 +263,57 @@ class CacheControl: only apply to requests or responses). """ + # NOTE: This only exists when accessed through Response/BaseRequest + header_value: str + update_dict = UpdateDict - def __init__(self, properties, type): + def __init__(self, properties: dict[str, Any], type: _ScopeT) -> None: self.properties = properties - self.type = type + self.type: _ScopeT = type + @overload + @classmethod + def parse( + cls, + header: str, + updates_to: Callable[[dict[str, Any]], Any] | None = None, + type: None = None, + ) -> CacheControl[None]: ... + + @overload @classmethod - def parse(cls, header, updates_to=None, type=None): + def parse( + cls, + header: str, + updates_to: Callable[[dict[str, Any]], Any] | None, + type: _ScopeT2, + ) -> CacheControl[_ScopeT2]: ... + + @overload + @classmethod + def parse( + cls, + header: str, + updates_to: Callable[[dict[str, Any]], Any] | None = None, + *, + type: _ScopeT2, + ) -> CacheControl[_ScopeT2]: ... + + @classmethod + def parse( + cls, + header: str, + updates_to: Callable[[dict[str, Any]], Any] | None = None, + type: Any = None, + ) -> CacheControl[Any]: """ Parse the header, returning a CacheControl object. The object is bound to the request or response object ``updates_to``, if that is given. """ + props: dict[str, Any] if updates_to: props = cls.update_dict() props.updated = updates_to @@ -179,46 +330,57 @@ def parse(cls, header, updates_to=None, type=None): props[name] = value obj = cls(props, type=type) if updates_to: + assert isinstance(props, cls.update_dict) props.updated_args = (obj,) return obj - def __repr__(self): + def __repr__(self) -> str: return "" % str(self) # Request values: # no-cache shared (below) # no-store shared (below) # max-age shared (below) + max_stale: value_property[int, None, Literal["*"], Literal["request"]] max_stale = value_property("max-stale", none="*", type="request") + min_fresh: value_property[int, None, None, Literal["request"]] min_fresh = value_property("min-fresh", type="request") # no-transform shared (below) only_if_cached = exists_property("only-if-cached", type="request") # Response values: public = exists_property("public", type="response") + private: value_property[str, None, Literal["*"], Literal["response"]] private = value_property("private", none="*", type="response") + no_cache: value_property[str, None, Literal["*"], None] no_cache = value_property("no-cache", none="*") no_store = exists_property("no-store") no_transform = exists_property("no-transform") must_revalidate = exists_property("must-revalidate", type="response") proxy_revalidate = exists_property("proxy-revalidate", type="response") + max_age: value_property[int, None, Literal[-1], None] max_age = value_property("max-age", none=-1) + s_maxage: value_property[int, None, None, Literal["response"]] s_maxage = value_property("s-maxage", type="response") s_max_age = s_maxage + stale_while_revalidate: value_property[int, None, None, Literal["response"]] stale_while_revalidate = value_property("stale-while-revalidate", type="response") + stale_if_error: value_property[int, None, None, Literal["response"]] stale_if_error = value_property("stale-if-error", type="response") - def __str__(self): + def __str__(self) -> str: return serialize_cache_control(self.properties) - def copy(self): + def copy(self) -> Self: """ Returns a copy of this object. """ return self.__class__(self.properties.copy(), type=self.type) -def serialize_cache_control(properties): +def serialize_cache_control( + properties: SupportsItems[str, Any] | CacheControl[Any], +) -> str: if isinstance(properties, CacheControl): properties = properties.properties parts = [] @@ -231,3 +393,7 @@ def serialize_cache_control(properties): value = '"%s"' % value parts.append(f"{name}={value}") return ", ".join(parts) + + +RequestCacheControl: TypeAlias = CacheControl[Literal["request"]] +ResponseCacheControl: TypeAlias = CacheControl[Literal["response"]] diff --git a/src/webob/client.py b/src/webob/client.py index 1b32d9e7..5593deb7 100644 --- a/src/webob/client.py +++ b/src/webob/client.py @@ -1,17 +1,27 @@ +from __future__ import annotations + import errno import re -import sys try: - import httplib + import httplib # type: ignore except ImportError: import http.client as httplib import socket +from typing import TYPE_CHECKING, ClassVar from urllib.parse import quote as url_quote from webob import exc +if TYPE_CHECKING: + from collections.abc import Iterable + from http.client import HTTPConnection, HTTPMessage, HTTPSConnection + + from _typeshed.wsgi import StartResponse, WSGIEnvironment + + from webob.response import Response + __all__ = ["send_request_app", "SendRequest"] @@ -36,13 +46,16 @@ class SendRequest: def __init__( self, - HTTPConnection=httplib.HTTPConnection, - HTTPSConnection=httplib.HTTPSConnection, - ): + HTTPConnection: type[HTTPConnection] = httplib.HTTPConnection, + HTTPSConnection: type[HTTPSConnection] = httplib.HTTPSConnection, + ) -> None: self.HTTPConnection = HTTPConnection self.HTTPSConnection = HTTPSConnection - def __call__(self, environ, start_response): + def __call__( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: + scheme = environ["wsgi.url_scheme"] if scheme == "http": @@ -69,7 +82,7 @@ def __call__(self, environ, start_response): environ["SERVER_PORT"] = port kw = {} - if "webob.client.timeout" in environ and self._timeout_supported(ConnClass): + if "webob.client.timeout" in environ: kw["timeout"] = environ["webob.client.timeout"] conn = ConnClass("%(SERVER_NAME)s:%(SERVER_PORT)s" % environ, **kw) headers = {} @@ -100,6 +113,7 @@ def __call__(self, environ, start_response): if environ.get("CONTENT_TYPE"): headers["Content-Type"] = environ["CONTENT_TYPE"] + resp: Response if not path.startswith("/"): path = "/" + path try: @@ -142,16 +156,16 @@ def __call__(self, environ, start_response): # Remove these headers from response (specify lower case header # names): - filtered_headers = ("transfer-encoding",) + filtered_headers: ClassVar[tuple[str, ...]] = ("transfer-encoding",) MULTILINE_RE = re.compile(r"\r?\n\s*") - def parse_headers(self, message): + def parse_headers(self, message: HTTPMessage) -> list[tuple[str, str]]: """ Turn a Message object into a list of WSGI-style headers. """ - headers_out = [] - headers = message._headers + headers_out: list[tuple[str, str]] = [] + headers = message._headers # type: ignore[attr-defined] for full_header in headers: if not full_header: # pragma: no cover @@ -191,19 +205,10 @@ def parse_headers(self, message): return headers_out - def _timeout_supported(self, ConnClass): - if sys.version_info < (2, 7) and ConnClass in ( - httplib.HTTPConnection, - httplib.HTTPSConnection, - ): # pragma: no cover - return False - - return True - -send_request_app = SendRequest() +send_request_app: SendRequest = SendRequest() -_e_refused = (errno.ECONNREFUSED,) +_e_refused: tuple[int, ...] = (errno.ECONNREFUSED,) if hasattr(errno, "ENODATA"): # pragma: no cover _e_refused += (errno.ENODATA,) diff --git a/src/webob/compat.py b/src/webob/compat.py index 55fbef9e..b840a0b8 100644 --- a/src/webob/compat.py +++ b/src/webob/compat.py @@ -1,5 +1,4 @@ # flake8: noqa - import cgi from cgi import FieldStorage as _cgi_FieldStorage, parse_header from html import escape @@ -7,111 +6,117 @@ import sys import tempfile import types +from typing import TYPE_CHECKING # Various different FieldStorage work-arounds required on Python 3.x class cgi_FieldStorage(_cgi_FieldStorage): # pragma: no cover - def __repr__(self): - """monkey patch for FieldStorage.__repr__ - - Unbelievably, the default __repr__ on FieldStorage reads - the entire file content instead of being sane about it. - This is a simple replacement that doesn't do that - """ - - if self.file: - return f"FieldStorage({self.name!r}, {self.filename!r})" - - return f"FieldStorage({self.name!r}, {self.filename!r}, {self.value!r})" - - # Work around https://bugs.python.org/issue27777 - def make_file(self): - if self._binary_file or self.length >= 0: - return tempfile.TemporaryFile("wb+") - else: - return tempfile.TemporaryFile("w+", encoding=self.encoding, newline="\n") - - # Work around http://bugs.python.org/issue23801 - # This is taken exactly from Python 3.5's cgi.py module - def read_multi(self, environ, keep_blank_values, strict_parsing): - """Internal: read a part that is itself multipart.""" - ib = self.innerboundary - - if not cgi.valid_boundary(ib): - raise ValueError(f"Invalid boundary in multipart form: {ib!r}") - self.list = [] - - if self.qs_on_post: - query = cgi.urllib.parse.parse_qsl( - self.qs_on_post, - self.keep_blank_values, - self.strict_parsing, - encoding=self.encoding, - errors=self.errors, - ) - - for key, value in query: - self.list.append(cgi.MiniFieldStorage(key, value)) - - klass = self.FieldStorageClass or self.__class__ - first_line = self.fp.readline() # bytes - - if not isinstance(first_line, bytes): - raise ValueError( - f"{self.fp} should return bytes, got {type(first_line).__name__}" - ) - self.bytes_read += len(first_line) - - # Ensure that we consume the file until we've hit our innerboundary - - while first_line.strip() != (b"--" + self.innerboundary) and first_line: - first_line = self.fp.readline() + # NOTE: No point in adding type checking for these workarounds + if not TYPE_CHECKING: + + def __repr__(self): + """monkey patch for FieldStorage.__repr__ + + Unbelievably, the default __repr__ on FieldStorage reads + the entire file content instead of being sane about it. + This is a simple replacement that doesn't do that + """ + + if self.file: + return f"FieldStorage({self.name!r}, {self.filename!r})" + + return f"FieldStorage({self.name!r}, {self.filename!r}, {self.value!r})" + + # Work around https://bugs.python.org/issue27777 + def make_file(self): + if self._binary_file or self.length >= 0: + return tempfile.TemporaryFile("wb+") + else: + return tempfile.TemporaryFile( + "w+", encoding=self.encoding, newline="\n" + ) + + # Work around http://bugs.python.org/issue23801 + # This is taken exactly from Python 3.5's cgi.py module + def read_multi(self, environ, keep_blank_values, strict_parsing): + """Internal: read a part that is itself multipart.""" + ib = self.innerboundary + + if not cgi.valid_boundary(ib): + raise ValueError(f"Invalid boundary in multipart form: {ib!r}") + self.list = [] + + if self.qs_on_post: + query = cgi.urllib.parse.parse_qsl( + self.qs_on_post, + self.keep_blank_values, + self.strict_parsing, + encoding=self.encoding, + errors=self.errors, + ) + + for key, value in query: + self.list.append(cgi.MiniFieldStorage(key, value)) + + klass = self.FieldStorageClass or self.__class__ + first_line = self.fp.readline() # bytes + + if not isinstance(first_line, bytes): + raise ValueError( + f"{self.fp} should return bytes, got {type(first_line).__name__}" + ) self.bytes_read += len(first_line) - while True: - parser = cgi.FeedParser() - hdr_text = b"" + # Ensure that we consume the file until we've hit our innerboundary + + while first_line.strip() != (b"--" + self.innerboundary) and first_line: + first_line = self.fp.readline() + self.bytes_read += len(first_line) while True: - data = self.fp.readline() - hdr_text += data + parser = cgi.FeedParser() + hdr_text = b"" - if not data.strip(): - break + while True: + data = self.fp.readline() + hdr_text += data - if not hdr_text: - break - # parser takes strings, not bytes - self.bytes_read += len(hdr_text) - parser.feed(hdr_text.decode(self.encoding, self.errors)) - headers = parser.close() - # Some clients add Content-Length for part headers, ignore them - - if "content-length" in headers: - filename = None - - if "content-disposition" in self.headers: - cdisp, pdict = parse_header(self.headers["content-disposition"]) - - if "filename" in pdict: - filename = pdict["filename"] - - if filename is None: - del headers["content-length"] - part = klass( - self.fp, - headers, - ib, - environ, - keep_blank_values, - strict_parsing, - self.limit - self.bytes_read, - self.encoding, - self.errors, - ) - self.bytes_read += part.bytes_read - self.list.append(part) - - if part.done or self.bytes_read >= self.length > 0: - break - self.skip_lines() + if not data.strip(): + break + + if not hdr_text: + break + # parser takes strings, not bytes + self.bytes_read += len(hdr_text) + parser.feed(hdr_text.decode(self.encoding, self.errors)) + headers = parser.close() + # Some clients add Content-Length for part headers, ignore them + + if "content-length" in headers: + filename = None + + if "content-disposition" in self.headers: + cdisp, pdict = parse_header(self.headers["content-disposition"]) + + if "filename" in pdict: + filename = pdict["filename"] + + if filename is None: + del headers["content-length"] + part = klass( + self.fp, + headers, + ib, + environ, + keep_blank_values, + strict_parsing, + self.limit - self.bytes_read, + self.encoding, + self.errors, + ) + self.bytes_read += part.bytes_read + self.list.append(part) + + if part.done or self.bytes_read >= self.length > 0: + break + self.skip_lines() diff --git a/src/webob/cookies.py b/src/webob/cookies.py index 4c141e28..6c433db6 100644 --- a/src/webob/cookies.py +++ b/src/webob/cookies.py @@ -1,6 +1,16 @@ +from __future__ import annotations + import base64 import binascii -from collections.abc import MutableMapping +from collections.abc import ( + Callable, + Collection, + ItemsView, + Iterator, + KeysView, + MutableMapping, + ValuesView, +) from datetime import date, datetime, timedelta import hashlib import hmac @@ -8,10 +18,38 @@ import re import string import time +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload import warnings from webob.util import bytes_, text_ +if TYPE_CHECKING: + from _hashlib import HASH + from _typeshed.wsgi import WSGIEnvironment + from typing_extensions import TypeAlias, TypedDict, TypeGuard + + from webob.request import BaseRequest + from webob.response import Response + from webob.types import AsymmetricProperty, SymmetricProperty + + _T = TypeVar("_T") + _MorselValueT = TypeVar("_MorselValueT", bytes, bool, "bytes | None", "bool | None") + # we accept both the official spelling and the one used in the WebOb docs + # the implementation compares after lower() so technically there are more + # valid spellings, but it seems more natural to support these two spellings + _SameSitePolicy: TypeAlias = Literal[ + "Strict", "Lax", "None", "strict", "lax", "none" + ] + + class _Serializer(Protocol): + def loads(self, appstruct: Any, /) -> bytes: ... + def dumps(self, bstruct: bytes, /) -> Any: ... + + class _Renamer(TypedDict): + name: bytes + quoter: Callable[[bytes], bytes] + + __all__ = [ "Cookie", "CookieProfile", @@ -29,22 +67,23 @@ SAMESITE_VALIDATION = True -class RequestCookies(MutableMapping): +class RequestCookies(MutableMapping[str, str]): _cache_key = "webob._parsed_cookies" - def __init__(self, environ): + def __init__(self, environ: WSGIEnvironment) -> None: self._environ = environ @property - def _cache(self): + def _cache(self) -> dict[str, str]: env = self._environ header = env.get("HTTP_COOKIE", "") + cache: dict[str, str] cache, cache_header = env.get(self._cache_key, ({}, None)) if cache_header == header: return cache - def d(b): + def d(b: bytes) -> str: return b.decode("utf8") cache = {d(k): d(v) for k, v in parse_cookie(header)} @@ -52,7 +91,7 @@ def d(b): return cache - def _mutate_header(self, name, value): + def _mutate_header(self, name: str, value: str | None) -> bool: header = self._environ.get("HTTP_COOKIE") had_header = header is not None header = header or "" @@ -95,7 +134,7 @@ def _mutate_header(self, name, value): return found - def _valid_cookie_name(self, name): + def _valid_cookie_name(self, name: str) -> str: if not isinstance(name, str): raise TypeError(name, "cookie name must be a string") @@ -109,7 +148,7 @@ def _valid_cookie_name(self, name): return name - def __setitem__(self, name, value): + def __setitem__(self, name: str, value: str) -> None: name = self._valid_cookie_name(name) if not isinstance(value, str): @@ -117,51 +156,57 @@ def __setitem__(self, name, value): self._mutate_header(name, value) - def __getitem__(self, name): + def __getitem__(self, name: str) -> str: return self._cache[name] - def get(self, name, default=None): + @overload + def get(self, name: str, default: None = None) -> str | None: ... + + @overload + def get(self, name: str, default: _T) -> str | _T: ... + + def get(self, name: str, default: Any = None) -> str | Any: return self._cache.get(name, default) - def __delitem__(self, name): + def __delitem__(self, name: str) -> None: name = self._valid_cookie_name(name) found = self._mutate_header(name, None) if not found: raise KeyError(name) - def keys(self): + def keys(self) -> KeysView[str]: return self._cache.keys() - def values(self): + def values(self) -> ValuesView[str]: return self._cache.values() - def items(self): + def items(self) -> ItemsView[str, str]: return self._cache.items() - def __contains__(self, name): + def __contains__(self, name: object) -> bool: return name in self._cache - def __iter__(self): + def __iter__(self) -> Iterator[str]: return self._cache.__iter__() - def __len__(self): + def __len__(self) -> int: return len(self._cache) - def clear(self): + def clear(self) -> None: self._environ["HTTP_COOKIE"] = "" - def __repr__(self): + def __repr__(self) -> str: return f"" -class Cookie(dict): - def __init__(self, input=None): +class Cookie(dict[bytes, "Morsel"]): + def __init__(self, input: str | None = None) -> None: if input: self.load(input) - def load(self, data): - morsel = {} + def load(self, data: str) -> None: + morsel: Morsel | dict[bytes, bytes] = {} for key, val in _parse_cookie(data): if key.lower() in _c_keys: @@ -169,7 +214,7 @@ def load(self, data): else: morsel = self.add(key, val) - def add(self, key, val): + def add(self, key: str | bytes, val: str | bytes) -> Morsel | dict[bytes, bytes]: if not isinstance(key, bytes): key = key.encode("ascii", "replace") @@ -180,31 +225,31 @@ def add(self, key, val): return r - __setitem__ = add + __setitem__ = add # type: ignore[assignment] - def serialize(self, full=True): + def serialize(self, full: bool = True) -> str: return "; ".join(m.serialize(full) for m in self.values()) - def values(self): + def values(self) -> list[Morsel]: # type: ignore[override] return [m for _, m in sorted(self.items())] __str__ = serialize - def __repr__(self): + def __repr__(self) -> str: return "<{}: [{}]>".format( self.__class__.__name__, ", ".join(map(repr, self.values())), ) -def _parse_cookie(data): - data = data.encode("latin-1") +def _parse_cookie(data: str) -> Iterator[tuple[bytes, bytes]]: + data_bytes = data.encode("latin-1") - for key, val in _rx_cookie.findall(data): + for key, val in _rx_cookie.findall(data_bytes): yield key, _unquote(val) -def parse_cookie(data): +def parse_cookie(data: str) -> Iterator[tuple[bytes, bytes]]: """ Parse cookies ignoring anything except names and values """ @@ -212,14 +257,26 @@ def parse_cookie(data): return ((k, v) for k, v in _parse_cookie(data) if _valid_cookie_name(k)) -def cookie_property(key, serialize=lambda v: v): - def fset(self, v): +@overload +def cookie_property(key: bytes) -> SymmetricProperty[bytes | None]: ... + + +@overload +def cookie_property( + key: bytes, serialize: Callable[[_T], _MorselValueT] +) -> AsymmetricProperty[_MorselValueT, _T]: ... + + +def cookie_property( + key: bytes, serialize: Callable[[Any], Any] = lambda v: v +) -> AsymmetricProperty[Any, Any]: + def fset(self: Morsel, v: _T) -> None: self[key] = serialize(v) return property(lambda self: self[key], fset) -def serialize_max_age(v): +def serialize_max_age(v: timedelta | int | str | bytes | None) -> bytes | None: if isinstance(v, timedelta): v = str(v.seconds + v.days * 24 * 60 * 60) elif isinstance(v, int): @@ -228,7 +285,20 @@ def serialize_max_age(v): return bytes_(v) -def serialize_cookie_date(v): +def serialize_cookie_date( + v: ( + datetime + | date + | timedelta + | time._TimeTuple + | time.struct_time + | bytes + | str + | int + | None + ), +) -> bytes | None: + if v is None: return None elif isinstance(v, bytes): @@ -248,7 +318,7 @@ def serialize_cookie_date(v): return bytes_(r % (weekdays[v[6]], months[v[1]]), "ascii") -def serialize_samesite(v): +def serialize_samesite(v: str | bytes) -> bytes: v = bytes_(v) if SAMESITE_VALIDATION: @@ -258,10 +328,10 @@ def serialize_samesite(v): return v -class Morsel(dict): +class Morsel(dict[bytes, "bytes | bool | None"]): __slots__ = ("name", "value") - def __init__(self, name, value): + def __init__(self, name: str | bytes, value: str | bytes) -> None: self.name = bytes_(name, encoding="ascii") self.value = bytes_(value, encoding="ascii") assert _valid_cookie_name(self.name) @@ -276,14 +346,14 @@ def __init__(self, name, value): secure = cookie_property(b"secure", bool) samesite = cookie_property(b"samesite", serialize_samesite) - def __setitem__(self, k, v): + def __setitem__(self, k: str | bytes, v: bytes | bool | None) -> None: k = bytes_(k.lower(), "ascii") if k in _c_keys: dict.__setitem__(self, k, v) - def serialize(self, full=True): - result = [] + def serialize(self, full: bool = True) -> str: + result: list[bytes] = [] add = result.append add(self.name + b"=" + _value_quote(self.value)) @@ -292,11 +362,12 @@ def serialize(self, full=True): v = self[k] if v: + assert isinstance(v, bytes) info = _c_renames[k] name = info["name"] quoter = info["quoter"] add(name + b"=" + quoter(v)) - expires = self[b"expires"] + expires = self.expires if expires: add(b"expires=" + expires) @@ -319,7 +390,7 @@ def serialize(self, full=True): __str__ = serialize - def __repr__(self): + def __repr__(self) -> str: return "<{}: {}={!r}>".format( self.__class__.__name__, text_(self.name), @@ -346,7 +417,7 @@ def __repr__(self): _rx_unquote = re.compile(bytes_(r"\\([0-3][0-7][0-7]|.)", "ascii")) -def _bchr(i): +def _bchr(i: int) -> bytes: return bytes([i]) @@ -357,7 +428,7 @@ def _bchr(i): _b_quote_mark = ord('"') -def _unquote(v): +def _unquote(v: bytes) -> bytes: # assert isinstance(v, bytes) if v and v[0] == v[-1] == _b_quote_mark: @@ -366,7 +437,7 @@ def _unquote(v): return _rx_unquote.sub(_ch_unquote, v) -def _ch_unquote(m): +def _ch_unquote(m: re.Match[bytes]) -> bytes: return _ch_unquote_map[m.group(1)] @@ -403,10 +474,11 @@ def _ch_unquote(m): # this is a map used to escape the values _escape_noop_chars = _allowed_cookie_chars + " " -_escape_map = {chr(i): "\\%03o" % i for i in range(256)} -_escape_map.update(zip(_escape_noop_chars, _escape_noop_chars)) +_escape_map_str = {chr(i): "\\%03o" % i for i in range(256)} +_escape_map_str.update(zip(_escape_noop_chars, _escape_noop_chars)) -_escape_map = {ord(k): bytes_(v, "ascii") for k, v in _escape_map.items()} +_escape_map = {ord(k): bytes_(v, "ascii") for k, v in _escape_map_str.items()} +del _escape_map_str _escape_char = _escape_map.__getitem__ weekdays = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") @@ -428,10 +500,15 @@ def _ch_unquote(m): # This is temporary, until we can remove this from _value_quote -_should_raise = None +_should_raise: bool | None = None -def __warn_or_raise(text, warn_class, to_raise, raise_reason): +def __warn_or_raise( + text: str, + warn_class: type[Warning], + to_raise: type[BaseException], + raise_reason: str, +) -> None: if _should_raise: raise to_raise(raise_reason) @@ -439,7 +516,7 @@ def __warn_or_raise(text, warn_class, to_raise, raise_reason): warnings.warn(text, warn_class, stacklevel=2) -def _value_quote(v): +def _value_quote(v: bytes) -> bytes: # This looks scary, but is simple. We remove all valid characters from the # string, if we end up with leftovers (string is longer than 0, we have # invalid characters in our value) @@ -461,7 +538,7 @@ def _value_quote(v): return v -def _valid_cookie_name(key): +def _valid_cookie_name(key: object) -> TypeGuard[bytes]: return isinstance(key, bytes) and not ( key.translate(None, _valid_token_bytes) # Not explicitly required by RFC6265, may consider removing later: @@ -470,14 +547,14 @@ def _valid_cookie_name(key): ) -def _path_quote(v): +def _path_quote(v: bytes) -> bytes: return b"".join(map(_escape_char, v)) _domain_quote = _path_quote _max_age_quote = _path_quote -_c_renames = { +_c_renames: dict[bytes, _Renamer] = { b"path": {"name": b"Path", "quoter": _path_quote}, b"comment": {"name": b"Comment", "quoter": _value_quote}, b"domain": {"name": b"Domain", "quoter": _domain_quote}, @@ -489,16 +566,16 @@ def _path_quote(v): def make_cookie( - name, - value, - max_age=None, - path="/", - domain=None, - secure=False, - httponly=False, - comment=None, - samesite=None, -): + name: str | bytes, + value: str | bytes | None, + max_age: int | timedelta | None = None, + path: str = "/", + domain: str | None = None, + secure: bool | None = False, + httponly: bool | None = False, + comment: str | None = None, + samesite: _SameSitePolicy | None = None, +) -> str: """ Generate a cookie value. @@ -558,6 +635,7 @@ def make_cookie( # We are deleting the cookie, override max_age and expires + expires: int | str | None if value is None: value = b"" # Note that the max-age value of zero is technically contraspec; @@ -615,10 +693,10 @@ def make_cookie( class JSONSerializer: """A serializer which uses `json.dumps`` and ``json.loads``""" - def dumps(self, appstruct): + def dumps(self, appstruct: Any) -> bytes: return bytes_(json.dumps(appstruct), encoding="utf-8") - def loads(self, bstruct): + def loads(self, bstruct: bytes | str) -> Any: # NB: json.loads raises ValueError if no json object can be decoded # so we don't have to do it explicitly here. @@ -628,13 +706,13 @@ def loads(self, bstruct): class Base64Serializer: """A serializer which uses base64 to encode/decode data""" - def __init__(self, serializer=None): + def __init__(self, serializer: _Serializer | None = None) -> None: if serializer is None: serializer = JSONSerializer() self.serializer = serializer - def dumps(self, appstruct): + def dumps(self, appstruct: Any) -> bytes: """ Given an ``appstruct``, serialize and sign the data. @@ -644,7 +722,7 @@ def dumps(self, appstruct): return base64.urlsafe_b64encode(cstruct) - def loads(self, bstruct): + def loads(self, bstruct: bytes) -> Any: """ Given a ``bstruct`` (a bytestring), verify the signature and then deserialize and return the deserialized value. @@ -688,7 +766,13 @@ class SignedSerializer: """ - def __init__(self, secret, salt, hashalg="sha512", serializer=None): + def __init__( + self, + secret: str | bytes, + salt: str | bytes, + hashalg: str = "sha512", + serializer: _Serializer | None = None, + ) -> None: self.salt = salt self.secret = secret self.hashalg = hashalg @@ -699,7 +783,6 @@ def __init__(self, secret, salt, hashalg="sha512", serializer=None): except UnicodeEncodeError: self.salted_secret = bytes_(salt or "", "utf-8") + bytes_(secret, "utf-8") - self.digestmod = lambda string=b"": hashlib.new(self.hashalg, string) self.digest_size = self.digestmod().digest_size if serializer is None: @@ -707,7 +790,10 @@ def __init__(self, secret, salt, hashalg="sha512", serializer=None): self.serializer = serializer - def dumps(self, appstruct): + def digestmod(self, string: bytes = b"") -> HASH: + return hashlib.new(self.hashalg, string) + + def dumps(self, appstruct: Any) -> bytes: """ Given an ``appstruct``, serialize and sign the data. @@ -718,7 +804,7 @@ def dumps(self, appstruct): return base64.urlsafe_b64encode(sig + cstruct).rstrip(b"=") - def loads(self, bstruct): + def loads(self, bstruct: bytes) -> Any: """ Given a ``bstruct`` (a bytestring), verify the signature and then deserialize and return the deserialized value. @@ -793,15 +879,15 @@ class CookieProfile: def __init__( self, - cookie_name, - secure=False, - max_age=None, - httponly=None, - samesite=None, - path="/", - domains=None, - serializer=None, - ): + cookie_name: str, + secure: bool = False, + max_age: int | timedelta | None = None, + httponly: bool | None = None, + samesite: _SameSitePolicy | None = None, + path: str = "/", + domains: Collection[str] | None = None, + serializer: _Serializer | None = None, + ) -> None: self.cookie_name = cookie_name self.secure = secure self.max_age = max_age @@ -814,14 +900,14 @@ def __init__( serializer = Base64Serializer() self.serializer = serializer - self.request = None + self.request: BaseRequest | None = None - def __call__(self, request): + def __call__(self, request: BaseRequest) -> CookieProfile: """Bind a request to a copy of this instance and return it""" return self.bind(request) - def bind(self, request): + def bind(self, request: BaseRequest) -> CookieProfile: """Bind a request to a copy of this instance and return it""" selfish = CookieProfile( @@ -838,7 +924,7 @@ def bind(self, request): return selfish - def get_value(self): + def get_value(self) -> Any | None: """Looks for a cookie by name in the currently bound request, and returns its value. If the cookie profile is not bound to a request, this method will raise a :exc:`ValueError`. @@ -857,19 +943,20 @@ def get_value(self): try: return self.serializer.loads(bytes_(cookie)) except ValueError: - return None + pass + return None def set_cookies( self, - response, - value, - domains=_default, - max_age=_default, - path=_default, - secure=_default, - httponly=_default, - samesite=_default, - ): + response: Response, + value: Any, + domains: Collection[str] = _default, # type: ignore[assignment] + max_age: int | timedelta | None = _default, # type: ignore[assignment] + path: str = _default, # type: ignore[assignment] + secure: bool = _default, # type: ignore[assignment] + httponly: bool = _default, # type: ignore[assignment] + samesite: _SameSitePolicy | None = _default, # type: ignore[assignment] + ) -> Response: """Set the cookies on a response.""" cookies = self.get_headers( value, @@ -886,14 +973,14 @@ def set_cookies( def get_headers( self, - value, - domains=_default, - max_age=_default, - path=_default, - secure=_default, - httponly=_default, - samesite=_default, - ): + value: Any, + domains: Collection[str] = _default, # type: ignore[assignment] + max_age: int | timedelta | None = _default, # type: ignore[assignment] + path: str = _default, # type: ignore[assignment] + secure: bool = _default, # type: ignore[assignment] + httponly: bool = _default, # type: ignore[assignment] + samesite: _SameSitePolicy | None = _default, # type: ignore[assignment] + ) -> list[tuple[str, str]]: """Retrieve raw headers for setting cookies. Returns a list of headers that should be set for the cookies to @@ -916,7 +1003,16 @@ def get_headers( samesite=samesite, ) - def _get_cookies(self, value, domains, max_age, path, secure, httponly, samesite): + def _get_cookies( + self, + value: Any, + domains: Collection[str] | None, + max_age: int | timedelta | None, + path: str, + secure: bool, + httponly: bool | None, + samesite: _SameSitePolicy | None, + ) -> list[tuple[str, str]]: """Internal function This returns a list of cookies that are valid HTTP Headers. @@ -1049,18 +1145,18 @@ class SignedCookieProfile(CookieProfile): def __init__( self, - secret, - salt, - cookie_name, - secure=False, - max_age=None, - httponly=False, - samesite=None, - path="/", - domains=None, - hashalg="sha512", - serializer=None, - ): + secret: str | bytes, + salt: str | bytes, + cookie_name: str, + secure: bool = False, + max_age: int | timedelta | None = None, + httponly: bool | None = False, + samesite: _SameSitePolicy | None = None, + path: str = "/", + domains: Collection[str] | None = None, + hashalg: str = "sha512", + serializer: _Serializer | None = None, + ) -> None: self.secret = secret self.salt = salt self.hashalg = hashalg @@ -1081,7 +1177,7 @@ def __init__( serializer=signed_serializer, ) - def bind(self, request): + def bind(self, request: BaseRequest) -> SignedCookieProfile: """Bind a request to a copy of this instance and return it""" selfish = SignedCookieProfile( @@ -1100,3 +1196,8 @@ def bind(self, request): selfish.request = request return selfish + + if TYPE_CHECKING: + # NOTE: Since the return type of bind changed, __call__ changes as well + def __call__(self, request: BaseRequest) -> SignedCookieProfile: + pass diff --git a/src/webob/datetime_utils.py b/src/webob/datetime_utils.py index a9825124..a6832e02 100644 --- a/src/webob/datetime_utils.py +++ b/src/webob/datetime_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import calendar from datetime import date, datetime, timedelta, tzinfo from email.utils import formatdate, mktime_tz, parsedate_tz @@ -25,23 +27,23 @@ class _UTC(tzinfo): - def dst(self, dt): + def dst(self, dt: datetime | None) -> timedelta: return timedelta(0) - def utcoffset(self, dt): + def utcoffset(self, dt: datetime | None) -> timedelta: return timedelta(0) - def tzname(self, dt): + def tzname(self, dt: datetime | None) -> str: return "UTC" - def __repr__(self): + def __repr__(self) -> str: return "UTC" -UTC = _UTC() +UTC: _UTC = _UTC() -def timedelta_to_seconds(td): +def timedelta_to_seconds(td: timedelta) -> int: """ Converts a timedelta instance to seconds. """ @@ -49,17 +51,17 @@ def timedelta_to_seconds(td): return td.seconds + (td.days * 24 * 60 * 60) -day = timedelta(days=1) -week = timedelta(weeks=1) -hour = timedelta(hours=1) -minute = timedelta(minutes=1) -second = timedelta(seconds=1) +day: timedelta = timedelta(days=1) +week: timedelta = timedelta(weeks=1) +hour: timedelta = timedelta(hours=1) +minute: timedelta = timedelta(minutes=1) +second: timedelta = timedelta(seconds=1) # Estimate, I know; good enough for expirations -month = timedelta(days=30) -year = timedelta(days=365) +month: timedelta = timedelta(days=30) +year: timedelta = timedelta(days=365) -def parse_date(value): +def parse_date(value: str | bytes | None) -> datetime | None: if not value: return None try: @@ -74,12 +76,23 @@ def parse_date(value): return None - t = mktime_tz(t) + tt = mktime_tz(t) - return datetime.fromtimestamp(t, UTC) + return datetime.fromtimestamp(tt, UTC) -def serialize_date(dt): +def serialize_date( + dt: ( + datetime + | date + | timedelta + | time._TimeTuple + | time.struct_time + | float + | str + | bytes + ), +) -> str: if isinstance(dt, (bytes, str)): return text_(dt) @@ -101,7 +114,7 @@ def serialize_date(dt): return formatdate(dt, usegmt=True) -def parse_date_delta(value): +def parse_date_delta(value: str | bytes | None) -> datetime | None: """ like parse_date, but also handle delta seconds """ @@ -109,14 +122,25 @@ def parse_date_delta(value): if not value: return None try: - value = int(value) + int_value = int(value) except ValueError: return parse_date(value) else: - return _now() + timedelta(seconds=value) - - -def serialize_date_delta(value): + return _now() + timedelta(seconds=int_value) + + +def serialize_date_delta( + value: ( + datetime + | date + | timedelta + | time._TimeTuple + | time.struct_time + | float + | str + | bytes + ), +) -> str: if isinstance(value, (float, int)): return str(int(value)) else: diff --git a/src/webob/dec.py b/src/webob/dec.py index caa1bfbc..b1b05e34 100644 --- a/src/webob/dec.py +++ b/src/webob/dec.py @@ -6,14 +6,62 @@ instantiated request). """ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, overload + +from typing_extensions import Concatenate, Never, ParamSpec, Self, TypeAlias, TypeVar + from webob.exc import HTTPException from webob.request import Request from webob.util import bytes_ +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Mapping + from typing import type_check_only + + from _typeshed.wsgi import StartResponse, WSGIApplication, WSGIEnvironment + + from webob.request import BaseRequest + from webob.response import Response + + _AnyResponse: TypeAlias = "Response | WSGIApplication | str | None" + _RequestHandlerCallable: TypeAlias = Callable[ + Concatenate["_RequestT", _P], _AnyResponse + ] + _RequestHandlerMethod: TypeAlias = Callable[ + Concatenate[Any, "_RequestT", _P], _AnyResponse + ] + _MiddlewareCallable: TypeAlias = Callable[ + Concatenate["_RequestT", "_AppT", _P], _AnyResponse + ] + _MiddlewareMethod: TypeAlias = Callable[ + Concatenate[Any, "_RequestT", "_AppT", _P], _AnyResponse + ] + _RequestHandler: TypeAlias = """( + _RequestHandlerCallable["_RequestT", _P] + | _RequestHandlerMethod["_RequestT", _P] + )""" + _Middleware: TypeAlias = """( + _MiddlewareCallable["_RequestT", "_AppT", _P] + | _MiddlewareMethod["_RequestT", "_AppT", _P] + )""" + +_S = TypeVar("_S") +_AppT = TypeVar("_AppT", bound="WSGIApplication") +_AppT_contra = TypeVar("_AppT_contra", bound="WSGIApplication", contravariant=True) +_RequestT = TypeVar("_RequestT", bound="BaseRequest") +_RequestT_contra = TypeVar( + "_RequestT_contra", bound="BaseRequest", default=Request, contravariant=True +) +_P = ParamSpec("_P") +_P2 = ParamSpec("_P2") + + __all__ = ["wsgify"] -class wsgify: +class wsgify(Generic[_P, _RequestT_contra]): """Turns a request-taking, response-returning function into a WSGI app @@ -77,11 +125,90 @@ def serve_json(req, json_obj): the function). """ - RequestClass = Request + RequestClass: type[_RequestT_contra] = Request # type: ignore[assignment] + func: _RequestHandler[_RequestT_contra, _P] | None + args: tuple[Any, ...] + kwargs: dict[str, Any] + middleware_wraps: WSGIApplication | None + + # NOTE: We disallow passing args/kwargs using this direct API, because + # we can't really make it work as a decorator this way, these + # arguments should only really be used indirectly through the + # middleware decorator, where we can be more type safe + @overload def __init__( - self, func=None, RequestClass=None, args=(), kwargs=None, middleware_wraps=None - ): + self: wsgify[[], Request], + func: _RequestHandler[Request, []] | None = None, + RequestClass: None = None, + args: tuple[()] = (), + kwargs: None = None, + middleware_wraps: None = None, + ) -> None: ... + + @overload + def __init__( + self: wsgify[[], _RequestT_contra], + func: _RequestHandler[_RequestT_contra, []] | None, + RequestClass: type[_RequestT_contra], + args: tuple[()] = (), + kwargs: None = None, + middleware_wraps: None = None, + ) -> None: ... + + @overload + def __init__( + self: wsgify[[], _RequestT_contra], + func: _RequestHandler[_RequestT_contra, []] | None = None, + *, + RequestClass: type[_RequestT_contra], + args: tuple[()] = (), + kwargs: None = None, + middleware_wraps: None = None, + ) -> None: ... + + @overload + def __init__( + self: wsgify[[_AppT_contra], Request], + func: _Middleware[Request, _AppT_contra, []] | None = None, + RequestClass: None = None, + args: tuple[()] = (), + kwargs: None = None, + *, + middleware_wraps: _AppT_contra, + ) -> None: ... + + @overload + def __init__( + self: wsgify[[_AppT_contra], _RequestT_contra], + func: _Middleware[_RequestT_contra, _AppT_contra, []] | None, + RequestClass: type[_RequestT_contra], + args: tuple[()] = (), + kwargs: None = None, + *, + middleware_wraps: _AppT_contra, + ) -> None: ... + + @overload + def __init__( + self: wsgify[[_AppT_contra], _RequestT_contra], + func: _Middleware[_RequestT_contra, _AppT_contra, []] | None = None, + *, + RequestClass: type[_RequestT_contra], + args: tuple[()] = (), + kwargs: None = None, + middleware_wraps: _AppT_contra, + ) -> None: ... + + def __init__( + self, + func: _RequestHandler[Any, ...] | None = None, + RequestClass: type[Any] | None = None, + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + middleware_wraps: Any | None = None, + ) -> None: + self.func = func if RequestClass is not None and RequestClass is not self.RequestClass: @@ -93,18 +220,47 @@ def __init__( self.kwargs = kwargs self.middleware_wraps = middleware_wraps - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} at {id(self)} wrapping {self.func!r}>" - def __get__(self, obj, type=None): + @overload + def __get__( + self, obj: None, type: type[_S] + ) -> _unbound_wsgify[_P, _S, _RequestT_contra]: ... + + @overload + def __get__(self, obj: object, type: type | None = None) -> Self: ... + + def __get__(self, obj: object, type: type | None = None) -> Self: # This handles wrapping methods if hasattr(self.func, "__get__"): + assert self.func is not None return self.clone(self.func.__get__(obj, type)) else: return self - def __call__(self, req, *args, **kw): + @overload + def __call__( + self, env: WSGIEnvironment, /, start_response: StartResponse + ) -> Iterable[bytes]: ... + + @overload + def __call__( + self, + func: _RequestHandler[_RequestT_contra, _P], + /, + ) -> Self: ... + + @overload + def __call__(self, req: _RequestT_contra) -> _AnyResponse: ... + + @overload + def __call__( + self, req: _RequestT_contra, *args: _P.args, **kw: _P.kwargs + ) -> _AnyResponse: ... + + def __call__(self, req: Any, *args: Any, **kw: Any) -> Any: """Call this as a WSGI application or with a request""" func = self.func @@ -127,6 +283,7 @@ def __call__(self, req, *args, **kw): start_response = args[0] req = self.RequestClass(environ) req.response = req.ResponseClass() + resp: Any try: args, kw = self._prepare_args(None, None) resp = self.call_func(req, *args, **kw) @@ -154,7 +311,7 @@ def __call__(self, req, *args, **kw): return self.call_func(req, *args, **kw) - def get(self, url, **kw): + def get(self, url: str, **kw: Any) -> _AnyResponse: """Run a GET request on this application, returning a Response. This creates a request object using the given URL, and any @@ -170,7 +327,18 @@ def get(self, url, **kw): return self(req) - def post(self, url, POST=None, **kw): + def post( + self, + url: str, + POST: ( + str + | bytes + | Mapping[Any, Any] + | Mapping[Any, list[Any] | tuple[Any, ...]] + | None + ) = None, + **kw: Any, + ) -> _AnyResponse: """Run a POST request on this application, returning a Response. The second argument (`POST`) can be the request body (a @@ -188,7 +356,7 @@ def post(self, url, POST=None, **kw): return self(req) - def request(self, url, **kw): + def request(self, url: str, **kw: Any) -> _AnyResponse: """Run a request on this application, returning a Response. This can be used for DELETE, PUT, etc requests. E.g.:: @@ -199,17 +367,23 @@ def request(self, url, **kw): return self(req) - def call_func(self, req, *args, **kwargs): + def call_func( + self, req: _RequestT_contra, *args: _P.args, **kwargs: _P.kwargs + ) -> _AnyResponse: """Call the wrapped function; override this in a subclass to change how the function is called.""" - return self.func(req, *args, **kwargs) + assert self.func is not None + return self.func(req, *args, **kwargs) # type: ignore[arg-type] - def clone(self, func=None, **kw): + # technically this could bind different type vars, but we disallow it for safety + def clone( + self, func: _RequestHandler[_RequestT_contra, _P] | None = None, **kw: Never + ) -> Self: """Creates a copy/clone of this object, but with some parameters rebound """ - kwargs = {} + kwargs: dict[str, Any] = {} if func is not None: kwargs["func"] = func @@ -228,11 +402,65 @@ def clone(self, func=None, **kw): # To match @decorator: @property - def undecorated(self): + def undecorated(self) -> _RequestHandler[_RequestT_contra, _P] | None: return self.func + @overload + @classmethod + def middleware( + cls, + middle_func: None = None, + app: None | _AppT = None, + *args: _P.args, + **kw: _P.kwargs, + ) -> _UnboundMiddleware[_P, _AppT, Any]: ... + + @overload + @classmethod + def middleware( + cls, middle_func: _MiddlewareCallable[_RequestT, _AppT, _P2], app: None = None + ) -> _MiddlewareFactory[_P2, _AppT, _RequestT]: ... + + @overload @classmethod - def middleware(cls, middle_func=None, app=None, **kw): + def middleware( + cls, middle_func: _MiddlewareMethod[_RequestT, _AppT, _P2], app: None = None + ) -> _MiddlewareFactory[_P2, _AppT, _RequestT]: ... + + @overload + @classmethod + def middleware( + cls, + middle_func: _MiddlewareMethod[_RequestT, _AppT, _P2], + app: None = None, + *args: _P2.args, + **kw: _P2.kwargs, + ) -> _MiddlewareFactory[_P2, _AppT, _RequestT]: ... + + @overload + @classmethod + def middleware( + cls, middle_func: _MiddlewareMethod[_RequestT, _AppT, _P2], app: _AppT + ) -> type[wsgify[Concatenate[_AppT, _P2], _RequestT]]: ... + + @overload + @classmethod + def middleware( + cls, + middle_func: _MiddlewareMethod[_RequestT, _AppT, _P2], + app: _AppT, + *args: _P2.args, + **kw: _P2.kwargs, + ) -> type[wsgify[Concatenate[_AppT, _P2], _RequestT]]: ... + + @classmethod + def middleware( + cls, + middle_func: _MiddlewareMethod[Any, Any, ...] | None = None, + app: _AppT | None = None, + *args: Any, + **kw: Any, + ) -> Any: """Creates middleware Use this like:: @@ -278,14 +506,17 @@ def all_caps(req, app): """ if middle_func is None: - return _UnboundMiddleware(cls, app, kw) + return _UnboundMiddleware(cls, app, args, kw) # type: ignore[arg-type] if app is None: - return _MiddlewareFactory(cls, middle_func, kw) + return _MiddlewareFactory(cls, middle_func, args, kw) # type: ignore[arg-type] + + return cls(middle_func, middleware_wraps=app, args=args, kwargs=kw) # type: ignore[call-overload] - return cls(middle_func, middleware_wraps=app, kwargs=kw) + def _prepare_args( + self, args: tuple[Any, ...] | None, kwargs: dict[str, Any] | None + ) -> tuple[tuple[Any, ...], dict[str, Any]]: - def _prepare_args(self, args, kwargs): args = args or self.args kwargs = kwargs or self.kwargs @@ -295,46 +526,116 @@ def _prepare_args(self, args, kwargs): return args, kwargs -class _UnboundMiddleware: +if TYPE_CHECKING: + + @type_check_only + class _unbound_wsgify( + wsgify[_P, _RequestT_contra], Generic[_P, _S, _RequestT_contra] + ): + @overload # type: ignore[override] + def __call__( + self, __self: _S, env: WSGIEnvironment, /, start_response: StartResponse + ) -> Iterable[bytes]: ... + + @overload + def __call__( + self, + __self: _S, + func: _RequestHandler[_RequestT_contra, _P], + /, + ) -> Self: ... + + @overload + def __call__(self, __self: _S, /, req: _RequestT_contra) -> _AnyResponse: ... + + @overload + def __call__( + self, __self: _S, /, req: _RequestT_contra, *args: _P.args, **kw: _P.kwargs + ) -> _AnyResponse: ... + + def __call__(self, __self: _S, /, req: Any, *args: Any, **kw: Any) -> Any: + pass + + +class _UnboundMiddleware(Generic[_P, _AppT_contra, _RequestT_contra]): """A `wsgify.middleware` invocation that has not yet wrapped a middleware function; the intermediate object when you do something like ``@wsgify.middleware(RequestClass=Foo)`` """ - def __init__(self, wrapper_class, app, kw): + def __init__( + self, + wrapper_class: type[wsgify[Concatenate[_AppT_contra, _P], _RequestT_contra]], + app: _AppT_contra | None, + args: tuple[Any, ...], + kw: dict[str, Any], + ) -> None: self.wrapper_class = wrapper_class self.app = app + self.args = args self.kw = kw - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} at {id(self)} wrapping {self.app!r}>" - def __call__(self, func, app=None): + @overload + def __call__(self, func: None, app: _AppT_contra | None = None) -> Self: ... + + @overload + def __call__( + self, func: _Middleware[_RequestT_contra, _AppT_contra, _P], app: None = None + ) -> wsgify[Concatenate[_AppT_contra, _P], _RequestT_contra]: ... + + @overload + def __call__( + self, func: _Middleware[_RequestT_contra, _AppT_contra, _P], app: _AppT_contra + ) -> wsgify[Concatenate[_AppT_contra, _P], _RequestT_contra]: ... + + def __call__( + self, func: _Middleware[Any, Any, ...] | None, app: _AppT_contra | None = None + ) -> Any: if app is None: app = self.app - return self.wrapper_class.middleware(func, app=app, **self.kw) + return self.wrapper_class.middleware(func, app=app, **self.kw) # type: ignore -class _MiddlewareFactory: +class _MiddlewareFactory(Generic[_P, _AppT_contra, _RequestT_contra]): """A middleware that has not yet been bound to an application or configured. """ - def __init__(self, wrapper_class, middleware, kw): + def __init__( + self, + wrapper_class: type[wsgify[Concatenate[_AppT_contra, _P], _RequestT_contra]], + middleware: _Middleware[_RequestT_contra, _AppT_contra, _P], + args: tuple[Any, ...], + kw: dict[str, Any], + ) -> None: self.wrapper_class = wrapper_class self.middleware = middleware + self.args = args self.kw = kw - def __repr__(self): + def __repr__(self) -> str: return "<{} at {} wrapping {!r}>".format( self.__class__.__name__, id(self), self.middleware, ) - def __call__(self, app=None, **config): + @overload + def __call__( + self, app: None = None, *args: _P.args, **config: _P.kwargs + ) -> _MiddlewareFactory[[], _AppT_contra, _RequestT_contra]: ... + + @overload + def __call__( + self, app: _AppT_contra, *args: _P.args, **config: _P.kwargs + ) -> wsgify[[_AppT_contra], _RequestT_contra]: ... + + def __call__(self, app: Any = None, *args: Any, **config: Any) -> Any: kw = self.kw.copy() kw.update(config) - return self.wrapper_class.middleware(self.middleware, app, **kw) + return self.wrapper_class.middleware(self.middleware, app, *args, **kw) diff --git a/src/webob/descriptors.py b/src/webob/descriptors.py index 7230fc96..d3e7c19b 100644 --- a/src/webob/descriptors.py +++ b/src/webob/descriptors.py @@ -1,11 +1,58 @@ -from collections import namedtuple -from datetime import date, datetime +from __future__ import annotations + +from datetime import date, datetime, timedelta import re +from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, overload from webob.byterange import ContentRange, Range from webob.datetime_utils import parse_date, serialize_date from webob.util import header_docstring, warn_deprecation +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from time import _TimeTuple, struct_time + + from typing_extensions import TypeAlias + + from webob.etag import IfRange, IfRangeDate + from webob.request import BaseRequest + from webob.response import Response + from webob.types import ( + AsymmetricPropertyWithDelete, + SymmetricProperty, + SymmetricPropertyWithDelete, + ) + + _T = TypeVar("_T") + _DefaultT = TypeVar("_DefaultT") + _GetterReturnType = TypeVar("_GetterReturnType") + _SetterValueType = TypeVar("_SetterValueType") + _ConvertedGetterReturnType = TypeVar("_ConvertedGetterReturnType") + _ConvertedSetterValueType = TypeVar("_ConvertedSetterValueType") + _DescriptorT = TypeVar("_DescriptorT", bound=AsymmetricPropertyWithDelete[Any, Any]) + + _StringProperty: TypeAlias = SymmetricPropertyWithDelete["str | None"] + _ListProperty: TypeAlias = AsymmetricPropertyWithDelete[ + "tuple[str, ...] | None", "Iterable[str] | str | None" + ] + _DateProperty: TypeAlias = AsymmetricPropertyWithDelete[ + "datetime | None", + "date | datetime | timedelta | _TimeTuple | struct_time | float | str | None", + ] + _ContentRangeParams: TypeAlias = """( + ContentRange + | list[int] + | list[None] + | list[int | None] + | tuple[int, int] + | tuple[None, None] + | tuple[int, int, int | None] + | tuple[None, None, int | None] + | str + | None + )""" + + CHARSET_RE = re.compile(r";\s*charset=([^;]*)", re.I) SCHEME_RE = re.compile(r"^[a-z]+:", re.I) @@ -13,115 +60,162 @@ _not_given = object() -def environ_getter(key, default=_not_given, rfc_section=None): +@overload +def environ_getter( + key: str, *, rfc_section: str | None = None +) -> SymmetricProperty[Any]: ... + + +@overload +def environ_getter( + key: str, default: None, rfc_section: str | None = None +) -> SymmetricPropertyWithDelete[Any | None]: ... + + +@overload +def environ_getter( + key: str, default: _DefaultT, rfc_section: str | None = None +) -> AsymmetricPropertyWithDelete[Any | _DefaultT, Any | _DefaultT | None]: ... + + +def environ_getter( + key: str, default: Any = _not_given, rfc_section: str | None = None +) -> AsymmetricPropertyWithDelete[Any, Any] | SymmetricProperty[Any]: if rfc_section: doc = header_docstring(key, rfc_section) else: doc = "Gets and sets the ``%s`` key in the environment." % key if default is _not_given: - def fget(req): + def fget(req: BaseRequest) -> Any: return req.environ[key] - def fset(req, val): + def fset(req: BaseRequest, val: Any) -> None: req.environ[key] = val fdel = None else: - def fget(req): + def fget(req: BaseRequest) -> Any | _DefaultT: return req.environ.get(key, default) - def fset(req, val): + def fset(req: BaseRequest, val: Any) -> None: if val is None: if key in req.environ: del req.environ[key] else: req.environ[key] = val - def fdel(req): + def fdel(req: BaseRequest) -> None: del req.environ[key] return property(fget, fset, fdel, doc=doc) -def environ_decoder(key, default=_not_given, rfc_section=None, encattr=None): +@overload +def environ_decoder( + key: str, *, rfc_section: str | None = None, encattr: str | None = None +) -> SymmetricProperty[str]: ... + + +@overload +def environ_decoder( + key: str, default: str, rfc_section: str | None = None, encattr: str | None = None +) -> AsymmetricPropertyWithDelete[str, str | None]: ... + + +@overload +def environ_decoder( + key: str, default: None, rfc_section: str | None = None, encattr: str | None = None +) -> SymmetricPropertyWithDelete[str | None]: ... + + +def environ_decoder( + key: str, + default: Any = _not_given, + rfc_section: str | None = None, + encattr: str | None = None, +) -> SymmetricPropertyWithDelete[str | None] | SymmetricProperty[str]: + if rfc_section: doc = header_docstring(key, rfc_section) else: doc = "Gets and sets the ``%s`` key in the environment." % key if default is _not_given: - def fget(req): + def fget(req: BaseRequest) -> str: return req.encget(key, encattr=encattr) - def fset(req, val): + def fset(req: BaseRequest, val: str) -> None: return req.encset(key, val, encattr=encattr) fdel = None else: - def fget(req): + def fget(req: BaseRequest) -> str | _DefaultT: return req.encget(key, default, encattr=encattr) - def fset(req, val): + def fset(req: BaseRequest, val: str | None) -> None: # type: ignore[misc] if val is None: if key in req.environ: del req.environ[key] else: return req.encset(key, val, encattr=encattr) - def fdel(req): + def fdel(req: BaseRequest) -> None: del req.environ[key] return property(fget, fset, fdel, doc=doc) -def upath_property(key): - def fget(req): +def upath_property(key: str) -> SymmetricProperty[str]: + def fget(req: BaseRequest) -> str: encoding = req.url_encoding - return req.environ.get(key, "").encode("latin-1").decode(encoding) + return req.environ.get(key, "").encode("latin-1").decode(encoding) # type: ignore[no-any-return] - def fset(req, val): + def fset(req: BaseRequest, val: str) -> None: encoding = req.url_encoding req.environ[key] = val.encode(encoding).decode("latin-1") return property(fget, fset, doc="upath_property(%r)" % key) -def deprecated_property(attr, name, text, version): # pragma: no cover +def deprecated_property( + attr: _DescriptorT, name: str, text: str, version: str +) -> _DescriptorT: # pragma: no cover """ Wraps a descriptor, with a deprecation warning or error """ - def warn(): + def warn() -> None: warn_deprecation(f"The attribute {name} is deprecated: {text}", version, 3) - def fget(self): + def fget(self: object) -> Any: warn() return attr.__get__(self, type(self)) - def fset(self, val): + def fset(self: object, val: Any) -> None: warn() attr.__set__(self, val) - def fdel(self): + def fdel(self: object) -> None: warn() attr.__delete__(self) - return property(fget, fset, fdel, "" % name) + return property(fget, fset, fdel, "" % name) # type: ignore[return-value] -def header_getter(header, rfc_section): +def header_getter(header: str, rfc_section: str) -> _StringProperty: doc = header_docstring(header, rfc_section) key = header.lower() - def fget(r): + def fget(r: Response) -> str | None: for k, v in r._headerlist: if k.lower() == key: return v + return None - def fset(r, value): + def fset(r: Response, value: str | None) -> None: fdel(r) if value is not None: if not isinstance(value, str): @@ -131,13 +225,21 @@ def fset(r, value): r._headerlist.append((header, value)) - def fdel(r): + def fdel(r: Response) -> None: r._headerlist[:] = [(k, v) for (k, v) in r._headerlist if k.lower() != key] return property(fget, fset, fdel, doc) -def converter(prop, parse, serialize, convert_name=None): +def converter( + prop: AsymmetricPropertyWithDelete[_GetterReturnType, _SetterValueType], + parse: Callable[[_GetterReturnType], _ConvertedGetterReturnType], + serialize: Callable[[_ConvertedSetterValueType], _SetterValueType], + convert_name: str | None = None, +) -> AsymmetricPropertyWithDelete[ + _ConvertedGetterReturnType, _ConvertedSetterValueType | None +]: + assert isinstance(prop, property) convert_name = convert_name or "``{}`` and ``{}``".format( parse.__name__, @@ -147,40 +249,44 @@ def converter(prop, parse, serialize, convert_name=None): doc += " Converts it using %s." % convert_name hget, hset = prop.fget, prop.fset - def fget(r): + def fget(r: object) -> _ConvertedGetterReturnType: + assert hget is not None return parse(hget(r)) - def fset(r, val): + def fset(r: object, val: _ConvertedSetterValueType) -> None: + assert hset is not None if val is not None: - val = serialize(val) - hset(r, val) + sval = serialize(val) + else: + sval = None + hset(r, sval) return property(fget, fset, prop.fdel, doc) -def list_header(header, rfc_section): +def list_header(header: str, rfc_section: str) -> _ListProperty: prop = header_getter(header, rfc_section) return converter(prop, parse_list, serialize_list, "list") -def parse_list(value): +def parse_list(value: str | None) -> tuple[str, ...] | None: if not value: return None return tuple(filter(None, [v.strip() for v in value.split(",")])) -def serialize_list(value): +def serialize_list(value: Iterable[str] | str) -> str: if isinstance(value, (str, bytes)): return str(value) else: return ", ".join(map(str, value)) -def converter_date(prop): - return converter(prop, parse_date, serialize_date, "HTTP date") +def converter_date(prop: _StringProperty) -> _DateProperty: + return converter(prop, parse_date, serialize_date, "HTTP date") # type: ignore[arg-type] -def date_header(header, rfc_section): +def date_header(header: str, rfc_section: str) -> _DateProperty: return converter_date(header_getter(header, rfc_section)) @@ -192,7 +298,7 @@ def date_header(header, rfc_section): _rx_etag = re.compile(r'(?:^|\s)(W/)?"((?:\\"|.)*?)"') -def parse_etag_response(value, strong=False): +def parse_etag_response(value: str | None, strong: bool = False) -> str | None: """ Parse a response ETag. See: @@ -212,7 +318,9 @@ def parse_etag_response(value, strong=False): return m.group(2).replace('\\"', '"') -def serialize_etag_response(value): # return '"%s"' % value.replace('"', '\\"') +def serialize_etag_response( + value: tuple[str, bool] | str, +) -> str: # return '"%s"' % value.replace('"', '\\"') strong = True if isinstance(value, tuple): value, strong = value @@ -226,21 +334,25 @@ def serialize_etag_response(value): # return '"%s"' % value.replace('"', '\\"') return r -def serialize_if_range(value): +def serialize_if_range( + value: IfRange | IfRangeDate | datetime | date | str, +) -> str | None: if isinstance(value, (datetime, date)): return serialize_date(value) value = str(value) return value or None -def parse_range(value): +def parse_range(value: str | None) -> Range | None: if not value: return None # Might return None too: return Range.parse(value) -def serialize_range(value): +def serialize_range( + value: tuple[int, int | None] | list[int | None] | list[int] | str | None, +) -> str | None: if not value: return None elif isinstance(value, (list, tuple)): @@ -250,13 +362,13 @@ def serialize_range(value): return value -def parse_int(value): +def parse_int(value: str | None) -> int | None: if value is None or value == "": return None return int(value) -def parse_int_safe(value): +def parse_int_safe(value: str | None) -> int | None: if value is None or value == "": return None try: @@ -265,29 +377,29 @@ def parse_int_safe(value): return None -serialize_int = str +serialize_int: Callable[[int], str] = str -def parse_content_range(value): +def parse_content_range(value: str | None) -> ContentRange | None: if not value or not value.strip(): return None # May still return None return ContentRange.parse(value) -def serialize_content_range(value): +def serialize_content_range(value: _ContentRangeParams) -> str | None: if isinstance(value, (tuple, list)): if len(value) not in (2, 3): raise ValueError( "When setting content_range to a list/tuple, it must " - "be length 2 or 3 (not %r)" % value + f"be length 2 or 3 (not {value!r})" ) if len(value) == 2: begin, end = value length = None else: begin, end, length = value - value = ContentRange(begin, end, length) + value = ContentRange(begin, end, length) # type: ignore[arg-type] value = str(value).strip() if not value: return None @@ -297,7 +409,7 @@ def serialize_content_range(value): _rx_auth_param = re.compile(r'([a-z]+)[ \t]*=[ \t]*(".*?"|[^,]*?)[ \t]*(?:\Z|, *)') -def parse_auth_params(params): +def parse_auth_params(params: str) -> dict[str, str]: r = {} for k, v in _rx_auth_param.findall(params): r[k] = v.strip('"') @@ -305,22 +417,32 @@ def parse_auth_params(params): # see http://lists.w3.org/Archives/Public/ietf-http-wg/2009OctDec/0297.html -known_auth_schemes = [ - "Basic", - "Digest", - "WSSE", - "HMACDigest", - "GoogleLogin", - "Cookie", - "OpenID", -] -known_auth_schemes = dict.fromkeys(known_auth_schemes, None) +known_auth_schemes = dict.fromkeys( + [ + "Basic", + "Digest", + "WSSE", + "HMACDigest", + "GoogleLogin", + "Cookie", + "OpenID", + ], + None, +) + + +class Authorization(NamedTuple): + authtype: str + params: dict[str, str] | str + -_authorization = namedtuple("Authorization", ["authtype", "params"]) +_authorization = Authorization +del Authorization -def parse_auth(val): +def parse_auth(val: str | None) -> _authorization | None: if val is not None: + params: dict[str, str] | str authtype, sep, params = val.partition(" ") if authtype in known_auth_schemes: if authtype == "Basic" and '"' not in params: @@ -332,7 +454,9 @@ def parse_auth(val): return val -def serialize_auth(val): +def serialize_auth( + val: tuple[str, dict[str, str] | str] | list[Any] | str | None, +) -> str | None: if isinstance(val, (tuple, list)): authtype, params = val if isinstance(params, dict): diff --git a/src/webob/etag.py b/src/webob/etag.py index 70153d73..444a6e9c 100644 --- a/src/webob/etag.py +++ b/src/webob/etag.py @@ -4,31 +4,51 @@ Also If-Range parsing """ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + from webob.datetime_utils import parse_date, serialize_date from webob.descriptors import _rx_etag from webob.util import header_docstring +if TYPE_CHECKING: + from collections.abc import Collection + from datetime import datetime + + from typing_extensions import TypeAlias + + from webob.request import BaseRequest + from webob.response import Response + from webob.types import AsymmetricPropertyWithDelete + + _ETag: TypeAlias = "_AnyETag | _NoETag | ETagMatcher" + _ETagProperty: TypeAlias = AsymmetricPropertyWithDelete[_ETag, "_ETag | str | None"] + __all__ = ["AnyETag", "NoETag", "ETagMatcher", "IfRange", "etag_property"] -def etag_property(key, default, rfc_section, strong=True): +def etag_property( + key: str, default: _ETag, rfc_section: str, strong: bool = True +) -> _ETagProperty: + doc = header_docstring(key, rfc_section) doc += " Converts it as a Etag." - def fget(req): + def fget(req: BaseRequest) -> _ETag: value = req.environ.get(key) if not value: return default else: return ETagMatcher.parse(value, strong=strong) - def fset(req, val): + def fset(req: BaseRequest, val: _ETag | str | None) -> None: if val is None: req.environ[key] = None else: req.environ[key] = str(val) - def fdel(req): + def fdel(req: BaseRequest) -> None: del req.environ[key] return property(fget, fset, fdel, doc=doc) @@ -39,20 +59,20 @@ class _AnyETag: Represents an ETag of *, or a missing ETag when matching is 'safe' """ - def __repr__(self): + def __repr__(self) -> str: return "" - def __bool__(self): + def __bool__(self) -> Literal[False]: return False - def __contains__(self, other): + def __contains__(self, other: str | None) -> Literal[True]: return True - def __str__(self): + def __str__(self) -> str: return "*" -AnyETag = _AnyETag() +AnyETag: _AnyETag = _AnyETag() class _NoETag: @@ -60,37 +80,37 @@ class _NoETag: Represents a missing ETag when matching is unsafe """ - def __repr__(self): + def __repr__(self) -> str: return "" - def __bool__(self): + def __bool__(self) -> Literal[False]: return False - def __contains__(self, other): + def __contains__(self, other: str | None) -> Literal[False]: return False - def __str__(self): + def __str__(self) -> str: return "" -NoETag = _NoETag() +NoETag: _NoETag = _NoETag() # TODO: convert into a simple tuple class ETagMatcher: - def __init__(self, etags): + def __init__(self, etags: Collection[str]) -> None: self.etags = etags - def __contains__(self, other): + def __contains__(self, other: str | None) -> bool: return other in self.etags - def __repr__(self): + def __repr__(self) -> str: return "" % (" or ".join(self.etags)) @classmethod - def parse(cls, value, strong=True): + def parse(cls, value: str, strong: bool = True) -> ETagMatcher | _AnyETag: """ Parse this from a header value """ @@ -106,16 +126,16 @@ def parse(cls, value, strong=True): else: return cls([t for w, t in matches]) - def __str__(self): + def __str__(self) -> str: return ", ".join(map('"%s"'.__mod__, self.etags)) class IfRange: - def __init__(self, etag): + def __init__(self, etag: _ETag) -> None: self.etag = etag @classmethod - def parse(cls, value): + def parse(cls, value: str | None) -> IfRange | IfRangeDate: """ Parse this from a header value. """ @@ -123,36 +143,37 @@ def parse(cls, value): return cls(AnyETag) elif value.endswith(" GMT"): # Must be a date - return IfRangeDate(parse_date(value)) + # FIXME: What if the date is not valid? + return IfRangeDate(parse_date(value)) # type: ignore[arg-type] else: return cls(ETagMatcher.parse(value)) - def __contains__(self, resp): + def __contains__(self, resp: Response) -> bool: """ Return True if the If-Range header matches the given etag or last_modified """ return resp.etag_strong in self.etag - def __bool__(self): + def __bool__(self) -> bool: return bool(self.etag) - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.etag!r})" - def __str__(self): + def __str__(self) -> str: return str(self.etag) if self.etag else "" class IfRangeDate: - def __init__(self, date): + def __init__(self, date: datetime) -> None: self.date = date - def __contains__(self, resp): + def __contains__(self, resp: Response) -> bool: last_modified = resp.last_modified - return last_modified and (last_modified <= self.date) + return (last_modified <= self.date) if last_modified else False - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.date!r})" - def __str__(self): + def __str__(self) -> str: return serialize_date(self.date) diff --git a/src/webob/exc.py b/src/webob/exc.py index 8439fef0..fef2dcca 100644 --- a/src/webob/exc.py +++ b/src/webob/exc.py @@ -165,10 +165,13 @@ """ +from __future__ import annotations + import json import re from string import Template import sys +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar from urllib import parse as urlparse from webob.acceptparse import create_accept_header @@ -176,28 +179,49 @@ from webob.response import Response from webob.util import html_escape, text_ +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from _typeshed import OptExcInfo, SupportsItems, SupportsKeysAndGetItem + from _typeshed.wsgi import StartResponse, WSGIApplication, WSGIEnvironment + from typing_extensions import Self, TypeAlias + + _Headers: TypeAlias = """( + SupportsItems[str, str] + | SupportsKeysAndGetItem[str, str] + | Iterable[tuple[str, str]] + )""" + + class _JSONFormatter(Protocol): + def __call__( + self, *, body: str, status: str, title: str, environ: WSGIEnvironment + ) -> Any: ... + + +_T = TypeVar("_T") + tag_re = re.compile(r"<.*?>", re.S) br_re = re.compile(r"", re.I | re.S) comment_re = re.compile(r"") -class _lazified: - def __init__(self, func, value): +class _lazified(Generic[_T]): + def __init__(self, func: Callable[[_T], str], value: _T) -> None: self.func = func self.value = value - def __str__(self): + def __str__(self) -> str: return self.func(self.value) -def lazify(func): - def wrapper(value): +def lazify(func: Callable[[_T], str]) -> Callable[[_T], _lazified[_T]]: + def wrapper(value: _T) -> _lazified[_T]: return _lazified(func, value) return wrapper -def no_escape(value): +def no_escape(value: object) -> str: if value is None: return "" @@ -210,7 +234,7 @@ def no_escape(value): return value -def strip_tags(value): +def strip_tags(value: str) -> str: value = value.replace("\n", " ") value = value.replace("\r", "") value = br_re.sub("\n", value) @@ -221,11 +245,13 @@ def strip_tags(value): class HTTPException(Exception): - def __init__(self, message, wsgi_response): + def __init__(self, message: str, wsgi_response: Response) -> None: Exception.__init__(self, message) self.wsgi_response = wsgi_response - def __call__(self, environ, start_response): + def __call__( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: return self.wsgi_response(environ, start_response) @@ -238,7 +264,7 @@ class WSGIHTTPException(Response, HTTPException): code = 500 title = "Internal Server Error" explanation = "" - body_template_obj = Template( + body_template_obj: Template = Template( """\ ${explanation}

${detail} @@ -246,14 +272,14 @@ class WSGIHTTPException(Response, HTTPException): """ ) - plain_template_obj = Template( + plain_template_obj: Template = Template( """\ ${status} ${body}""" ) - html_template_obj = Template( + html_template_obj: Template = Template( """\ @@ -271,13 +297,13 @@ class WSGIHTTPException(Response, HTTPException): def __init__( self, - detail=None, - headers=None, - comment=None, - body_template=None, - json_formatter=None, - **kw, - ): + detail: str | None = None, + headers: _Headers | None = None, + comment: str | None = None, + body_template: str | None = None, + json_formatter: _JSONFormatter | None = None, + **kw: Any, + ) -> None: Response.__init__(self, status=f"{self.code} {self.title}", **kw) Exception.__init__(self, detail) @@ -295,21 +321,24 @@ def __init__( del self.content_length if json_formatter is not None: - self.json_formatter = json_formatter + self.json_formatter = json_formatter # type: ignore[method-assign] - def __str__(self): + def __str__(self) -> str: # type: ignore[override] return self.detail or self.explanation - def _make_body(self, environ, escape): - escape = lazify(escape) - args = { - "explanation": escape(self.explanation), - "detail": escape(self.detail or ""), - "comment": escape(self.comment or ""), + def _make_body( + self, environ: WSGIEnvironment, escape: Callable[[object], str] + ) -> str: + + lazy_escape = lazify(escape) + args: dict[str, object] = { + "explanation": lazy_escape(self.explanation), + "detail": lazy_escape(self.detail or ""), + "comment": lazy_escape(self.comment or ""), } if self.comment: - args["html_comment"] = "" % escape(self.comment) + args["html_comment"] = "" % lazy_escape(self.comment) else: args["html_comment"] = "" @@ -317,15 +346,15 @@ def _make_body(self, environ, escape): # Custom template; add headers to args for k, v in environ.items(): - args[k] = escape(v) + args[k] = lazy_escape(v) for k, v in self.headers.items(): - args[k.lower()] = escape(v) + args[k.lower()] = lazy_escape(v) t_obj = self.body_template_obj return t_obj.safe_substitute(args) - def plain_body(self, environ): + def plain_body(self, environ: WSGIEnvironment) -> str: body = self._make_body(environ, no_escape) body = strip_tags(body) @@ -333,15 +362,18 @@ def plain_body(self, environ): status=self.status, title=self.title, body=body ) - def html_body(self, environ): + def html_body(self, environ: WSGIEnvironment) -> str: body = self._make_body(environ, html_escape) return self.html_template_obj.substitute(status=self.status, body=body) - def json_formatter(self, body, status, title, environ): + def json_formatter( + self, *, body: str, status: str, title: str, environ: WSGIEnvironment + ) -> Any: + return {"message": body, "code": status, "title": title} - def json_body(self, environ): + def json_body(self, environ: WSGIEnvironment) -> str: # type: ignore[override] body = self._make_body(environ, no_escape) jsonbody = self.json_formatter( body=body, status=self.status, title=self.title, environ=environ @@ -349,7 +381,9 @@ def json_body(self, environ): return json.dumps(jsonbody) - def generate_response(self, environ, start_response): + def generate_response( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: if self.content_length is not None: del self.content_length headerlist = list(self.headerlist) @@ -376,7 +410,10 @@ def generate_response(self, environ, start_response): return resp(environ, start_response) - def __call__(self, environ, start_response): + def __call__( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: + is_head = environ["REQUEST_METHOD"] == "HEAD" if self.has_body or self.empty_body or is_head: @@ -390,7 +427,7 @@ def __call__(self, environ, start_response): return app_iter @property - def wsgi_response(self): + def wsgi_response(self) -> Self: # type: ignore[override] return self @@ -553,13 +590,13 @@ class _HTTPMove(HTTPRedirection): def __init__( self, - detail=None, - headers=None, - comment=None, - body_template=None, - location=None, - add_slash=False, - ): + detail: str | None = None, + headers: _Headers | None = None, + comment: str | None = None, + body_template: str | None = None, + location: str | None = None, + add_slash: bool = False, + ) -> None: super().__init__( detail=detail, headers=headers, comment=comment, body_template=body_template ) @@ -577,7 +614,10 @@ def __init__( ) self.add_slash = add_slash - def __call__(self, environ, start_response): + def __call__( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: + req = Request(environ) if self.add_slash: @@ -1313,26 +1353,36 @@ class HTTPExceptionMiddleware: *expected* exceptions raise through the WSGI stack is dangerous. """ - def __init__(self, application): + def __init__(self, application: WSGIApplication) -> None: self.application = application - def __call__(self, environ, start_response): + def __call__( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: + try: return self.application(environ, start_response) - except HTTPException: + except HTTPException as exc: parent_exc_info = sys.exc_info() - def repl_start_response(status, headers, exc_info=None): + def repl_start_response( + status: str, + headers: list[tuple[str, str]], + exc_info: OptExcInfo | None = None, + ) -> Callable[[bytes], object]: + if exc_info is None: exc_info = parent_exc_info return start_response(status, headers, exc_info) - return parent_exc_info[1](environ, repl_start_response) + return exc(environ, repl_start_response) __all__ = ["HTTPExceptionMiddleware", "status_map"] -status_map = {} +status_map: dict[ + int, type[HTTPOk | HTTPRedirection | HTTPClientError | HTTPServerError] +] = {} for name, value in list(globals().items()): if ( @@ -1342,13 +1392,11 @@ def repl_start_response(status, headers, exc_info=None): ): __all__.append(name) - if all( - ( - getattr(value, "code", None), - value not in (HTTPRedirection, HTTPClientError, HTTPServerError), - issubclass( - value, (HTTPOk, HTTPRedirection, HTTPClientError, HTTPServerError) - ), + if ( + getattr(value, "code", None) + and value not in (HTTPRedirection, HTTPClientError, HTTPServerError) + and issubclass( + value, (HTTPOk, HTTPRedirection, HTTPClientError, HTTPServerError) ) ): status_map[value.code] = value diff --git a/src/webob/headers.py b/src/webob/headers.py index 0061486a..4872b99c 100644 --- a/src/webob/headers.py +++ b/src/webob/headers.py @@ -1,17 +1,25 @@ -from collections.abc import MutableMapping +from __future__ import annotations + +from collections.abc import Iterator, MutableMapping +from typing import TYPE_CHECKING, TypeVar, overload from webob.multidict import MultiDict +if TYPE_CHECKING: + from _typeshed.wsgi import WSGIEnvironment + + _T = TypeVar("_T") + __all__ = ["ResponseHeaders", "EnvironHeaders"] -class ResponseHeaders(MultiDict): +class ResponseHeaders(MultiDict[str, str]): """ Dictionary view on the response headerlist. Keys are normalized for case and whitespace. """ - def __getitem__(self, key): + def __getitem__(self, key: str) -> str: key = key.lower() for k, v in reversed(self._items): @@ -19,13 +27,13 @@ def __getitem__(self, key): return v raise KeyError(key) - def getall(self, key): + def getall(self, key: str) -> list[str]: key = key.lower() return [v for (k, v) in self._items if k.lower() == key] - def mixed(self): - r = self.dict_of_lists() + def mixed(self) -> dict[str, str | list[str]]: + r: dict[str, str | list[str]] = self.dict_of_lists() # type: ignore[assignment] for key, val in r.items(): if len(val) == 1: @@ -33,20 +41,20 @@ def mixed(self): return r - def dict_of_lists(self): - r = {} + def dict_of_lists(self) -> dict[str, list[str]]: + r: dict[str, list[str]] = {} for key, val in self.items(): r.setdefault(key.lower(), []).append(val) return r - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: str) -> None: norm_key = key.lower() self._items[:] = [(k, v) for (k, v) in self._items if k.lower() != norm_key] self._items.append((key, value)) - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: key = key.lower() items = self._items found = False @@ -59,7 +67,10 @@ def __delitem__(self, key): if not found: raise KeyError(key) - def __contains__(self, key): + def __contains__(self, key: object) -> bool: + if not isinstance(key, str): + return False # pragma: no cover + key = key.lower() for k, _ in self._items: @@ -70,7 +81,7 @@ def __contains__(self, key): has_key = __contains__ - def setdefault(self, key, default=None): + def setdefault(self, key: str, default: str) -> str: c_key = key.lower() for k, v in self._items: @@ -80,7 +91,13 @@ def setdefault(self, key, default=None): return default - def pop(self, key, *args): + @overload + def pop(self, key: str) -> str: ... + + @overload + def pop(self, key: str, default: _T, /) -> str | _T: ... + + def pop(self, key: str, *args: _T) -> str | _T: if len(args) > 1: raise TypeError( "pop expected at most 2 arguments, got %s" % repr(1 + len(args)) @@ -110,7 +127,7 @@ def pop(self, key, *args): header2key = {v.upper(): k for (k, v) in key2header.items()} -def _trans_key(key): +def _trans_key(key: object) -> str | None: if not isinstance(key, str): return None elif key in key2header: @@ -121,7 +138,7 @@ def _trans_key(key): return None -def _trans_name(name): +def _trans_name(name: str) -> str: name = name.upper() if name in header2key: @@ -130,7 +147,7 @@ def _trans_name(name): return "HTTP_" + name.replace("-", "_") -class EnvironHeaders(MutableMapping): +class EnvironHeaders(MutableMapping[str, str]): """An object that represents the headers as present in a WSGI environment. @@ -141,26 +158,28 @@ class EnvironHeaders(MutableMapping): headers). """ - def __init__(self, environ): + def __init__(self, environ: WSGIEnvironment) -> None: self.environ = environ - def __getitem__(self, hname): - return self.environ[_trans_name(hname)] + def __getitem__(self, hname: str) -> str: + return self.environ[_trans_name(hname)] # type: ignore[no-any-return] - def __setitem__(self, hname, value): + def __setitem__(self, hname: str, value: str) -> None: self.environ[_trans_name(hname)] = value - def __delitem__(self, hname): + def __delitem__(self, hname: str) -> None: del self.environ[_trans_name(hname)] - def keys(self): + def keys(self) -> Iterator[str]: # type: ignore[override] return filter(None, map(_trans_key, self.environ)) - def __contains__(self, hname): + def __contains__(self, hname: object) -> bool: + if not isinstance(hname, str): + return False # pragma: no cover return _trans_name(hname) in self.environ - def __len__(self): + def __len__(self) -> int: return len(list(self.keys())) - def __iter__(self): + def __iter__(self) -> Iterator[str]: yield from self.keys() diff --git a/src/webob/multidict.py b/src/webob/multidict.py index e54ea3b0..d356f7d4 100644 --- a/src/webob/multidict.py +++ b/src/webob/multidict.py @@ -4,22 +4,69 @@ """ Gives a multi-value dictionary object (MultiDict) plus several wrappers """ +from __future__ import annotations + import binascii -from collections.abc import MutableMapping +from collections.abc import Collection, Iterable, Iterator, MutableMapping +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload from urllib.parse import urlencode as url_encode import warnings +if TYPE_CHECKING: + from _typeshed import SupportsKeysAndGetItem + from _typeshed.wsgi import WSGIEnvironment + from typing_extensions import Self + + from webob.compat import cgi_FieldStorage + from webob.types import _FieldStorageWithFile + + _KT_co = TypeVar("_KT_co", covariant=True) + _VT_co = TypeVar("_VT_co", covariant=True) + + class _SupportsItemsWithIterableResult(Protocol[_KT_co, _VT_co]): + def items(self) -> Iterable[tuple[_KT_co, _VT_co]]: ... + + +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + __all__ = ["MultiDict", "NestedMultiDict", "NoVars", "GetDict"] -class MultiDict(MutableMapping): +class MultiDict(MutableMapping[_KT, _VT]): """ An ordered dictionary that can have multiple values for each key. Adds the methods getall, getone, mixed and extend and add to the normal dictionary interface. """ - def __init__(self, *args, **kw): + @overload + def __init__(self) -> None: ... + + @overload + def __init__(self: MultiDict[str, _VT], **kwargs: _VT) -> None: ... + + @overload + def __init__(self, m: _SupportsItemsWithIterableResult[_KT, _VT], /) -> None: ... + + @overload + def __init__( + self: MultiDict[str, _VT], + m: _SupportsItemsWithIterableResult[str, _VT], + /, + **kwargs: _VT, + ) -> None: ... + + @overload + def __init__(self, m: Iterable[tuple[_KT, _VT]], /) -> None: ... + + @overload + def __init__( + self: MultiDict[str, _VT], m: Iterable[tuple[str, _VT]], /, **kwargs: _VT + ) -> None: ... + + def __init__(self, *args: Any, **kw: _VT) -> None: # type: ignore[misc] if len(args) > 1: raise TypeError( "MultiDict can only be called with one positional " "argument" @@ -30,15 +77,15 @@ def __init__(self, *args, **kw): items = list(args[0].items()) else: items = list(args[0]) - self._items = items + self._items: list[tuple[_KT, _VT]] = items else: self._items = [] if kw: - self._items.extend(kw.items()) + self._items.extend(kw.items()) # type: ignore[arg-type] @classmethod - def view_list(cls, lst): + def view_list(cls, lst: list[tuple[_KT, _VT]]) -> Self: """ Create a multidict that is a view on the given list """ @@ -48,35 +95,37 @@ def view_list(cls, lst): "%s.view_list(obj) takes only actual list objects, not %r" % (cls.__name__, lst) ) - obj = cls() + obj: Self = cls() obj._items = lst return obj @classmethod - def from_fieldstorage(cls, fs): + def from_fieldstorage( + cls, fs: cgi_FieldStorage + ) -> MultiDict[str, str | _FieldStorageWithFile]: """ Create a multidict from a cgi.FieldStorage instance """ - obj = cls() + obj: MultiDict[str, str | _FieldStorageWithFile] = cls() # fs.list can be None when there's nothing to parse for field in fs.list or (): charset = field.type_options.get("charset", "utf8") transfer_encoding = field.headers.get("Content-Transfer-Encoding", None) - supported_transfer_encoding = { + supported_transfer_encoding: dict[str, Any] = { "base64": binascii.a2b_base64, "quoted-printable": binascii.a2b_qp, } if charset == "utf8": - def decode(b): + def decode(b: str) -> str: return b else: - def decode(b): + def decode(b: str) -> str: return b.encode("utf8").decode(charset) if field.filename: @@ -96,33 +145,33 @@ def decode(b): return obj - def __getitem__(self, key): + def __getitem__(self, key: _KT) -> _VT: for k, v in reversed(self._items): if k == key: return v raise KeyError(key) - def __setitem__(self, key, value): + def __setitem__(self, key: _KT, value: _VT) -> None: try: del self[key] except KeyError: pass self._items.append((key, value)) - def add(self, key, value): + def add(self, key: _KT, value: _VT) -> None: """ Add the key and value, not overwriting any previous value. """ self._items.append((key, value)) - def getall(self, key): + def getall(self, key: _KT) -> list[_VT]: """ Return a list of all values matching the key (may be an empty list) """ return [v for k, v in self._items if k == key] - def getone(self, key): + def getone(self, key: _KT) -> _VT: """ Get one value matching the key, raising a KeyError if multiple values were found. @@ -137,7 +186,7 @@ def getone(self, key): return v[0] - def mixed(self): + def mixed(self) -> dict[_KT, _VT | list[_VT]]: """ Returns a dictionary where the values are either single values, or a list of values when a key/value appears more than @@ -145,8 +194,8 @@ def mixed(self): dictionary often used to represent the variables in a web request. """ - result = {} - multi = {} + result: dict[_KT, _VT | list[_VT]] = {} + multi: dict[_KT, None] = {} for key, value in self.items(): if key in result: @@ -154,27 +203,27 @@ def mixed(self): # *actual* values in this dictionary: if key in multi: - result[key].append(value) + result[key].append(value) # type: ignore[union-attr] else: - result[key] = [result[key], value] + result[key] = [result[key], value] # type: ignore[list-item] multi[key] = None else: result[key] = value return result - def dict_of_lists(self): + def dict_of_lists(self) -> dict[_KT, list[_VT]]: """ Returns a dictionary where each key is associated with a list of values. """ - r = {} + r: dict[_KT, list[_VT]] = {} for key, val in self.items(): r.setdefault(key, []).append(val) return r - def __delitem__(self, key): + def __delitem__(self, key: _KT) -> None: items = self._items found = False @@ -186,7 +235,7 @@ def __delitem__(self, key): if not found: raise KeyError(key) - def __contains__(self, key): + def __contains__(self, key: object) -> bool: for k, _ in self._items: if k == key: return True @@ -195,21 +244,35 @@ def __contains__(self, key): has_key = __contains__ - def clear(self): + def clear(self) -> None: del self._items[:] - def copy(self): + def copy(self) -> Self: return self.__class__(self) - def setdefault(self, key, default=None): + @overload + def setdefault( + self: MultiDict[_KT, _VT | None], key: _KT, default: None = None + ) -> _VT | None: ... + + @overload + def setdefault(self, key: _KT, default: _VT) -> _VT: ... + + def setdefault(self, key: _KT, default: _VT | None = None) -> _VT | None: for k, v in self._items: if key == k: return v - self._items.append((key, default)) + self._items.append((key, default)) # type: ignore[arg-type] return default - def pop(self, key, *args): + @overload + def pop(self, key: _KT) -> _VT: ... + + @overload + def pop(self, key: _KT, default: _T, /) -> _VT | _T: ... + + def pop(self, key: _KT, *args: _T) -> _VT | _T: if len(args) > 1: raise TypeError( "pop expected at most 2 arguments, got %s" % repr(1 + len(args)) @@ -225,10 +288,21 @@ def pop(self, key, *args): return args[0] raise KeyError(key) - def popitem(self): + def popitem(self) -> tuple[_KT, _VT]: return self._items.pop() - def update(self, *args, **kw): + @overload # type: ignore[override] + def update(self: MultiDict[str, _VT], **kwargs: _VT) -> None: ... + + @overload + def update(self, m: Collection[tuple[_KT, _VT]], /) -> None: ... + + @overload + def update( + self: MultiDict[str, _VT], m: Collection[tuple[str, _VT]], /, **kwargs: _VT + ) -> None: ... + + def update(self, *args: Collection[tuple[_KT, _VT]], **kw: _VT) -> None: # type: ignore[misc] if args: lst = args[0] @@ -242,7 +316,40 @@ def update(self, *args, **kw): warnings.warn(msg, UserWarning, stacklevel=2) MutableMapping.update(self, *args, **kw) - def extend(self, other=None, **kwargs): + @overload + def extend(self, other: _SupportsItemsWithIterableResult[_KT, _VT]) -> None: ... + + @overload + def extend( + self: MultiDict[str, _VT], + other: _SupportsItemsWithIterableResult[str, _VT], + **kwargs: _VT, + ) -> None: ... + + @overload + def extend(self, other: Iterable[tuple[_KT, _VT]]) -> None: ... + + @overload + def extend( + self: MultiDict[str, _VT], other: Iterable[tuple[str, _VT]], **kwargs: _VT + ) -> None: ... + + @overload + def extend(self, other: SupportsKeysAndGetItem[_KT, _VT]) -> None: ... + + @overload + def extend( + self: MultiDict[str, _VT], + other: SupportsKeysAndGetItem[str, _VT], + **kwargs: _VT, + ) -> None: ... + + @overload + def extend( + self: MultiDict[str, _VT], other: None = None, **kwargs: _VT + ) -> None: ... + + def extend(self, other: Any | None = None, **kwargs: _VT) -> None: if other is None: pass elif hasattr(other, "items"): @@ -255,47 +362,67 @@ def extend(self, other=None, **kwargs): self._items.append((k, v)) if kwargs: - self.update(kwargs) + self.update(kwargs) # type: ignore[arg-type] - def __repr__(self): + def __repr__(self) -> str: items = map("(%r, %r)".__mod__, _hide_passwd(self.items())) return "{}([{}])".format(self.__class__.__name__, ", ".join(items)) - def __len__(self): + def __len__(self) -> int: return len(self._items) # # All the iteration: # - def keys(self): + def keys(self) -> Iterator[_KT]: # type: ignore[override] for k, _ in self._items: yield k __iter__ = keys - def items(self): + def items(self) -> Iterator[tuple[_KT, _VT]]: # type: ignore[override] return iter(self._items) - def values(self): + def values(self) -> Iterator[_VT]: # type: ignore[override] for _, v in self._items: yield v + if TYPE_CHECKING: + # more permissive get + @overload + def get(self, key: _KT, /) -> _VT | None: ... + + @overload + def get(self, key: _KT, /, default: _T) -> _VT | _T: ... + + def get(self, key: _KT, /, default: Any = None) -> Any: ... + _dummy = object() -class GetDict(MultiDict): +class GetDict(MultiDict[str, str]): # def __init__(self, data, tracker, encoding, errors): # d = lambda b: b.decode(encoding, errors) # data = [(d(k), d(v)) for k,v in data] - def __init__(self, data, env): + @overload + def __init__( + self, data: _SupportsItemsWithIterableResult[str, str], env: WSGIEnvironment + ) -> None: ... + + @overload + def __init__( + self, data: Iterable[tuple[str, str]], env: WSGIEnvironment + ) -> None: ... + + def __init__(self, data: Any, env: WSGIEnvironment) -> None: self.env = env MultiDict.__init__(self, data) - def on_change(self): - def e(t): + def on_change(self) -> None: + def e(t: str) -> bytes: return t.encode("utf8") data = [(e(k), e(v)) for k, v in self.items()] @@ -303,89 +430,125 @@ def e(t): self.env["QUERY_STRING"] = qs self.env["webob._parsed_query_vars"] = (self, qs) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: str) -> None: MultiDict.__setitem__(self, key, value) self.on_change() - def add(self, key, value): + def add(self, key: str, value: str) -> None: MultiDict.add(self, key, value) self.on_change() - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: MultiDict.__delitem__(self, key) self.on_change() - def clear(self): + def clear(self) -> None: MultiDict.clear(self) self.on_change() - def setdefault(self, key, default=None): + def setdefault(self, key: str, default: str) -> str: result = MultiDict.setdefault(self, key, default) self.on_change() return result - def pop(self, key, *args): + @overload + def pop(self, key: str) -> str: ... + + @overload + def pop(self, key: str, default: _T, /) -> str | _T: ... + + def pop(self, key: str, *args: _T) -> str | _T: result = MultiDict.pop(self, key, *args) self.on_change() return result - def popitem(self): + def popitem(self) -> tuple[str, str]: result = MultiDict.popitem(self) self.on_change() return result - def update(self, *args, **kwargs): + @overload # type: ignore[override] + def update(self, **kwargs: str) -> None: ... + + @overload + def update(self, m: Collection[tuple[str, str]], /, **kwargs: str) -> None: ... + + def update(self, *args: Any, **kwargs: str) -> None: # type: ignore[misc] MultiDict.update(self, *args, **kwargs) self.on_change() - def extend(self, *args, **kwargs): - MultiDict.extend(self, *args, **kwargs) + @overload + def extend( + self, other: _SupportsItemsWithIterableResult[str, str], **kwargs: str + ) -> None: ... + + @overload + def extend(self, other: Iterable[tuple[str, str]], **kwargs: str) -> None: ... + + @overload + def extend( + self, other: SupportsKeysAndGetItem[str, str], **kwargs: str + ) -> None: ... + + @overload + def extend(self, other: None = None, **kwargs: str) -> None: ... + + def extend(self, *args: Any, **kwargs: str) -> None: # type: ignore[misc] + MultiDict.extend(self, *args, **kwargs) # type: ignore[arg-type] self.on_change() - def __repr__(self): + def __repr__(self) -> str: items = map("(%r, %r)".__mod__, _hide_passwd(self.items())) # TODO: GET -> GetDict return "GET([%s])" % (", ".join(items)) - def copy(self): + def copy(self) -> MultiDict[str, str]: # type: ignore[override] # Copies shouldn't be tracked return MultiDict(self) -class NestedMultiDict(MultiDict): +class NestedMultiDict(MultiDict[_KT, _VT]): """ Wraps several MultiDict objects, treating it as one large MultiDict """ - def __init__(self, *dicts): + # FIXME: the annotation here is too strict currently, because we need to + # allow violating the variance of _VT for MultiDict, we should replace + # this with a MultiMapping Protocol with the correct variance + def __init__(self, *dicts: MultiDict[_KT, _VT]) -> None: self.dicts = dicts - def __getitem__(self, key): + def __getitem__(self, key: _KT) -> _VT: for d in self.dicts: value = d.get(key, _dummy) if value is not _dummy: - return value + return value # type: ignore[return-value] raise KeyError(key) - def _readonly(self, *args, **kw): - raise KeyError("NestedMultiDict objects are read-only") + if TYPE_CHECKING: + # NOTE: This gives us a slightly better type checker error + _readonly = None + else: + + def _readonly(self, *args, **kw): + raise KeyError("NestedMultiDict objects are read-only") - __setitem__ = _readonly - add = _readonly - __delitem__ = _readonly - clear = _readonly - setdefault = _readonly - pop = _readonly - popitem = _readonly - update = _readonly + __setitem__ = _readonly # type: ignore[assignment] + add = _readonly # type: ignore[assignment] + __delitem__ = _readonly # type: ignore[assignment] + clear = _readonly # type: ignore[assignment] + setdefault = _readonly # type: ignore[assignment] + pop = _readonly # type: ignore[assignment] + popitem = _readonly # type: ignore[assignment] + update = _readonly # type: ignore[assignment] - def getall(self, key): + def getall(self, key: _KT) -> list[_VT]: result = [] for d in self.dicts: @@ -398,10 +561,10 @@ def getall(self, key): # mixed # dict_of_lists - def copy(self): + def copy(self) -> MultiDict[_KT, _VT]: # type: ignore[override] return MultiDict(self) - def __contains__(self, key): + def __contains__(self, key: object) -> bool: for d in self.dicts: if key in d: return True @@ -410,7 +573,7 @@ def __contains__(self, key): has_key = __contains__ - def __len__(self): + def __len__(self) -> int: v = 0 for d in self.dicts: @@ -418,22 +581,22 @@ def __len__(self): return v - def __bool__(self): + def __bool__(self) -> bool: for d in self.dicts: if d: return True return False - def items(self): + def items(self) -> Iterator[tuple[_KT, _VT]]: # type: ignore[override] for d in self.dicts: yield from d.items() - def values(self): + def values(self) -> Iterator[_VT]: # type: ignore[override] for d in self.dicts: yield from d.values() - def keys(self): + def keys(self) -> Iterator[_KT]: # type: ignore[override] for d in self.dicts: yield from d @@ -448,65 +611,79 @@ class NoVars: This is read-only """ - def __init__(self, reason=None): + def __init__(self, reason: str | None = None) -> None: self.reason = reason or "N/A" - def __getitem__(self, key): - raise KeyError(f"No key {key!r}: {self.reason}") + if not TYPE_CHECKING: + # NOTE: It's better to pretend the methods don't exist for NoVars + # so we get better type errors + + def __getitem__(self, key): + raise KeyError(f"No key {key!r}: {self.reason}") - def __setitem__(self, *args, **kw): - raise KeyError("Cannot add variables: %s" % self.reason) + def __setitem__(self, *args, **kw): + raise KeyError("Cannot add variables: %s" % self.reason) - add = __setitem__ - setdefault = __setitem__ - update = __setitem__ + add = __setitem__ + setdefault = __setitem__ + update = __setitem__ - def __delitem__(self, *args, **kw): - raise KeyError("No keys to delete: %s" % self.reason) + def __delitem__(self, *args, **kw): + raise KeyError("No keys to delete: %s" % self.reason) - clear = __delitem__ - pop = __delitem__ - popitem = __delitem__ + clear = __delitem__ + pop = __delitem__ + popitem = __delitem__ - def get(self, key, default=None): + getone = __getitem__ + + @overload + def get(self, key: str, default: None = None) -> None: ... + + @overload + def get(self, key: str, default: _T) -> _T: ... + + def get(self, key: str, default: _T | None = None) -> _T | None: return default - def getall(self, key): + def getall(self, key: str) -> list[str]: return [] - def getone(self, key): - return self[key] - - def mixed(self): + def mixed(self) -> dict[str, str | list[str]]: return {} - dict_of_lists = mixed + def dict_of_lists(self) -> dict[str, list[str]]: + return {} # pragma: no cover - def __contains__(self, key): + def __contains__(self, key: object) -> Literal[False]: return False has_key = __contains__ - def copy(self): + def copy(self) -> Self: return self - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}: {self.reason}>" - def __len__(self): + def __len__(self) -> Literal[0]: return 0 - def keys(self): + def keys(self) -> Iterator[str]: + return iter([]) + + def items(self) -> Iterator[tuple[str, str]]: return iter([]) - items = keys values = keys __iter__ = keys -def _hide_passwd(items): +def _hide_passwd( + items: Iterable[tuple[object, object]], +) -> Iterator[tuple[object, object]]: for k, v in items: - if "password" in k or "passwd" in k or "pwd" in k: + if isinstance(k, str) and ("password" in k or "passwd" in k or "pwd" in k): yield k, "******" else: yield k, v diff --git a/src/webob/py.typed b/src/webob/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/webob/request.py b/src/webob/request.py index ee52a7d1..fb812222 100644 --- a/src/webob/request.py +++ b/src/webob/request.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import binascii import io import mimetypes @@ -5,6 +7,16 @@ import re import sys import tempfile +from typing import ( + IO, + TYPE_CHECKING, + Any, + ClassVar, + Literal, + Protocol, + TypeVar, + overload, +) from urllib import parse as urlparse from urllib.parse import quote as url_quote, quote_plus, urlencode as url_encode import warnings @@ -40,21 +52,72 @@ from webob.multidict import GetDict, MultiDict, NestedMultiDict, NoVars from webob.util import bytes_, parse_qsl_text, text_, url_unquote -try: - import simplejson as json -except ImportError: +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Mapping + import datetime import json + from re import Pattern + from typing import type_check_only + + from _typeshed import ( + OptExcInfo, + SupportsKeysAndGetItem, + SupportsNoArgReadline, + SupportsRead, + SupportsWrite, + WriteableBuffer, + ) + from _typeshed.wsgi import WSGIApplication, WSGIEnvironment + from typing_extensions import Self + + from webob.acceptparse import ( + _AcceptCharsetProperty, + _AcceptEncodingProperty, + _AcceptLanguageProperty, + _AcceptProperty, + ) + from webob.byterange import Range + from webob.cachecontrol import RequestCacheControl + from webob.client import SendRequest + from webob.descriptors import _authorization + from webob.etag import IfRangeDate + from webob.response import Response + from webob.types import ( + AsymmetricProperty, + AsymmetricPropertyWithDelete, + HTTPMethod, + ListOrTuple, + RequestCacheControlDict, + SymmetricProperty, + SymmetricPropertyWithDelete, + _FieldStorageWithFile, + ) + + _T = TypeVar("_T") + _SupportsWriteT = TypeVar("_SupportsWriteT", bound=SupportsWrite[bytes]) + + @type_check_only + class _SupportsReadAndNoArgReadline( + SupportsRead["str | bytes"], SupportsNoArgReadline["str | bytes"], Protocol + ): + pass + +else: + try: + import simplejson as json + except ImportError: + import json __all__ = ["BaseRequest", "Request"] class _NoDefault: - def __repr__(self): + def __repr__(self) -> str: return "(No Default)" -NoDefault = _NoDefault() +NoDefault: _NoDefault = _NoDefault() DEFAULT = object() PATH_SAFE = "/~!$&'()*+,;=:@" @@ -73,16 +136,19 @@ def __repr__(self): "8859", ) +_send_request_app: SendRequest + class BaseRequest: # The limit after which request bodies should be stored on disk # if they are read in (under this, and the request body is stored # in memory): - request_body_tempfile_limit = 10 * 1024 + request_body_tempfile_limit: ClassVar[int] = 10 * 1024 + environ: WSGIEnvironment - _charset = None + _charset: str | None = None - def __init__(self, environ, **kw): + def __init__(self, environ: WSGIEnvironment, **kw: Any) -> None: if type(environ) is not dict: raise TypeError(f"WSGI environ must be a dict; you passed {environ!r}") @@ -101,7 +167,15 @@ def __init__(self, environ, **kw): raise TypeError(f"Unexpected keyword: {name}={value!r}") setattr(self, name, value) - def encget(self, key, default=NoDefault, encattr=None): + @overload + def encget(self, key: str, default: _T, encattr: str | None = None) -> str | _T: ... + + @overload + def encget(self, key: str, *, encattr: str | None = None) -> str: ... + + def encget( + self, key: str, default: Any = NoDefault, encattr: str | None = None + ) -> Any: val = self.environ.get(key, default) if val is NoDefault: @@ -119,7 +193,7 @@ def encget(self, key, default=NoDefault, encattr=None): return bytes_(val, "latin-1").decode(encoding) - def encset(self, key, val, encattr=None): + def encset(self, key: str, val: str, encattr: str | None = None) -> None: if encattr: encoding = getattr(self, encattr) else: @@ -127,7 +201,7 @@ def encset(self, key, val, encattr=None): self.environ[key] = bytes_(val, encoding).decode("latin-1") @property - def charset(self): + def charset(self) -> str | None: if self._charset is None: charset = detect_charset(self._content_type_raw) @@ -137,19 +211,24 @@ def charset(self): return self._charset - @charset.setter - def charset(self, charset): - if _is_utf8(charset): - charset = "UTF-8" + # NOTE: The setter is deprecated and doesn't do anything except emit + # a warning, so we should pretend it doesn't exist + if not TYPE_CHECKING: + + @charset.setter + def charset(self, charset): + if _is_utf8(charset): + charset = "UTF-8" - if charset != self.charset: - raise DeprecationWarning("Use req = req.decode(%r)" % charset) + if charset != self.charset: + raise DeprecationWarning("Use req = req.decode(%r)" % charset) - def decode(self, charset=None, errors="strict"): + def decode(self, charset: str | None = None, errors: str = "strict") -> Self: charset = charset or self.charset if charset == "UTF-8": return self + assert charset is not None # cookies and path are always utf-8 t = Transcoder(charset, errors) @@ -172,7 +251,7 @@ def decode(self, charset=None, errors="strict"): fs_environ.setdefault("CONTENT_LENGTH", "0") fs_environ["QUERY_STRING"] = "" fs = cgi_FieldStorage( - fp=self.body_file, + fp=self.body_file, # type: ignore[arg-type] environ=fs_environ, keep_blank_values=True, encoding=charset, @@ -194,7 +273,7 @@ def decode(self, charset=None, errors="strict"): _setattr_stacklevel = 2 @property - def body_file(self): + def body_file(self) -> SupportsRead[bytes]: """ Input stream of the request (wsgi.input). Setting this property resets the content_length and seekable flag @@ -226,7 +305,7 @@ def body_file(self): return r @body_file.setter - def body_file(self, value): + def body_file(self, value: SupportsRead[bytes]) -> None: if isinstance(value, bytes): raise ValueError("Excepted fileobj but received bytes.") @@ -236,13 +315,13 @@ def body_file(self, value): self.is_body_readable = True @body_file.deleter - def body_file(self): + def body_file(self) -> None: self.body = b"" - body_file_raw = environ_getter("wsgi.input") + body_file_raw: SymmetricProperty[SupportsRead[bytes]] = environ_getter("wsgi.input") @property - def body_file_seekable(self): + def body_file_seekable(self) -> IO[bytes]: """ Get the body of the request (wsgi.input) as a seekable file-like object. Middleware and routing applications should use this @@ -254,25 +333,37 @@ def body_file_seekable(self): if not self.is_body_seekable: self.make_body_seekable() - return self.body_file_raw + return self.body_file_raw # type: ignore[return-value] - url_encoding = environ_getter("webob.url_encoding", "UTF-8") - scheme = environ_getter("wsgi.url_scheme") - method = environ_getter("REQUEST_METHOD", "GET") - http_version = environ_getter("SERVER_PROTOCOL") + url_encoding: AsymmetricPropertyWithDelete[str, str | None] = environ_getter( + "webob.url_encoding", "UTF-8" + ) + scheme: SymmetricProperty[str] = environ_getter("wsgi.url_scheme") + method: AsymmetricPropertyWithDelete[HTTPMethod, HTTPMethod | None] = ( + environ_getter("REQUEST_METHOD", "GET") + ) + http_version: SymmetricProperty[str] = environ_getter("SERVER_PROTOCOL") content_length = converter( environ_getter("CONTENT_LENGTH", None, "14.13"), parse_int_safe, serialize_int, "int", ) - remote_user = environ_getter("REMOTE_USER", None) - remote_host = environ_getter("REMOTE_HOST", None) - remote_addr = environ_getter("REMOTE_ADDR", None) - query_string = environ_getter("QUERY_STRING", "") - server_name = environ_getter("SERVER_NAME") - server_port = converter( - environ_getter("SERVER_PORT"), parse_int, serialize_int, "int" + remote_user: SymmetricPropertyWithDelete[str | None] = environ_getter( + "REMOTE_USER", None + ) + remote_host: SymmetricPropertyWithDelete[str | None] = environ_getter( + "REMOTE_HOST", None + ) + remote_addr: SymmetricPropertyWithDelete[str | None] = environ_getter( + "REMOTE_ADDR", None + ) + query_string: AsymmetricPropertyWithDelete[str, str | None] = environ_getter( + "QUERY_STRING", "" + ) + server_name: SymmetricProperty[str] = environ_getter("SERVER_NAME") + server_port: SymmetricProperty[int] = converter( + environ_getter("SERVER_PORT"), parse_int, serialize_int, "int" # type: ignore[arg-type] ) script_name = environ_decoder("SCRIPT_NAME", "", encattr="url_encoding") @@ -282,9 +373,11 @@ def body_file_seekable(self): uscript_name = script_name upath_info = path_info - _content_type_raw = environ_getter("CONTENT_TYPE", "") + _content_type_raw: AsymmetricPropertyWithDelete[str, str | None] = environ_getter( + "CONTENT_TYPE", "" + ) - def _content_type__get(self): + def _content_type__get(self) -> str: """Return the content type, but leaving off any parameters (like charset, but also things like the type in ``application/atom+xml; type=entry``) @@ -296,7 +389,7 @@ def _content_type__get(self): return self._content_type_raw.split(";", 1)[0] - def _content_type__set(self, value=None): + def _content_type__set(self, value: str | None = None) -> None: if value is not None: value = str(value) @@ -307,16 +400,16 @@ def _content_type__set(self, value=None): value += ";" + content_type.split(";", 1)[1] self._content_type_raw = value - content_type = property( + content_type: AsymmetricPropertyWithDelete[str, str | None] = property( _content_type__get, _content_type__set, _content_type__set, _content_type__get.__doc__, ) - _headers = None + _headers: EnvironHeaders | None = None - def _headers__get(self): + def _headers__get(self) -> EnvironHeaders: """ All the request headers as a case-insensitive dictionary-like object. @@ -327,14 +420,18 @@ def _headers__get(self): return self._headers - def _headers__set(self, value): + def _headers__set( + self, value: SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]] + ) -> None: self.headers.clear() self.headers.update(value) - headers = property(_headers__get, _headers__set, doc=_headers__get.__doc__) + headers: AsymmetricProperty[ + EnvironHeaders, SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]] + ] = property(_headers__get, _headers__set, doc=_headers__get.__doc__) @property - def client_addr(self): + def client_addr(self) -> str | None: """ The effective client IP address as a string. If the ``HTTP_X_FORWARDED_FOR`` header exists in the WSGI environ, this @@ -356,17 +453,17 @@ def client_addr(self): must be behind a trusted proxy for this to be true. """ e = self.environ - xff = e.get("HTTP_X_FORWARDED_FOR") + xff: str | None = e.get("HTTP_X_FORWARDED_FOR") if xff is not None: - addr = xff.split(",")[0].strip() + addr: str | None = xff.split(",")[0].strip() else: addr = e.get("REMOTE_ADDR") return addr @property - def host_port(self): + def host_port(self) -> str: """ The effective server port number as a string. If the ``HTTP_HOST`` header exists in the WSGI environ, this attribute returns the port @@ -378,7 +475,7 @@ def host_port(self): ``SERVER_PORT`` header (which is guaranteed to be present). """ e = self.environ - host = e.get("HTTP_HOST") + host: str | None = e.get("HTTP_HOST") if host is not None: if ":" in host and host[-1] != "]": @@ -396,12 +493,12 @@ def host_port(self): return port @property - def host_url(self): + def host_url(self) -> str: """ The URL through the host (no path) """ e = self.environ - scheme = e.get("wsgi.url_scheme") + scheme: str = e["wsgi.url_scheme"] url = scheme + "://" host = e.get("HTTP_HOST") @@ -428,7 +525,7 @@ def host_url(self): return url @property - def application_url(self): + def application_url(self) -> str: """ The URL including SCRIPT_NAME (no PATH_INFO or query string) """ @@ -437,7 +534,7 @@ def application_url(self): return self.host_url + url_quote(bscript_name, PATH_SAFE) @property - def path_url(self): + def path_url(self) -> str: """ The URL including SCRIPT_NAME and PATH_INFO, but not QUERY_STRING """ @@ -446,7 +543,7 @@ def path_url(self): return self.application_url + url_quote(bpath_info, PATH_SAFE) @property - def path(self): + def path(self) -> str: """ The path of the request, without host or query string """ @@ -456,7 +553,7 @@ def path(self): return url_quote(bscript, PATH_SAFE) + url_quote(bpath, PATH_SAFE) @property - def path_qs(self): + def path_qs(self) -> str: """ The path of the request, without host but with query string """ @@ -469,7 +566,7 @@ def path_qs(self): return path @property - def url(self): + def url(self) -> str: """ The full request URL, including QUERY_STRING """ @@ -481,7 +578,7 @@ def url(self): return url - def relative_url(self, other_url, to_application=False): + def relative_url(self, other_url: str, to_application: bool = False) -> str: """ Resolve other_url relative to the request URL. @@ -499,7 +596,7 @@ def relative_url(self, other_url, to_application=False): return urlparse.urljoin(url, other_url) - def path_info_pop(self, pattern=None): + def path_info_pop(self, pattern: Pattern[str] | None = None) -> str | None: """ 'Pops' off the next segment of PATH_INFO, pushing it onto SCRIPT_NAME, and returning the popped segment. Returns None if @@ -532,8 +629,9 @@ def path_info_pop(self, pattern=None): self.path_info = path[idx:] return r + return None - def path_info_peek(self): + def path_info_peek(self) -> str | None: """ Returns the next segment on PATH_INFO, or None if there is no next segment. Doesn't modify the environment. @@ -546,7 +644,7 @@ def path_info_peek(self): return path.split("/", 1)[0] - def _urlvars__get(self): + def _urlvars__get(self) -> dict[str, str]: """ Return any *named* variables matched in the URL. @@ -555,16 +653,16 @@ def _urlvars__get(self): """ if "paste.urlvars" in self.environ: - return self.environ["paste.urlvars"] + return self.environ["paste.urlvars"] # type: ignore[no-any-return] elif "wsgiorg.routing_args" in self.environ: - return self.environ["wsgiorg.routing_args"][1] + return self.environ["wsgiorg.routing_args"][1] # type: ignore[no-any-return] else: - result = {} + result: dict[str, str] = {} self.environ["wsgiorg.routing_args"] = ((), result) return result - def _urlvars__set(self, value): + def _urlvars__set(self, value: dict[str, str]) -> None: environ = self.environ if "wsgiorg.routing_args" in environ: @@ -580,7 +678,7 @@ def _urlvars__set(self, value): else: environ["wsgiorg.routing_args"] = ((), value) - def _urlvars__del(self): + def _urlvars__del(self) -> None: if "paste.urlvars" in self.environ: del self.environ["paste.urlvars"] @@ -593,11 +691,11 @@ def _urlvars__del(self): {}, ) - urlvars = property( + urlvars: SymmetricPropertyWithDelete[dict[str, str]] = property( _urlvars__get, _urlvars__set, _urlvars__del, doc=_urlvars__get.__doc__ ) - def _urlargs__get(self): + def _urlargs__get(self) -> tuple[str, ...]: """ Return any *positional* variables matched in the URL. @@ -606,14 +704,14 @@ def _urlargs__get(self): """ if "wsgiorg.routing_args" in self.environ: - return self.environ["wsgiorg.routing_args"][0] + return self.environ["wsgiorg.routing_args"][0] # type: ignore[no-any-return] else: # Since you can't update this value in-place, we don't need # to set the key in the environment return () - def _urlargs__set(self, value): + def _urlargs__set(self, value: tuple[str, ...]) -> None: environ = self.environ if "paste.urlvars" in environ: @@ -626,7 +724,7 @@ def _urlargs__set(self, value): routing_args = (value, {}) environ["wsgiorg.routing_args"] = routing_args - def _urlargs__del(self): + def _urlargs__del(self) -> None: if "wsgiorg.routing_args" in self.environ: if not self.environ["wsgiorg.routing_args"][1]: del self.environ["wsgiorg.routing_args"] @@ -636,12 +734,12 @@ def _urlargs__del(self): self.environ["wsgiorg.routing_args"][1], ) - urlargs = property( + urlargs: SymmetricPropertyWithDelete[tuple[str, ...]] = property( _urlargs__get, _urlargs__set, _urlargs__del, _urlargs__get.__doc__ ) @property - def is_xhr(self): + def is_xhr(self) -> bool: """Is X-Requested-With header present and equal to ``XMLHttpRequest``? Note: this isn't set by every XMLHttpRequest request, it is @@ -649,27 +747,29 @@ def is_xhr(self): (or you set the header yourself manually). Currently Prototype and jQuery are known to set this header.""" - return self.environ.get("HTTP_X_REQUESTED_WITH", "") == "XMLHttpRequest" + return self.environ.get("HTTP_X_REQUESTED_WITH", "") == "XMLHttpRequest" # type: ignore[no-any-return] - def _host__get(self): + def _host__get(self) -> str: """Host name provided in HTTP_HOST, with fall-back to SERVER_NAME""" if "HTTP_HOST" in self.environ: - return self.environ["HTTP_HOST"] + return self.environ["HTTP_HOST"] # type: ignore[no-any-return] else: return "%(SERVER_NAME)s:%(SERVER_PORT)s" % self.environ - def _host__set(self, value): + def _host__set(self, value: str) -> None: self.environ["HTTP_HOST"] = value - def _host__del(self): + def _host__del(self) -> None: if "HTTP_HOST" in self.environ: del self.environ["HTTP_HOST"] - host = property(_host__get, _host__set, _host__del, doc=_host__get.__doc__) + host: SymmetricPropertyWithDelete[str] = property( + _host__get, _host__set, _host__del, doc=_host__get.__doc__ + ) @property - def domain(self): + def domain(self) -> str: """Returns the domain portion of the host value. Equivalent to: .. code-block:: python @@ -694,8 +794,10 @@ def domain(self): return domain + # NOTE: Technically this should be an asymmetric property, since we're allowed + # to set this to `None`, but it doesn't seem worth the effort right now @property - def body(self): + def body(self) -> bytes: """ Return the content of the request body. """ @@ -704,13 +806,13 @@ def body(self): return b"" self.make_body_seekable() # we need this to have content_length - r = self.body_file.read(self.content_length) - self.body_file_raw.seek(0) + r = self.body_file.read(self.content_length) # type: ignore[arg-type] + self.body_file_raw.seek(0) # type: ignore[attr-defined] return r @body.setter - def body(self, value): + def body(self, value: bytes | None) -> None: if value is None: value = b"" @@ -723,23 +825,25 @@ def body(self, value): self.is_body_seekable = True @body.deleter - def body(self): + def body(self) -> None: self.body = b"" - def _json_body__get(self): + def _json_body__get(self) -> Any: """Access the body of the request as JSON""" - return json.loads(self.body.decode(self.charset)) + return json.loads(self.body.decode(self.charset)) # type: ignore[arg-type] - def _json_body__set(self, value): - self.body = json.dumps(value, separators=(",", ":")).encode(self.charset) + def _json_body__set(self, value: Any) -> None: + self.body = json.dumps(value, separators=(",", ":")).encode(self.charset) # type: ignore[arg-type] - def _json_body__del(self): + def _json_body__del(self) -> None: del self.body + json: SymmetricPropertyWithDelete[Any] + json_body: SymmetricPropertyWithDelete[Any] json = json_body = property(_json_body__get, _json_body__set, _json_body__del) - def _text__get(self): + def _text__get(self) -> str: """ Get/set the text value of the body """ @@ -750,7 +854,7 @@ def _text__get(self): return body.decode(self.charset) - def _text__set(self, value): + def _text__set(self, value: str) -> None: if not self.charset: raise AttributeError( "You cannot access Response.text unless charset is set" @@ -763,13 +867,15 @@ def _text__set(self, value): ) self.body = value.encode(self.charset) - def _text__del(self): + def _text__del(self) -> None: del self.body - text = property(_text__get, _text__set, _text__del, doc=_text__get.__doc__) + text: SymmetricPropertyWithDelete[str] = property( + _text__get, _text__set, _text__del, doc=_text__get.__doc__ + ) @property - def POST(self): + def POST(self) -> MultiDict[str, str | _FieldStorageWithFile] | NoVars: """ Return a MultiDict containing all the variables from a form request. Returns an empty dict-like object for non-form requests. @@ -779,6 +885,7 @@ def POST(self): """ env = self.environ + vars: MultiDict[str, str | _FieldStorageWithFile] if "webob._parsed_post_vars" in env: vars, body_file = env["webob._parsed_post_vars"] @@ -799,7 +906,7 @@ def POST(self): self._check_charset() self.make_body_seekable() - self.body_file_raw.seek(0) + self.body_file_raw.seek(0) # type: ignore[attr-defined] fs_environ = env.copy() # FieldStorage assumes a missing CONTENT_LENGTH, but a @@ -807,20 +914,20 @@ def POST(self): fs_environ.setdefault("CONTENT_LENGTH", "0") fs_environ["QUERY_STRING"] = "" fs = cgi_FieldStorage( - fp=self.body_file, + fp=self.body_file, # type: ignore[arg-type] environ=fs_environ, keep_blank_values=True, encoding="utf8", ) - self.body_file_raw.seek(0) + self.body_file_raw.seek(0) # type: ignore[attr-defined] vars = MultiDict.from_fieldstorage(fs) env["webob._parsed_post_vars"] = (vars, self.body_file_raw) return vars @property - def GET(self): + def GET(self) -> GetDict: """ Return a MultiDict containing all the variables from the QUERY_STRING. @@ -828,13 +935,14 @@ def GET(self): env = self.environ source = env.get("QUERY_STRING", "") + vars: GetDict if "webob._parsed_query_vars" in env: vars, qs = env["webob._parsed_query_vars"] if qs == source: return vars - data = [] + data: Iterable[tuple[str, str]] = [] if source: # this is disabled because we want to access req.GET @@ -848,7 +956,7 @@ def GET(self): return vars - def _check_charset(self): + def _check_charset(self) -> None: if self.charset != "UTF-8": raise DeprecationWarning( "Requests are expected to be submitted in UTF-8, not %s. " @@ -857,17 +965,21 @@ def _check_charset(self): ) @property - def params(self): + def params(self) -> NestedMultiDict[str, str | _FieldStorageWithFile]: """ A dictionary-like object containing both the parameters from the query string and request body. """ - params = NestedMultiDict(self.GET, self.POST) + params = NestedMultiDict(self.GET, self.POST) # type: ignore return params - @property - def cookies(self): + cookies: AsymmetricProperty[ + RequestCookies, SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]] + ] + + @property # type: ignore[no-redef] + def cookies(self) -> RequestCookies: """ Return a dictionary of cookies as found in the request. """ @@ -875,12 +987,14 @@ def cookies(self): return RequestCookies(self.environ) @cookies.setter - def cookies(self, val): + def cookies( + self, val: SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]] + ) -> None: self.environ.pop("HTTP_COOKIE", None) r = RequestCookies(self.environ) r.update(val) - def copy(self): + def copy(self) -> Self: """ Copy the request and environment object. @@ -893,7 +1007,7 @@ def copy(self): return new_req - def copy_get(self): + def copy_get(self) -> Self: """ Copies the request and environment object, but turning this request into a GET along the way. If this was a POST request (or any other @@ -905,10 +1019,11 @@ def copy_get(self): # webob.is_body_seekable marks input streams that are seekable # this way we can have seekable input without testing the .seek() method + is_body_seekable: AsymmetricPropertyWithDelete[bool, bool | None] is_body_seekable = environ_getter("webob.is_body_seekable", False) @property - def is_body_readable(self): + def is_body_readable(self) -> bool: """ webob.is_body_readable is a flag that tells us that we can read the input stream even though CONTENT_LENGTH is missing. @@ -924,7 +1039,7 @@ def is_body_readable(self): # self.body_file with something that is readable and EOF's # correctly. - return self.environ.get( + return self.environ.get( # type: ignore[no-any-return] "wsgi.input_terminated", # For backwards compatibility, we fall back to checking if # webob.is_body_readable is set in the environ @@ -934,10 +1049,10 @@ def is_body_readable(self): return False @is_body_readable.setter - def is_body_readable(self, flag): + def is_body_readable(self, flag: bool) -> None: self.environ["wsgi.input_terminated"] = bool(flag) - def make_body_seekable(self): + def make_body_seekable(self) -> None: """ This forces ``environ['wsgi.input']`` to be seekable. That means that, the content is copied into a BytesIO or temporary @@ -952,11 +1067,11 @@ def make_body_seekable(self): """ if self.is_body_seekable: - self.body_file_raw.seek(0) + self.body_file_raw.seek(0) # type: ignore[attr-defined] else: self.copy_body() - def copy_body(self): + def copy_body(self) -> None: """ Copies the body, in cases where it might be shared with another request object and that is not desired. @@ -969,13 +1084,13 @@ def copy_body(self): # Before we copy, if we can, rewind the body file if self.is_body_seekable: - self.body_file_raw.seek(0) + self.body_file_raw.seek(0) # type: ignore[attr-defined] tempfile_limit = self.request_body_tempfile_limit todo = self.content_length if self.content_length is not None else 65535 newbody = b"" - fileobj = None + fileobj: io.BufferedRandom | None = None input = self.body_file while todo > 0: @@ -1043,7 +1158,7 @@ def copy_body(self): # cheap. self.body = b"" - def make_tempfile(self): + def make_tempfile(self) -> io.BufferedRandom: """ Create a tempfile to store big request body. This API is not stable yet. A 'size' argument might be added. @@ -1053,11 +1168,11 @@ def make_tempfile(self): def remove_conditional_headers( self, - remove_encoding=True, - remove_range=True, - remove_match=True, - remove_modified=True, - ): + remove_encoding: bool = True, + remove_range: bool = True, + remove_match: bool = True, + remove_modified: bool = True, + ) -> None: """ Remove headers that make the request conditional. @@ -1085,20 +1200,25 @@ def remove_conditional_headers( if key in self.environ: del self.environ[key] - accept = accept_property() - accept_charset = accept_charset_property() - accept_encoding = accept_encoding_property() - accept_language = accept_language_property() - - authorization = converter( - environ_getter("HTTP_AUTHORIZATION", None, "14.8"), parse_auth, serialize_auth + accept: _AcceptProperty = accept_property() + accept_charset: _AcceptCharsetProperty = accept_charset_property() + accept_encoding: _AcceptEncodingProperty = accept_encoding_property() + accept_language: _AcceptLanguageProperty = accept_language_property() + + authorization: AsymmetricPropertyWithDelete[ + _authorization | None, tuple[str, str | dict[str, str]] | list[Any] | str | None + ] = converter( + environ_getter("HTTP_AUTHORIZATION", None, "14.8"), + parse_auth, + serialize_auth, # type: ignore[arg-type] ) - def _cache_control__get(self): + def _cache_control__get(self) -> RequestCacheControl: """ Get/set/modify the Cache-Control header (`HTTP spec section 14.9 `_) """ + cache_obj: RequestCacheControl | None env = self.environ value = env.get("HTTP_CACHE_CONTROL", "") cache_header, cache_obj = env.get("webob._cache_control", (None, None)) @@ -1112,12 +1232,14 @@ def _cache_control__get(self): return cache_obj - def _cache_control__set(self, value): + def _cache_control__set( + self, value: RequestCacheControl | RequestCacheControlDict | str | None + ) -> None: env = self.environ value = value or "" if isinstance(value, dict): - value = CacheControl(value, type="request") + value = CacheControl(value, type="request") # type: ignore[arg-type] if isinstance(value, CacheControl): str_value = str(value) @@ -1127,7 +1249,7 @@ def _cache_control__set(self, value): env["HTTP_CACHE_CONTROL"] = str(value) env["webob._cache_control"] = (None, None) - def _cache_control__del(self): + def _cache_control__del(self) -> None: env = self.environ if "HTTP_CACHE_CONTROL" in env: @@ -1136,10 +1258,12 @@ def _cache_control__del(self): if "webob._cache_control" in env: del env["webob._cache_control"] - def _update_cache_control(self, prop_dict): + def _update_cache_control(self, prop_dict: dict[str, Any]) -> None: self.environ["HTTP_CACHE_CONTROL"] = serialize_cache_control(prop_dict) - cache_control = property( + cache_control: AsymmetricPropertyWithDelete[ + RequestCacheControl, RequestCacheControl | RequestCacheControlDict | str | None + ] = property( _cache_control__get, _cache_control__set, _cache_control__del, @@ -1156,35 +1280,46 @@ def _update_cache_control(self, prop_dict): if_unmodified_since = converter_date( environ_getter("HTTP_IF_UNMODIFIED_SINCE", None, "14.28") ) - if_range = converter( + if_range: AsymmetricPropertyWithDelete[ + IfRange | IfRangeDate, + IfRange | IfRangeDate | datetime.datetime | datetime.date | str | None, + ] = converter( environ_getter("HTTP_IF_RANGE", None, "14.27"), IfRange.parse, - serialize_if_range, + serialize_if_range, # type: ignore[arg-type] "IfRange object", ) - max_forwards = converter( + max_forwards: SymmetricPropertyWithDelete[int | None] = converter( environ_getter("HTTP_MAX_FORWARDS", None, "14.31"), parse_int, serialize_int, "int", ) - pragma = environ_getter("HTTP_PRAGMA", None, "14.32") + pragma: SymmetricPropertyWithDelete[str | None] = environ_getter( + "HTTP_PRAGMA", None, "14.32" + ) - range = converter( + range: AsymmetricPropertyWithDelete[ + Range | None, tuple[int, int | None] | list[int | None] | list[int] | str | None + ] = converter( environ_getter("HTTP_RANGE", None, "14.35"), parse_range, - serialize_range, + serialize_range, # type: ignore[arg-type] "Range object", ) - referer = environ_getter("HTTP_REFERER", None, "14.36") + referer: SymmetricPropertyWithDelete[str | None] = environ_getter( + "HTTP_REFERER", None, "14.36" + ) referrer = referer - user_agent = environ_getter("HTTP_USER_AGENT", None, "14.43") + user_agent: SymmetricPropertyWithDelete[str | None] = environ_getter( + "HTTP_USER_AGENT", None, "14.43" + ) - def __repr__(self): + def __repr__(self) -> str: try: name = f"{self.method} {self.url}" except KeyError: @@ -1193,7 +1328,7 @@ def __repr__(self): return msg - def as_bytes(self, skip_body=False): + def as_bytes(self, skip_body: bool = False) -> bytes: """ Return HTTP bytes representing this request. If skip_body is True, exclude the body. @@ -1230,15 +1365,15 @@ def as_bytes(self, skip_body=False): return b"\r\n".join(parts) - def as_text(self, skip_body=False): + def as_text(self, skip_body: bool = False) -> str: bytes = self.as_bytes(skip_body) - return bytes.decode(self.charset) + return bytes.decode(self.charset) # type: ignore[arg-type] __str__ = as_text @classmethod - def from_bytes(cls, b): + def from_bytes(cls, b: bytes) -> Self: """ Create a request from HTTP bytes data. If the bytes contain extra data after the request, raise a ValueError. @@ -1252,13 +1387,13 @@ def from_bytes(cls, b): return r @classmethod - def from_text(cls, s): + def from_text(cls, s: str) -> Self: b = bytes_(s, "utf-8") return cls.from_bytes(b) @classmethod - def from_file(cls, fp): + def from_file(cls, fp: _SupportsReadAndNoArgReadline) -> Self: """Read a request from a file-like object (it must implement ``.read(size)`` and ``.readline()``). @@ -1272,6 +1407,8 @@ def from_file(cls, fp): start_line = fp.readline() is_text = isinstance(start_line, str) + crlf: str | bytes + colon: str | bytes if is_text: crlf = "\r\n" colon = ":" @@ -1279,7 +1416,7 @@ def from_file(cls, fp): crlf = b"\r\n" colon = b":" try: - header = start_line.rstrip(crlf) + header = start_line.rstrip(crlf) # type: ignore[arg-type] method, resource, http_version = header.split(None, 2) method = text_(method, "utf-8") resource = text_(resource, "utf-8") @@ -1298,7 +1435,7 @@ def from_file(cls, fp): # end of headers break - hname, hval = line.split(colon, 1) + hname, hval = line.split(colon, 1) # type: ignore[arg-type] hname = text_(hname, "utf-8") hval = text_(hval, "utf-8").strip() @@ -1315,11 +1452,34 @@ def from_file(cls, fp): if is_text: body = bytes_(body, "utf-8") - r.body = body + r.body = body # type: ignore[assignment] return r - def call_application(self, application, catch_exc_info=False): + @overload + def call_application( + self, application: WSGIApplication, catch_exc_info: Literal[False] = False + ) -> tuple[str, list[tuple[str, str]], Iterable[bytes]]: ... + + @overload + def call_application( + self, application: WSGIApplication, catch_exc_info: Literal[True] + ) -> tuple[str, list[tuple[str, str]], Iterable[bytes], OptExcInfo | None]: ... + + @overload + def call_application( + self, application: WSGIApplication, catch_exc_info: bool + ) -> ( + tuple[str, list[tuple[str, str]], Iterable[bytes], OptExcInfo | None] + | tuple[str, list[tuple[str, str]], Iterable[bytes]] + ): ... + + def call_application( + self, application: WSGIApplication, catch_exc_info: bool = False + ) -> ( + tuple[str, list[tuple[str, str]], Iterable[bytes], OptExcInfo | None] + | tuple[str, list[tuple[str, str]], Iterable[bytes]] + ): """ Call the given WSGI application, returning ``(status_string, headerlist, app_iter)`` @@ -1334,15 +1494,22 @@ def call_application(self, application, catch_exc_info=False): """ if self.is_body_seekable: - self.body_file_raw.seek(0) - captured = [] - output = [] + self.body_file_raw.seek(0) # type: ignore[attr-defined] + captured: tuple[str, list[tuple[str, str]], OptExcInfo | None] + output: list[bytes] = [] + + def start_response( + status: str, + headers: list[tuple[str, str]], + exc_info: OptExcInfo | None = None, + ) -> Callable[[bytes], object]: - def start_response(status, headers, exc_info=None): + nonlocal captured if exc_info is not None and not catch_exc_info: etype, exc, tb = exc_info - raise etype(exc).with_traceback(tb) - captured[:] = [status, headers, exc_info] + if etype is not None: + raise etype(exc).with_traceback(tb) + captured = (status, headers, exc_info) return output.append @@ -1362,9 +1529,14 @@ def start_response(status, headers, exc_info=None): return (captured[0], captured[1], app_iter) # Will be filled in later: - ResponseClass = None + ResponseClass: type[Response] + if not TYPE_CHECKING: + # FIXME: We probably don't need this assignment to None + ResponseClass = None - def send(self, application=None, catch_exc_info=False): + def send( + self, application: WSGIApplication | None = None, catch_exc_info: bool = False + ) -> Response: """ Like ``.call_application(application)``, except returns a response object with ``.status``, ``.headers``, and ``.body`` @@ -1396,19 +1568,29 @@ def send(self, application=None, catch_exc_info=False): get_response = send - def make_default_send_app(self): - global _client + def make_default_send_app(self) -> SendRequest: + global _send_request_app try: - client = _client + send_request_app = _send_request_app except NameError: - from webob import client + from webob.client import send_request_app - _client = client + _send_request_app = send_request_app - return client.send_request_app + return send_request_app @classmethod - def blank(cls, path, environ=None, base_url=None, headers=None, POST=None, **kw): + def blank( + cls, + path: str, + environ: dict[str, None] | None = None, + base_url: str | None = None, + headers: Mapping[str, str] | None = None, + POST: ( + str | bytes | Mapping[Any, Any] | Mapping[Any, ListOrTuple[Any]] | None + ) = None, + **kw: Any, + ) -> Self: """ Create a blank request environ (and Request wrapper) with the given path (path should be urlencoded), and any keys from @@ -1474,7 +1656,7 @@ def blank(cls, path, environ=None, base_url=None, headers=None, POST=None, **kw) class AdhocAttrMixin: _setattr_stacklevel = 3 - def __setattr__(self, attr, value, DEFAULT=DEFAULT): + def __setattr__(self, attr: str, value: Any, DEFAULT: object = DEFAULT) -> None: if getattr(self.__class__, attr, DEFAULT) is not DEFAULT or attr.startswith( "_" ): @@ -1482,13 +1664,13 @@ def __setattr__(self, attr, value, DEFAULT=DEFAULT): else: self.environ.setdefault("webob.adhoc_attrs", {})[attr] = value - def __getattr__(self, attr, DEFAULT=DEFAULT): + def __getattr__(self, attr: str) -> Any: try: return self.environ["webob.adhoc_attrs"][attr] except KeyError: raise AttributeError(attr) - def __delattr__(self, attr, DEFAULT=DEFAULT): + def __delattr__(self, attr: str, DEFAULT: object = DEFAULT) -> None: if getattr(self.__class__, attr, DEFAULT) is not DEFAULT: return object.__delattr__(self, attr) try: @@ -1501,7 +1683,7 @@ class Request(AdhocAttrMixin, BaseRequest): """The default request implementation""" -def environ_from_url(path): +def environ_from_url(path: str) -> WSGIEnvironment: if SCHEME_RE.search(path): scheme, netloc, path, qs, fragment = urlparse.urlsplit(path) @@ -1550,7 +1732,11 @@ def environ_from_url(path): return env -def environ_add_POST(env, data, content_type=None): +def environ_add_POST( + env: WSGIEnvironment, + data: str | bytes | Mapping[Any, Any] | Mapping[Any, ListOrTuple[Any]] | None, + content_type: str | None = None, +) -> None: if data is None: return elif isinstance(data, str): @@ -1561,9 +1747,9 @@ def environ_add_POST(env, data, content_type=None): has_files = False if hasattr(data, "items"): - data = list(data.items()) + data = list(data.items()) # type: ignore[assignment] - for _, v in data: + for _, v in data: # type: ignore[misc] if isinstance(v, (tuple, list)) or hasattr(v, "filename"): has_files = True @@ -1609,25 +1795,25 @@ class DisconnectionError(IOError): class LimitedLengthFile(io.RawIOBase): - def __init__(self, file, maxlen): + def __init__(self, file: SupportsRead[bytes], maxlen: int) -> None: self.file = file self.maxlen = maxlen self.remaining = maxlen - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}({self.file!r}, maxlen={self.maxlen})>" - def fileno(self): - return self.file.fileno() + def fileno(self) -> int: + return self.file.fileno() # type: ignore[attr-defined, no-any-return] @staticmethod - def readable(): + def readable() -> Literal[True]: return True - def readinto(self, buff): + def readinto(self, buff: WriteableBuffer) -> int: if not self.remaining: return 0 - sz0 = min(len(buff), self.remaining) + sz0 = min(len(buff), self.remaining) # type: ignore[arg-type] data = self.file.read(sz0) sz = len(data) self.remaining -= sz @@ -1637,24 +1823,41 @@ def readinto(self, buff): "The client disconnected while sending the body " "(%d more bytes were expected)" % (self.remaining,) ) - buff[:sz] = data + buff[:sz] = data # type: ignore[index] return sz -def _get_multipart_boundary(ctype): +def _get_multipart_boundary(ctype: str) -> str | None: m = re.search(r"boundary=([^ ]+)", ctype, re.I) if m: return text_(m.group(1).strip('"')) + return None + + +@overload +def _encode_multipart( + vars: Iterable[tuple[str, Any]], content_type: str +) -> tuple[str, bytes]: ... + + +@overload +def _encode_multipart( + vars: Iterable[tuple[str, Any]], content_type: str, fout: _SupportsWriteT +) -> tuple[str, _SupportsWriteT]: ... -def _encode_multipart(vars, content_type, fout=None): +def _encode_multipart( + vars: Iterable[tuple[str, Any]], + content_type: str, + fout: SupportsWrite[bytes] | None = None, +) -> tuple[str, Any]: """Encode a multipart request body into a string""" - f = fout or io.BytesIO() + f = io.BytesIO() if fout is None else fout w = f.write - def wt(t): + def wt(t: str) -> None: w(t.encode("utf8")) CRLF = b"\r\n" @@ -1674,8 +1877,8 @@ def wt(t): wt('; name="%s"' % name) filename = None - if getattr(value, "filename", None): - filename = value.filename + if filename := getattr(value, "filename", None): + filename = filename elif isinstance(value, (list, tuple)): filename, value = value @@ -1714,20 +1917,22 @@ def wt(t): w(CRLF) wt("--%s--" % boundary) - if fout: + if fout is not None: return content_type, fout else: + assert isinstance(f, io.BytesIO) return content_type, f.getvalue() -def detect_charset(ctype): +def detect_charset(ctype: str) -> str | None: m = CHARSET_RE.search(ctype) if m: return m.group(1).strip('"').strip() + return None -def _is_utf8(charset): +def _is_utf8(charset: str | None) -> bool: if not charset: return True else: @@ -1735,12 +1940,12 @@ def _is_utf8(charset): class Transcoder: - def __init__(self, charset, errors="strict"): + def __init__(self, charset: str, errors: str = "strict") -> None: self.charset = charset # source charset self.errors = errors # unicode errors self._trans = lambda b: b.decode(charset, errors).encode("utf8") - def transcode_query(self, q): + def transcode_query(self, q: str) -> str: q_orig = q if "=" not in q: @@ -1748,13 +1953,13 @@ def transcode_query(self, q): return q_orig - q = list(parse_qsl_text(q, self.charset)) + q_list = list(parse_qsl_text(q, self.charset)) - return url_encode(q) + return url_encode(q_list) - def transcode_fs(self, fs, content_type): + def transcode_fs(self, fs: cgi_FieldStorage, content_type: str) -> io.BytesIO: # transcode FieldStorage - def decode(b): + def decode(b: str) -> str: return b data = [] diff --git a/src/webob/response.py b/src/webob/response.py index 5fffbfd5..8bdffac6 100644 --- a/src/webob/response.py +++ b/src/webob/response.py @@ -1,14 +1,21 @@ +from __future__ import annotations + from base64 import b64encode from datetime import datetime, timedelta from hashlib import md5 import re import struct +from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar, overload from urllib import parse as urlparse from urllib.parse import quote as url_quote import zlib from webob.byterange import ContentRange -from webob.cachecontrol import CacheControl, serialize_cache_control +from webob.cachecontrol import ( + CacheControl, + ResponseCacheControl, + serialize_cache_control, +) from webob.cookies import Cookie, make_cookie from webob.datetime_utils import ( parse_date_delta, @@ -42,11 +49,37 @@ warn_deprecation, ) -try: - import simplejson as json -except ImportError: +if TYPE_CHECKING: + from collections.abc import Callable, Collection, Iterable, Iterator, Sequence import json + from _typeshed import OptExcInfo, SupportsItems, SupportsRead + from _typeshed.wsgi import StartResponse, WSGIApplication, WSGIEnvironment + from typing_extensions import Self + + from webob.cookies import _SameSitePolicy + from webob.descriptors import ( + _authorization, + _ContentRangeParams, + _DateProperty, + ) + from webob.request import Request + from webob.types import ( + AsymmetricProperty, + AsymmetricPropertyWithDelete, + ResponseCacheControlDict, + ResponseCacheExpires, + SymmetricProperty, + SymmetricPropertyWithDelete, + ) + + _ResponseT = TypeVar("_ResponseT", bound="Response") +else: + try: + import simplejson as json + except ImportError: + import json + __all__ = ["Response"] @@ -158,8 +191,8 @@ class Response: # constructor they correctly get saved and set, however they are not used # by any part of the Response. See commit # 627593bbcd4ab52adc7ee569001cdda91c670d5d for rationale. - request = None - environ = None + request: Request | None = None + environ: WSGIEnvironment | None = None # # __init__, from_file, copy @@ -167,15 +200,15 @@ class Response: def __init__( self, - body=None, - status=None, - headerlist=None, - app_iter=None, - content_type=None, - conditional_response=None, - charset=_marker, - **kw, - ): + body: bytes | str | None = None, + status: int | str | bytes | None = None, + headerlist: list[tuple[str, str]] | None = None, + app_iter: Iterable[bytes] | None = None, + content_type: str | None = None, + conditional_response: bool | None = None, + charset: str = _marker, # type: ignore[assignment] + **kw: Any, + ) -> None: # Do some sanity checking, and turn json_body into an actual body if app_iter is None and body is None and ("json_body" in kw or "json" in kw): @@ -202,7 +235,7 @@ def __init__( self.status = status # Initialize headers - self._headers = None + self._headers: ResponseHeaders | None = None if headerlist is None: self._headerlist = [] @@ -296,6 +329,7 @@ def __init__( # Set up app_iter if the HTTP Status code has a body if app_iter is None and code_has_body: + assert body is not None if isinstance(body, str): # Fall back to trying self.charset if encoding is not set. In # most cases encoding will be set to the default value. @@ -318,6 +352,7 @@ def __init__( elif app_iter is None and not code_has_body: app_iter = [b""] + assert app_iter is not None self._app_iter = app_iter # Loop through all the remaining keyword arguments @@ -329,7 +364,7 @@ def __init__( setattr(self, name, value) @classmethod - def from_file(cls, fp): + def from_file(cls, fp: IO[str] | IO[bytes]) -> Response: """Reads a response from a file-like object (it must implement ``.read(size)`` and ``.readline()``). @@ -343,6 +378,8 @@ def from_file(cls, fp): status = fp.readline().strip() is_text = isinstance(status, str) + _colon: str | bytes + _http: str | bytes if is_text: _colon = ":" _http = "HTTP/" @@ -350,7 +387,7 @@ def from_file(cls, fp): _colon = b":" _http = b"HTTP/" - if status.startswith(_http): + if status.startswith(_http): # type: ignore[arg-type] (http_ver, status_num, status_text) = status.split(None, 2) status = f"{text_(status_num)} {text_(status_text)}" @@ -362,7 +399,7 @@ def from_file(cls, fp): break try: - header_name, value = line.split(_colon, 1) + header_name, value = line.split(_colon, 1) # type: ignore[arg-type] except ValueError: raise ValueError("Bad header line: %r" % line) value = value.strip() @@ -371,13 +408,13 @@ def from_file(cls, fp): body = fp.read(r.content_length or 0) if is_text: - r.text = body + r.text = body # type: ignore[assignment] else: - r.body = body + r.body = body # type: ignore[assignment] return r - def copy(self): + def copy(self) -> Response: """Makes a copy of the response.""" # we need to do this for app_iter to be reusable app_iter = list(self._app_iter) @@ -396,10 +433,10 @@ def copy(self): # __repr__, __str__ # - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} at 0x{abs(id(self)):x} {self.status}>" - def __str__(self, skip_body=False): + def __str__(self, skip_body: bool = False) -> str: parts = [self.status] if not skip_body: @@ -416,14 +453,14 @@ def __str__(self, skip_body=False): # status, status_code/status_int # - def _status__get(self): + def _status__get(self) -> str: """ The status string. """ return self._status - def _status__set(self, value): + def _status__set(self, value: int | str | bytes) -> None: try: code = int(value) except (ValueError, TypeError): @@ -451,21 +488,25 @@ def _status__set(self, value): raise ValueError("Invalid status code, integer required.") self._status = value - status = property(_status__get, _status__set, doc=_status__get.__doc__) + status: AsymmetricProperty[str, int | str | bytes] = property( + _status__get, _status__set, doc=_status__get.__doc__ + ) - def _status_code__get(self): + def _status_code__get(self) -> int: """ The status as an integer. """ return int(self._status.split()[0]) - def _status_code__set(self, code): + def _status_code__set(self, code: int) -> None: try: self._status = "%d %s" % (code, status_reasons[code]) except KeyError: self._status = "%d %s" % (code, status_generic_reasons[code // 100]) + status_code: SymmetricProperty[int] + status_int: SymmetricProperty[int] status_code = status_int = property( _status_code__get, _status_code__set, doc=_status_code__get.__doc__ ) @@ -474,14 +515,16 @@ def _status_code__set(self, code): # headerslist, headers # - def _headerlist__get(self): + def _headerlist__get(self) -> list[tuple[str, str]]: """ The list of response headers. """ return self._headerlist - def _headerlist__set(self, value): + def _headerlist__set( + self, value: Iterable[tuple[str, str]] | SupportsItems[str, str] + ) -> None: self._headers = None if not isinstance(value, list): @@ -490,17 +533,19 @@ def _headerlist__set(self, value): value = list(value) self._headerlist = value - def _headerlist__del(self): + def _headerlist__del(self) -> None: self.headerlist = [] - headerlist = property( + headerlist: AsymmetricPropertyWithDelete[ + list[tuple[str, str]], Iterable[tuple[str, str]] | SupportsItems[str, str] + ] = property( _headerlist__get, _headerlist__set, _headerlist__del, doc=_headerlist__get.__doc__, ) - def _headers__get(self): + def _headers__get(self) -> ResponseHeaders: """ The headers in a dictionary-like object. """ @@ -510,19 +555,23 @@ def _headers__get(self): return self._headers - def _headers__set(self, value): + def _headers__set( + self, value: Iterable[tuple[str, str]] | SupportsItems[str, str] + ) -> None: if hasattr(value, "items"): value = value.items() self.headerlist = value self._headers = None - headers = property(_headers__get, _headers__set, doc=_headers__get.__doc__) + headers: AsymmetricProperty[ + ResponseHeaders, SupportsItems[str, str] | Iterable[tuple[str, str]] + ] = property(_headers__get, _headers__set, doc=_headers__get.__doc__) # # body # - def _body__get(self): + def _body__get(self) -> bytes: """ The body of the response, as a :class:`bytes`. This will read in the entire app_iter if necessary. @@ -535,7 +584,7 @@ def _body__get(self): # pass if isinstance(app_iter, list) and len(app_iter) == 1: - return app_iter[0] + return app_iter[0] # type: ignore[no-any-return] if app_iter is None: raise AttributeError("No body has been set") @@ -560,9 +609,9 @@ def _body__get(self): return body - def _body__set(self, value=b""): + def _body__set(self, value: bytes = b"") -> None: if not isinstance(value, bytes): - if isinstance(value, str): + if isinstance(value, str): # type: ignore[unreachable] msg = ( "You cannot set Response.body to a text object " "(use Response.text)" @@ -582,9 +631,11 @@ def _body__set(self, value=b""): # self.body = '' # #self.content_length = None - body = property(_body__get, _body__set, _body__set) + body: SymmetricPropertyWithDelete[bytes] = property( + _body__get, _body__set, _body__set + ) - def _json_body__get(self): + def _json_body__get(self) -> Any: """ Set/get the body of the response as JSON. @@ -600,15 +651,18 @@ def _json_body__get(self): return json.loads(self.body.decode("UTF-8")) - def _json_body__set(self, value): + def _json_body__set(self, value: Any) -> None: self.body = json.dumps(value, separators=(",", ":")).encode("UTF-8") - def _json_body__del(self): + def _json_body__del(self) -> None: del self.body + json: SymmetricPropertyWithDelete[Any] + json_body: SymmetricPropertyWithDelete[Any] json = json_body = property(_json_body__get, _json_body__set, _json_body__del) - def _has_body__get(self): + @property + def has_body(self) -> bool: """ Determine if the the response has a :attr:`~Response.body`. In contrast to simply accessing :attr:`~Response.body`, this method @@ -624,17 +678,15 @@ def _has_body__get(self): return False if app_iter is None: # pragma: no cover - return False + return False # type: ignore[unreachable] return True - has_body = property(_has_body__get) - # # text, unicode_body, ubody # - def _text__get(self): + def _text__get(self) -> str: """ Get/set the text value of the body using the ``charset`` of the ``Content-Type`` or the ``default_body_encoding``. @@ -650,7 +702,7 @@ def _text__get(self): return body.decode(decoding, self.unicode_errors) - def _text__set(self, value): + def _text__set(self, value: str) -> None: if not self.charset and not self.default_body_encoding: raise AttributeError( "You cannot access Response.text unless charset or " @@ -665,11 +717,15 @@ def _text__set(self, value): encoding = self.charset or self.default_body_encoding self.body = value.encode(encoding) - def _text__del(self): + def _text__del(self) -> None: del self.body - text = property(_text__get, _text__set, _text__del, doc=_text__get.__doc__) + text: SymmetricPropertyWithDelete[str] = property( + _text__get, _text__set, _text__del, doc=_text__get.__doc__ + ) + unicode_body: SymmetricPropertyWithDelete[str] + ubody: SymmetricPropertyWithDelete[str] unicode_body = ubody = property( _text__get, _text__set, _text__del, "Deprecated alias for .text" ) @@ -678,7 +734,7 @@ def _text__del(self): # body_file, write(text) # - def _body_file__get(self): + def _body_file__get(self) -> ResponseBodyFile: """ A file-like object that can be used to write to the body. If you passed in a list ``app_iter``, that ``app_iter`` will be @@ -687,20 +743,21 @@ def _body_file__get(self): return ResponseBodyFile(self) - def _body_file__set(self, file): + def _body_file__set(self, file: SupportsRead[bytes]) -> None: self.app_iter = iter_file(file) - def _body_file__del(self): + def _body_file__del(self) -> None: del self.body + body_file: AsymmetricPropertyWithDelete[ResponseBodyFile, SupportsRead[bytes]] body_file = property( _body_file__get, _body_file__set, _body_file__del, doc=_body_file__get.__doc__ ) - def write(self, text): + def write(self, text: str | bytes) -> int: if not isinstance(text, bytes): if not isinstance(text, str): - msg = "You can only write str to a Response.body_file, not %s" + msg = "You can only write str to a Response.body_file, not %s" # type: ignore[unreachable] raise TypeError(msg % type(text)) if not self.charset: @@ -728,7 +785,7 @@ def write(self, text): # app_iter # - def _app_iter__get(self): + def _app_iter__get(self) -> Iterable[bytes]: """ Returns the ``app_iter`` of the response. @@ -738,17 +795,17 @@ def _app_iter__get(self): return self._app_iter - def _app_iter__set(self, value): + def _app_iter__set(self, value: Iterable[bytes]) -> None: if self._app_iter is not None: # Undo the automatically-set content-length self.content_length = None self._app_iter = value - def _app_iter__del(self): + def _app_iter__del(self) -> None: self._app_iter = [] self.content_length = None - app_iter = property( + app_iter: SymmetricPropertyWithDelete[Iterable[bytes]] = property( _app_iter__get, _app_iter__set, _app_iter__del, doc=_app_iter__get.__doc__ ) @@ -772,10 +829,12 @@ def _app_iter__del(self): content_disposition = header_getter("Content-Disposition", "19.5.1") accept_ranges = header_getter("Accept-Ranges", "14.5") - content_range = converter( + content_range: AsymmetricPropertyWithDelete[ + ContentRange | None, _ContentRangeParams + ] = converter( header_getter("Content-Range", "14.16"), parse_content_range, - serialize_content_range, + serialize_content_range, # type: ignore[arg-type] "ContentRange object", ) @@ -784,37 +843,40 @@ def _app_iter__del(self): last_modified = date_header("Last-Modified", "14.29") _etag_raw = header_getter("ETag", "14.19") - etag = converter( - _etag_raw, parse_etag_response, serialize_etag_response, "Entity tag" + etag: AsymmetricPropertyWithDelete[ + str | None, tuple[str, bool] | str | None + ] = converter( + _etag_raw, parse_etag_response, serialize_etag_response, "Entity tag" # type: ignore[arg-type] ) @property - def etag_strong(self): + def etag_strong(self) -> str | None: return parse_etag_response(self._etag_raw, strong=True) location = header_getter("Location", "14.30") pragma = header_getter("Pragma", "14.32") age = converter(header_getter("Age", "14.6"), parse_int_safe, serialize_int, "int") - - retry_after = converter( + retry_after: _DateProperty = converter( header_getter("Retry-After", "14.37"), parse_date_delta, - serialize_date_delta, + serialize_date_delta, # type: ignore[arg-type] "HTTP date or delta seconds", ) server = header_getter("Server", "14.38") # TODO: the standard allows this to be a list of challenges - www_authenticate = converter( - header_getter("WWW-Authenticate", "14.47"), parse_auth, serialize_auth + www_authenticate: AsymmetricPropertyWithDelete[ + _authorization | None, tuple[str, str | dict[str, str]] | list[Any] | str | None + ] = converter( + header_getter("WWW-Authenticate", "14.47"), parse_auth, serialize_auth # type: ignore[arg-type] ) # # charset # - def _charset__get(self): + def _charset__get(self) -> str | None: """ Get/set the ``charset`` specified in ``Content-Type``. @@ -832,7 +894,7 @@ def _charset__get(self): return None - def _charset__set(self, charset): + def _charset__set(self, charset: str | None) -> None: if charset is None: self._charset__del() @@ -850,7 +912,7 @@ def _charset__set(self, charset): header += "; charset=%s" % charset self.headers["Content-Type"] = header - def _charset__del(self): + def _charset__del(self) -> None: header = self.headers.pop("Content-Type", None) if header is None: @@ -863,7 +925,7 @@ def _charset__del(self): header = header[: match.start()] + header[match.end() :] self.headers["Content-Type"] = header - charset = property( + charset: SymmetricPropertyWithDelete[str | None] = property( _charset__get, _charset__set, _charset__del, doc=_charset__get.__doc__ ) @@ -871,7 +933,7 @@ def _charset__del(self): # content_type # - def _content_type__get(self): + def _content_type__get(self) -> str | None: """ Get/set the ``Content-Type`` header. If no ``Content-Type`` header is set, this will return ``None``. @@ -900,7 +962,7 @@ def _content_type__get(self): return header.split(";", 1)[0] - def _content_type__set(self, value): + def _content_type__set(self, value: str | None) -> None: if not value: self._content_type__del() @@ -933,10 +995,10 @@ def _content_type__set(self, value): self.headers["Content-Type"] = content_type - def _content_type__del(self): + def _content_type__del(self) -> None: self.headers.pop("Content-Type", None) - content_type = property( + content_type: SymmetricPropertyWithDelete[str | None] = property( _content_type__get, _content_type__set, _content_type__del, @@ -947,7 +1009,7 @@ def _content_type__del(self): # content_type_params # - def _content_type_params__get(self): + def _content_type_params__get(self) -> dict[str, str]: """ A dictionary of all the parameters in the content type. @@ -966,7 +1028,9 @@ def _content_type_params__get(self): return result - def _content_type_params__set(self, value_dict): + def _content_type_params__set( + self, value_dict: SupportsItems[str, str] | None + ) -> None: if not value_dict: self._content_type_params__del() @@ -982,12 +1046,14 @@ def _content_type_params__set(self, value_dict): ct += "".join(params) self.headers["Content-Type"] = ct - def _content_type_params__del(self): + def _content_type_params__del(self) -> None: self.headers["Content-Type"] = self.headers.get("Content-Type", "").split( ";", 1 )[0] - content_type_params = property( + content_type_params: AsymmetricPropertyWithDelete[ + dict[str, str], SupportsItems[str, str] | None + ] = property( _content_type_params__get, _content_type_params__set, _content_type_params__del, @@ -1000,17 +1066,17 @@ def _content_type_params__del(self): def set_cookie( self, - name, - value="", - max_age=None, - path="/", - domain=None, - secure=False, - httponly=False, - comment=None, - overwrite=False, - samesite=None, - ): + name: str | bytes, + value: str | bytes | None = "", + max_age: int | timedelta | None = None, + path: str = "/", + domain: str | None = None, + secure: bool = False, + httponly: bool = False, + comment: str | None = None, + overwrite: bool = False, + samesite: _SameSitePolicy | None = None, + ) -> None: """ Set (add) a cookie for the response. @@ -1097,7 +1163,9 @@ def set_cookie( ) self.headerlist.append(("Set-Cookie", cookie)) - def delete_cookie(self, name, path="/", domain=None): + def delete_cookie( + self, name: str | bytes, path: str = "/", domain: str | None = None + ) -> None: """ Delete a cookie from the client. Note that ``path`` and ``domain`` must match how the cookie was originally set. @@ -1107,7 +1175,7 @@ def delete_cookie(self, name, path="/", domain=None): """ self.set_cookie(name, None, path=path, domain=domain) - def unset_cookie(self, name, strict=True): + def unset_cookie(self, name: str | bytes, strict: bool = True) -> None: """ Unset a cookie with the given name (remove it from the response). """ @@ -1132,7 +1200,15 @@ def unset_cookie(self, name, strict=True): elif strict: raise KeyError("No cookie has been set with the name %r" % name) - def merge_cookies(self, resp): + @overload + def merge_cookies(self, resp: _ResponseT) -> _ResponseT: ... + + @overload + def merge_cookies(self, resp: WSGIApplication) -> WSGIApplication: ... + + def merge_cookies( + self, resp: Response | WSGIApplication + ) -> Response | WSGIApplication: """Merge the cookies that were set on this response with the given ``resp`` object (which can be any WSGI application). @@ -1151,11 +1227,16 @@ def merge_cookies(self, resp): else: c_headers = [h for h in self.headerlist if h[0].lower() == "set-cookie"] - def repl_app(environ, start_response): - def repl_start_response(status, headers, exc_info=None): - return start_response( - status, headers + c_headers, exc_info=exc_info - ) + def repl_app( + environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: + + def repl_start_response( + status: str, + headers: list[tuple[str, str]], + exc_info: OptExcInfo | None = None, + ) -> Callable[[bytes], object]: + return start_response(status, headers + c_headers, exc_info) return resp(environ, repl_start_response) @@ -1165,9 +1246,9 @@ def repl_start_response(status, headers, exc_info=None): # cache_control # - _cache_control_obj = None + _cache_control_obj: ResponseCacheControl | None = None - def _cache_control__get(self): + def _cache_control__get(self) -> ResponseCacheControl: """ Get/set/modify the Cache-Control header (`HTTP spec section 14.9 `_). @@ -1188,14 +1269,16 @@ def _cache_control__get(self): return self._cache_control_obj - def _cache_control__set(self, value): + def _cache_control__set( + self, value: ResponseCacheControl | ResponseCacheControlDict | str | None + ) -> None: # This actually becomes a copy if not value: value = "" if isinstance(value, dict): - value = CacheControl(value, "response") + value = CacheControl(value, "response") # type: ignore[arg-type] if isinstance(value, str): value = str(value) @@ -1205,15 +1288,15 @@ def _cache_control__set(self, value): self.headers["Cache-Control"] = value return - value = CacheControl.parse(value, "response") + value = CacheControl.parse(value, type="response") cache = self.cache_control cache.properties.clear() cache.properties.update(value.properties) - def _cache_control__del(self): + def _cache_control__del(self) -> None: self.cache_control = {} - def _update_cache_control(self, prop_dict): + def _update_cache_control(self, prop_dict: dict[str, Any]) -> None: value = serialize_cache_control(prop_dict) if not value: @@ -1222,7 +1305,10 @@ def _update_cache_control(self, prop_dict): else: self.headers["Cache-Control"] = value - cache_control = property( + cache_control: AsymmetricProperty[ + ResponseCacheControl, + ResponseCacheControl | ResponseCacheControlDict | str | None, + ] = property( _cache_control__get, _cache_control__set, _cache_control__del, @@ -1233,7 +1319,7 @@ def _update_cache_control(self, prop_dict): # cache_expires # - def _cache_expires(self, seconds=0, **kw): + def _cache_expires(self, seconds: int | timedelta = 0, **kw: Any) -> None: """ Set expiration on this request. This sets the response to expire in the given seconds, and any other attributes are used @@ -1257,8 +1343,6 @@ def _cache_expires(self, seconds=0, **kw): cache_control.no_cache = True cache_control.must_revalidate = True cache_control.max_age = 0 - cache_control.post_check = 0 - cache_control.pre_check = 0 self.expires = datetime.utcnow() if "last-modified" not in self.headers: @@ -1273,13 +1357,17 @@ def _cache_expires(self, seconds=0, **kw): for name, value in kw.items(): setattr(cache_control, name, value) - cache_expires = property(lambda self: self._cache_expires, _cache_expires) + cache_expires: AsymmetricProperty[ + ResponseCacheExpires, timedelta | int | bool | None + ] = property(lambda self: self._cache_expires, _cache_expires) # # encode_content, decode_content, md5_etag # - def encode_content(self, encoding="gzip", lazy=False): + def encode_content( + self, encoding: Literal["gzip", "identity"] = "gzip", lazy: bool = False + ) -> None: """ Encode the content with the given encoding (only ``gzip`` and ``identity`` are supported). @@ -1302,7 +1390,7 @@ def encode_content(self, encoding="gzip", lazy=False): self.content_length = sum(map(len, self._app_iter)) self.content_encoding = "gzip" - def decode_content(self): + def decode_content(self) -> None: content_encoding = self.content_encoding or "identity" if content_encoding == "identity": @@ -1333,7 +1421,9 @@ def decode_content(self): self.body = zlib.decompress(self.body, -15) self.content_encoding = None - def md5_etag(self, body=None, set_content_md5=False): + def md5_etag( + self, body: bytes | None = None, set_content_md5: bool = False + ) -> None: """ Generate an etag for the response object using an MD5 hash of the body (the ``body`` parameter, or ``self.body`` if not given). @@ -1348,14 +1438,14 @@ def md5_etag(self, body=None, set_content_md5=False): md5_digest = md5(body).digest() md5_digest = b64encode(md5_digest) md5_digest = md5_digest.replace(b"\n", b"") - md5_digest = text_(md5_digest) - self.etag = md5_digest.strip("=") + md5_digest_str = text_(md5_digest) + self.etag = md5_digest_str.strip("=") if set_content_md5: - self.content_md5 = md5_digest + self.content_md5 = md5_digest_str @staticmethod - def _make_location_absolute(environ, value): + def _make_location_absolute(environ: WSGIEnvironment, value: str) -> str: if SCHEME_RE.search(value): return value @@ -1368,7 +1458,7 @@ def _make_location_absolute(environ, value): return new_location - def _abs_headerlist(self, environ): + def _abs_headerlist(self, environ: WSGIEnvironment) -> list[tuple[str, str]]: # Build the headerlist, if we have a Location header, make it absolute return [ @@ -1384,7 +1474,9 @@ def _abs_headerlist(self, environ): # __call__, conditional_response_app # - def __call__(self, environ, start_response): + def __call__( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: """ WSGI application interface """ @@ -1405,7 +1497,9 @@ def __call__(self, environ, start_response): _safe_methods = ("GET", "HEAD") - def conditional_response_app(self, environ, start_response): + def conditional_response_app( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: """ Like the normal ``__call__`` interface, but checks conditional headers: @@ -1463,12 +1557,13 @@ def conditional_response_app(self, environ, start_response): return [body] else: + assert content_range.start is not None app_iter = self.app_iter_range(content_range.start, content_range.stop) if app_iter is not None: # the following should be guaranteed by # Range.range_for_length(length) - assert content_range.start is not None + assert content_range.stop is not None headerlist = [ ( "Content-Length", @@ -1490,7 +1585,7 @@ def conditional_response_app(self, environ, start_response): return self._app_iter - def app_iter_range(self, start, stop): + def app_iter_range(self, start: int, stop: int | None) -> AppIterRange: """ Return a new ``app_iter`` built from the response ``app_iter``, that serves up only the given ``start:stop`` range. @@ -1498,16 +1593,21 @@ def app_iter_range(self, start, stop): app_iter = self._app_iter if hasattr(app_iter, "app_iter_range"): - return app_iter.app_iter_range(start, stop) + return app_iter.app_iter_range(start, stop) # type: ignore[no-any-return] return AppIterRange(app_iter, start, stop) -def filter_headers(hlist, remove_headers=("content-length", "content-type")): +def filter_headers( + hlist: Iterable[tuple[str, str]], + remove_headers: Collection[str] = ("content-length", "content-type"), +) -> list[tuple[str, str]]: return [h for h in hlist if (h[0].lower() not in remove_headers)] -def iter_file(file, block_size=1 << 18): # 256Kb +def iter_file( + file: SupportsRead[bytes], block_size: int = 1 << 18 # 256Kb +) -> Iterator[bytes]: while True: data = file.read(block_size) @@ -1517,25 +1617,25 @@ def iter_file(file, block_size=1 << 18): # 256Kb class ResponseBodyFile: - mode = "wb" - closed = False + mode: Literal["wb"] = "wb" + closed: Literal[False] = False - def __init__(self, response): + def __init__(self, response: Response) -> None: """ Represents a :class:`~Response` as a file like object. """ self.response = response self.write = response.write - def __repr__(self): + def __repr__(self) -> str: return "" % self.response - encoding = property( - lambda self: self.response.charset, - doc="The encoding of the file (inherited from response.charset)", - ) + @property + def encoding(self) -> str | None: + """The encoding of the file (inherited from response.charset)""" + return self.response.charset # pragma: no cover - def writelines(self, seq): + def writelines(self, seq: Sequence[str | bytes]) -> None: """ Write a sequence of lines to the response. """ @@ -1543,13 +1643,13 @@ def writelines(self, seq): for item in seq: self.write(item) - def close(self): + def close(self) -> None: raise NotImplementedError("Response bodies cannot be closed") - def flush(self): + def flush(self) -> None: pass - def tell(self): + def tell(self) -> int: """ Provide the current location where we are going to start writing. """ @@ -1565,7 +1665,7 @@ class AppIterRange: Wraps an ``app_iter``, returning just a range of bytes. """ - def __init__(self, app_iter, start, stop): + def __init__(self, app_iter: Iterable[bytes], start: int, stop: int | None) -> None: assert start >= 0, "Bad start: %r" % start assert stop is None or (stop >= 0 and stop >= start), "Bad stop: %r" % stop self.app_iter = iter(app_iter) @@ -1573,10 +1673,10 @@ def __init__(self, app_iter, start, stop): self.start = start self.stop = stop - def __iter__(self): + def __iter__(self) -> Self: return self - def _skip_start(self): + def _skip_start(self) -> bytes: start, stop = self.start, self.stop for chunk in self.app_iter: @@ -1597,7 +1697,7 @@ def _skip_start(self): else: raise StopIteration() - def next(self): + def next(self) -> bytes: if self._pos < self.start: # need to skip some leading bytes @@ -1617,7 +1717,7 @@ def next(self): __next__ = next # py3 - def close(self): + def close(self) -> None: iter_close(self.app_iter) @@ -1629,23 +1729,23 @@ class EmptyResponse: method to close an underlying ``app_iter`` it replaces. """ - def __init__(self, app_iter=None): + def __init__(self, app_iter: Iterable[bytes] | None = None) -> None: if app_iter is not None and hasattr(app_iter, "close"): self.close = app_iter.close - def __iter__(self): + def __iter__(self) -> Self: return self - def __len__(self): + def __len__(self) -> Literal[0]: return 0 - def next(self): + def next(self) -> bytes: raise StopIteration() __next__ = next # py3 -def _is_xml(content_type): +def _is_xml(content_type: str) -> bool: return ( content_type.startswith("application/xml") or (content_type.startswith("application/") and content_type.endswith("+xml")) @@ -1653,15 +1753,15 @@ def _is_xml(content_type): ) -def _content_type_has_charset(content_type): +def _content_type_has_charset(content_type: str) -> bool: return content_type.startswith("text/") or _is_xml(content_type) -def _request_uri(environ): +def _request_uri(environ: WSGIEnvironment) -> str: """Like ``wsgiref.url.request_uri``, except eliminates ``:80`` ports. Returns the full request URI.""" - url = environ["wsgi.url_scheme"] + "://" + url: str = environ["wsgi.url_scheme"] + "://" if environ.get("HTTP_HOST"): url += environ["HTTP_HOST"] @@ -1687,12 +1787,12 @@ def _request_uri(environ): return url -def iter_close(iter): +def iter_close(iter: Iterable[bytes]) -> None: if hasattr(iter, "close"): iter.close() -def gzip_app_iter(app_iter): +def gzip_app_iter(app_iter: Iterable[bytes]) -> Iterator[bytes]: size = 0 crc = zlib.crc32(b"") & 0xFFFFFFFF compress = zlib.compressobj( diff --git a/src/webob/static.py b/src/webob/static.py index 36631e2d..aeba9652 100644 --- a/src/webob/static.py +++ b/src/webob/static.py @@ -1,13 +1,26 @@ +from __future__ import annotations + import mimetypes import os +from typing import IO, TYPE_CHECKING, Any from webob import exc from webob.dec import wsgify from webob.response import Response +if TYPE_CHECKING: + from collections.abc import Iterator + + from _typeshed import StrPath + from _typeshed.wsgi import WSGIApplication + + from webob.request import Request + + __all__ = ["FileApp", "DirectoryApp"] -mimetypes._winreg = None # do not load mimetypes from windows registry +# do not load mimetypes from windows registry +mimetypes._winreg = None # type: ignore[attr-defined] mimetypes.add_type( "text/javascript", ".js" ) # stdlib default is application/x-javascript @@ -22,7 +35,7 @@ class FileApp: Adds a mime type based on `mimetypes.guess_type()`. """ - def __init__(self, filename, **kw): + def __init__(self, filename: StrPath, **kw: Any) -> None: self.filename = filename content_type, content_encoding = mimetypes.guess_type(filename) kw.setdefault("content_type", content_type) @@ -33,7 +46,7 @@ def __init__(self, filename, **kw): self._open = open @wsgify - def __call__(self, req): + def __call__(self, req: Request) -> WSGIApplication: if req.method not in ("GET", "HEAD"): return exc.HTTPMethodNotAllowed("You cannot %s a file" % req.method) try: @@ -63,10 +76,15 @@ def __call__(self, req): class FileIter: - def __init__(self, file): + def __init__(self, file: IO[bytes]) -> None: self.file = file - def app_iter_range(self, seek=None, limit=None, block_size=None): + def app_iter_range( + self, + seek: int | None = None, + limit: int | None = None, + block_size: int | None = None, + ) -> Iterator[bytes]: """Iter over the content of the file. You can set the `seek` parameter to read the file starting from a @@ -117,8 +135,12 @@ class DirectoryApp: """ def __init__( - self, path, index_page="index.html", hide_index_with_redirect=False, **kw - ): + self, + path: StrPath, + index_page: str = "index.html", + hide_index_with_redirect: bool = False, + **kw: Any, + ) -> None: self.path = os.path.abspath(path) if not self.path.endswith(os.path.sep): self.path += os.path.sep @@ -128,11 +150,11 @@ def __init__( self.hide_index_with_redirect = hide_index_with_redirect self.fileapp_kw = kw - def make_fileapp(self, path): + def make_fileapp(self, path: StrPath) -> FileApp: return FileApp(path, **self.fileapp_kw) @wsgify - def __call__(self, req): + def __call__(self, req: Request) -> Response | FileApp: path = os.path.abspath(os.path.join(self.path, req.path_info.lstrip("/"))) if os.path.isdir(path) and self.index_page: return self.index(req, path) @@ -153,7 +175,7 @@ def __call__(self, req): else: return self.make_fileapp(path) - def index(self, req, path): + def index(self, req: Request, path: StrPath) -> Response | FileApp: index_path = os.path.join(path, self.index_page) if not os.path.isfile(index_path): return exc.HTTPNotFound(comment=index_path) diff --git a/src/webob/types.py b/src/webob/types.py new file mode 100644 index 00000000..51cea251 --- /dev/null +++ b/src/webob/types.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import IO, TYPE_CHECKING, Literal, Protocol, TypedDict, TypeVar, overload + +if TYPE_CHECKING: + from typing import type_check_only + + from typing_extensions import TypeAlias + + from webob.compat import cgi_FieldStorage + + # NOTE: the field storage objects we expose always contain a file + @type_check_only + class _FieldStorageWithFile(cgi_FieldStorage): + file: IO[bytes] + filename: str + + +T = TypeVar("T") +GetterReturnType_co = TypeVar("GetterReturnType_co", covariant=True) +SetterValueType_contra = TypeVar("SetterValueType_contra", contravariant=True) + + +class AsymmetricProperty(Protocol[GetterReturnType_co, SetterValueType_contra]): + @overload + def __get__(self, obj: None, type: type[object] | None = ..., /) -> property: ... + @overload + def __get__( + self, obj: object, type: type[object] | None = ..., / + ) -> GetterReturnType_co: ... + + def __set__(self, obj: object, value: SetterValueType_contra, /) -> None: + pass + + +class AsymmetricPropertyWithDelete( + AsymmetricProperty[GetterReturnType_co, SetterValueType_contra], + Protocol[GetterReturnType_co, SetterValueType_contra], +): + def __delete__(self, obj: object, /) -> None: + pass + + +SymmetricProperty: TypeAlias = AsymmetricProperty[T, T] +SymmetricPropertyWithDelete: TypeAlias = AsymmetricPropertyWithDelete[T, T] + +HTTPMethod: TypeAlias = Literal[ + "GET", + "HEAD", + "POST", + "PUT", + "DELETE", + "CONNECT", + "OPTIONS", + "TRACE", + "PATCH", +] +ListOrTuple: TypeAlias = "list[T] | tuple[T, ...]" + + +class RequestCacheControlDict(TypedDict, total=False): + max_stale: int + min_stale: int + only_if_cached: bool + no_cache: Literal[True] | str + no_store: bool + no_transform: bool + max_age: int + + +class ResponseCacheControlDict(TypedDict, total=False): + public: bool + private: Literal[True] | str + no_cache: Literal[True] | str + no_store: bool + no_transform: bool + must_revalidate: bool + proxy_revalidate: bool + max_age: int + s_maxage: int + s_max_age: int + stale_while_revalidate: int + stale_if_error: int + + +class ResponseCacheExpires(Protocol): + def __call__( + self, + seconds: int | timedelta = 0, + *, + public: bool = ..., + private: Literal[True] | str = ..., + no_cache: Literal[True] | str = ..., + no_store: bool = ..., + no_transform: bool = ..., + must_revalidate: bool = ..., + proxy_revalidate: bool = ..., + max_age: int = ..., + s_maxage: int = ..., + s_max_age: int = ..., + stale_while_revalidate: int = ..., + stale_if_error: int = ..., + ) -> None: ... diff --git a/src/webob/util.py b/src/webob/util.py index d26358e3..28f82185 100644 --- a/src/webob/util.py +++ b/src/webob/util.py @@ -1,10 +1,16 @@ +from __future__ import annotations + +from html import escape +from typing import TYPE_CHECKING, overload import warnings -from webob.compat import escape from webob.headers import _trans_key +if TYPE_CHECKING: + from collections.abc import Iterator + -def unquote(string): +def unquote(string: bytes) -> bytes: if not string: return b"" res = string.split(b"%") @@ -18,40 +24,62 @@ def unquote(string): return string -def url_unquote(s): +def url_unquote(s: str) -> str: return unquote(s.encode("ascii")).decode("latin-1") -def parse_qsl_text(qs, encoding="utf-8"): - qs = qs.encode("latin-1") - qs = qs.replace(b"+", b" ") - pairs = [s2 for s1 in qs.split(b"&") for s2 in s1.split(b";") if s2] +def parse_qsl_text(qs: str, encoding: str = "utf-8") -> Iterator[tuple[str, str]]: + qsb = qs.encode("latin-1") + qsb = qsb.replace(b"+", b" ") + pairs = [s2 for s1 in qsb.split(b"&") for s2 in s1.split(b";") if s2] for name_value in pairs: nv = name_value.split(b"=", 1) if len(nv) != 2: - nv.append("") + nv.append(b"") name = unquote(nv[0]) value = unquote(nv[1]) yield (name.decode(encoding), value.decode(encoding)) -def text_(s, encoding="latin-1", errors="strict"): +@overload +def text_(s: str | bytes, encoding: str = "latin-1", errors: str = "strict") -> str: ... + + +@overload +def text_(s: None, encoding: str = "latin-1", errors: str = "strict") -> None: ... + + +def text_( + s: str | bytes | None, encoding: str = "latin-1", errors: str = "strict" +) -> str | None: if isinstance(s, bytes): return str(s, encoding, errors) return s -def bytes_(s, encoding="latin-1", errors="strict"): +@overload +def bytes_( + s: str | bytes, encoding: str = "latin-1", errors: str = "strict" +) -> bytes: ... + + +@overload +def bytes_(s: None, encoding: str = "latin-1", errors: str = "strict") -> None: ... + + +def bytes_( + s: str | bytes | None, encoding: str = "latin-1", errors: str = "strict" +) -> bytes | None: if isinstance(s, str): return s.encode(encoding, errors) return s -def html_escape(s): +def html_escape(s: object) -> str: """HTML-escape a string or object This converts any non-string objects passed into it to strings @@ -67,7 +95,7 @@ def html_escape(s): __html__ = getattr(s, "__html__", None) if __html__ is not None and callable(__html__): - return s.__html__() + return __html__() # type: ignore[no-any-return] if not isinstance(s, str): s = str(s) @@ -79,9 +107,10 @@ def html_escape(s): return text_(s) -def header_docstring(header, rfc_section): +def header_docstring(header: str, rfc_section: str) -> str: if header.isupper(): - header = _trans_key(header) + # FIXME: What should we do when this returns `None`? + header = _trans_key(header) # type: ignore[assignment] major_section = rfc_section.split(".")[0] link = "http://www.w3.org/Protocols/rfc2616/rfc2616-sec{}.html#sec{}".format( major_section, @@ -95,7 +124,7 @@ def header_docstring(header, rfc_section): ) -def warn_deprecation(text, version, stacklevel): +def warn_deprecation(text: str, version: str, stacklevel: int) -> None: # version specifies when to start raising exceptions instead of warnings if version in ("1.2", "1.3", "1.4", "1.5", "1.6", "1.7"): @@ -105,7 +134,7 @@ def warn_deprecation(text, version, stacklevel): warnings.warn(text, cls, stacklevel=stacklevel + 1) -status_reasons = { +status_reasons: dict[int, str] = { # Status Codes # Informational 100: "Continue", @@ -171,7 +200,7 @@ def warn_deprecation(text, version, stacklevel): } # generic class responses as per RFC2616 -status_generic_reasons = { +status_generic_reasons: dict[int, str] = { 1: "Continue", 2: "Success", 3: "Multiple Choices", diff --git a/tests/mypy/check_cache_control.py b/tests/mypy/check_cache_control.py new file mode 100644 index 00000000..6a9879d2 --- /dev/null +++ b/tests/mypy/check_cache_control.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from typing import Any, Literal, Union + +from typing_extensions import assert_type + +from webob.cachecontrol import CacheControl +from webob.request import BaseRequest +from webob.response import Response + +req = BaseRequest({}) +res = Response() +assert_type(req.cache_control, CacheControl[Literal["request"]]) +assert_type(res.cache_control, CacheControl[Literal["response"]]) + +assert_type(CacheControl.parse(""), CacheControl[None]) +assert_type(CacheControl.parse("", type="request"), CacheControl[Literal["request"]]) +assert_type(CacheControl.parse("", type="response"), CacheControl[Literal["response"]]) + +req_cc = req.cache_control +res_cc = res.cache_control +shared_cc = CacheControl.parse("") +assert_type(req_cc, CacheControl[Literal["request"]]) +assert_type(res_cc, CacheControl[Literal["response"]]) +assert_type(shared_cc, CacheControl[None]) +any_cc = CacheControl[Any]({}, None) + +assert_type(req_cc.max_stale, Union[int, Literal["*"], None]) +res_cc.max_stale # type: ignore +shared_cc.max_stale # type: ignore +assert_type(any_cc.max_stale, Union[int, Literal["*"], None]) + +assert_type(req_cc.min_fresh, Union[int, None]) +res_cc.min_fresh # type: ignore +shared_cc.min_fresh # type: ignore +assert_type(any_cc.min_fresh, Union[int, None]) + +assert_type(req_cc.only_if_cached, bool) +res_cc.only_if_cached # type: ignore +shared_cc.only_if_cached # type: ignore +assert_type(any_cc.only_if_cached, bool) + +req_cc.public # type: ignore +assert_type(res_cc.public, bool) +shared_cc.public # type: ignore +assert_type(any_cc.public, bool) + +req_cc.private # type: ignore +assert_type(res_cc.private, Union[str, Literal["*"], None]) +shared_cc.private # type: ignore +assert_type(any_cc.private, Union[str, Literal["*"], None]) + +assert_type(req_cc.no_cache, Union[str, Literal["*"], None]) +assert_type(res_cc.no_cache, Union[str, Literal["*"], None]) +assert_type(shared_cc.no_cache, Union[str, Literal["*"], None]) +assert_type(any_cc.no_cache, Union[str, Literal["*"], None]) + +assert_type(req_cc.no_store, bool) +assert_type(res_cc.no_store, bool) +assert_type(shared_cc.no_store, bool) +assert_type(any_cc.no_store, bool) + +assert_type(req_cc.no_transform, bool) +assert_type(res_cc.no_transform, bool) +assert_type(shared_cc.no_transform, bool) +assert_type(any_cc.no_transform, bool) + +req_cc.must_revalidate # type: ignore +assert_type(res_cc.must_revalidate, bool) +shared_cc.must_revalidate # type: ignore +assert_type(any_cc.must_revalidate, bool) + +req_cc.proxy_revalidate # type: ignore +assert_type(res_cc.proxy_revalidate, bool) +shared_cc.proxy_revalidate # type: ignore +assert_type(any_cc.proxy_revalidate, bool) + +assert_type(req_cc.max_age, Union[int, Literal[-1], None]) +assert_type(res_cc.max_age, Union[int, Literal[-1], None]) +assert_type(shared_cc.max_age, Union[int, Literal[-1], None]) +assert_type(any_cc.max_age, Union[int, Literal[-1], None]) + +req_cc.s_maxage # type: ignore +assert_type(res_cc.s_maxage, Union[int, None]) +shared_cc.s_maxage # type: ignore +assert_type(any_cc.s_maxage, Union[int, None]) + +req_cc.s_max_age # type: ignore +assert_type(res_cc.s_max_age, Union[int, None]) +shared_cc.s_max_age # type: ignore +assert_type(any_cc.s_max_age, Union[int, None]) + +req_cc.stale_while_revalidate # type: ignore +assert_type(res_cc.stale_while_revalidate, Union[int, None]) +shared_cc.stale_while_revalidate # type: ignore +assert_type(any_cc.stale_while_revalidate, Union[int, None]) + +req_cc.stale_if_error # type: ignore +assert_type(res_cc.stale_if_error, Union[int, None]) +shared_cc.stale_if_error # type: ignore +assert_type(any_cc.stale_if_error, Union[int, None]) diff --git a/tests/mypy/check_wsgfy.py b/tests/mypy/check_wsgfy.py new file mode 100644 index 00000000..41ed7ccf --- /dev/null +++ b/tests/mypy/check_wsgfy.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from collections.abc import Iterable # noqa: F401 + +from _typeshed.wsgi import StartResponse, WSGIApplication, WSGIEnvironment +from typing_extensions import assert_type + +from webob.dec import _AnyResponse, wsgify +from webob.request import Request + + +class App: + @wsgify + def __call__(self, request: Request) -> str: + return "hello" + + +env: WSGIEnvironment = {} +start_response: StartResponse = lambda x, y, z=None: lambda b: None +application: WSGIApplication = lambda e, s: [b""] +request: Request = Request(env) + +x = App() +# since we wsgified our __call__ we should now be a valid WSGIApplication +application = x +assert_type(x(env, start_response), "Iterable[bytes]") +# currently we lose the exact response type, but that should be fine in +# most use-cases, since middlewares operate on an application level, not +# on these raw intermediary functions +assert_type(x(request), _AnyResponse) + +# accessing the method from the class should work as you expect it to +assert_type(App.__call__(x, env, start_response), "Iterable[bytes]") +assert_type(App.__call__(x, request), _AnyResponse) + + +# but we can also wrap it with a middleware that expects to deal with requests +class Middleware: + @wsgify.middleware + def restrict_ip( + self, req: Request, app: WSGIApplication, ips: list[str] + ) -> WSGIApplication: + return app + + __call__ = restrict_ip(x, ips=["127.0.0.1"]) + + +# and we still end up with a valid WSGIApplication +m = Middleware() +application = m +assert_type(m(env, start_response), "Iterable[bytes]") +assert_type(m(request), _AnyResponse) + + +# the same should work with plain functions +@wsgify +def app(request: Request) -> str: + return "hello" + + +application = app +assert_type(app, "wsgify[[], Request]") +assert_type(app(env, start_response), "Iterable[bytes]") +assert_type(app(request), _AnyResponse) +assert_type(app(application), "wsgify[[], Request]") +application = app(application) + + +@wsgify.middleware +def restrict_ip(req: Request, app: WSGIApplication, ips: list[str]) -> WSGIApplication: + return app + + +@restrict_ip(ips=["127.0.0.1"]) +@wsgify +def m_app(request: Request) -> str: + return "hello" + + +application = m_app +assert_type(m_app, "wsgify[[WSGIApplication], Request]") +assert_type(m_app(env, start_response), "Iterable[bytes]") +assert_type(m_app(request), _AnyResponse) +assert_type(m_app(application), "wsgify[[WSGIApplication], Request]") +application = m_app(application) + + +# custom request +class MyRequest(Request): + pass + + +@wsgify(RequestClass=MyRequest) +def my_request_app(request: MyRequest) -> None: + pass + + +application = my_request_app +assert_type(my_request_app, "wsgify[[], MyRequest]") + + +# we are allowed to accept a less specific request class +@wsgify(RequestClass=MyRequest) +def valid_request_app(request: Request) -> None: + pass + + +# but the opposite is not allowed +@wsgify # type: ignore +def invalid_request_app(request: MyRequest) -> None: + pass + + +# we can't really make passing extra arguments directly work +# otherwise we have to give up most of our type safety for +# something that should only be used through wsgify.middleware +wsgify(args=(1,)) # type: ignore +wsgify(kwargs={"ips": ["127.0.0.1"]}) # type: ignore diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 32f884ec..ba20160f 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -93,9 +93,6 @@ def slow_app(req): def test_client_slow(serve, client_app=None): if client_app is None: client_app = SendRequest() - if not client_app._timeout_supported(client_app.HTTPConnection): - # timeout isn't supported - return with serve(slow_app) as server: req = Request.blank(server.url) req.environ["webob.client.timeout"] = 0.1 diff --git a/tests/test_response.py b/tests/test_response.py index f539b422..0d848fb8 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1052,7 +1052,7 @@ def dummy_wsgi_callable(environ, start_response): environ = {} def dummy_start_response(status, headers, exc_info=None): - assert headers, [("Set-Cookie" == "a=1; Path=/")] + assert headers, ["Set-Cookie" == "a=1; Path=/"] result = wsgiapp(environ, dummy_start_response) assert result == "abc" @@ -1269,7 +1269,6 @@ def test_cache_expires_set_zero(): assert res.cache_control.no_cache == "*" assert res.cache_control.must_revalidate is True assert res.cache_control.max_age == 0 - assert res.cache_control.post_check == 0 def test_encode_content_unknown(): diff --git a/tox.ini b/tox.ini index c04f38f4..8779e213 100644 --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,7 @@ envlist = py39,py310,py311,py312,py313,pypy39,pypy310, coverage, docs, + mypy, isolated_build = True [testenv] @@ -74,6 +75,14 @@ deps = black isort +[testenv:mypy] +skip_install = True +commands = + mypy -p webob + mypy tests/mypy/ +deps = + mypy + [testenv:build] skip_install = true commands =