diff --git a/src/endpoint.py b/src/endpoint.py index 33fbac9..3daf4ba 100644 --- a/src/endpoint.py +++ b/src/endpoint.py @@ -1,5 +1,9 @@ from __future__ import annotations -from typing import Union, Optional + +from typing import Optional, Any +from typing_extensions import Literal + + import importlib import os import pathlib @@ -44,8 +48,8 @@ def __init__(self, summary: str = None, # description: str = "Description", # desc: str = "Description", - types: Union[list, str] = "application/octet-stream", - example: Union[dict, str, int, float, bool, any] = None, + types: list | str = "application/octet-stream", + example: Any = None, security: Optional[dict] = None, responses: Optional[list] = None, tags: Optional[list] = None, @@ -66,10 +70,10 @@ def __init__(self, def http(method: str, require_auth: bool = True, - args: Union[tuple, list, Argument] = (), + args: tuple | list | Argument = (), docs: Optional[Document] = None): def _context(handler): - path = None + path: Optional[str] = None file = handler.__globals__["__file__"] if "___" in os.path.normpath(file).split(os.path.sep): raise IsADirectoryError("Path-argument like directory found.") @@ -148,21 +152,27 @@ def __init__(self, document: Optional[Document] = None): class Undefined: pass +ArgumentTypes = Literal["str", "string", "bool", "boolean", "number", "int", + "long", "double", "decimal", "float", "other"] + + class Argument(Documented): + type: ArgumentTypes + def __init__(self, name: str, - arg_type: str, + arg_type: ArgumentTypes, arg_in: str, required: bool = True, auto_cast: bool = True, minimum: int = -1, maximum: int = -1, - must_be: Union[tuple, list] = (), + must_be: tuple | list = (), doc: Optional[Document] = None, format_type: Optional[str] = None, ignore_check_expect100: bool = False, - enum: Union[tuple, list] = (), - default: Optional[any] = Undefined): + enum: tuple | list = (), + default: Any = Undefined): super().__init__(doc) if arg_type not in ["str", "string", "bool", "boolean", "number", "int", "long", "double", "decimal", "float", "other"]: @@ -182,14 +192,14 @@ def __init__(self, self.ignore_check_expect100 = ignore_check_expect100 self.default = default - def norm_type(self, val: Optional[any] = None) -> Optional[any]: + def norm_type(self, val: Any = None) -> Any: if "str" in self.type: return "string" if val is None else str(val) elif "bool" in self.type: return "boolean" if val is None else bool(val) - elif self.type is "number" or "int" in self.type: + elif self.type == "number" or "int" in self.type: return "integer" if val is None else int(val) - elif self.type is "long": + elif self.type == "long": return "integer" if val is None else int(val) else: return "number" if val is None else float(val) @@ -217,20 +227,20 @@ def validate(self, param_dict: dict) -> int: value = param_dict[name] if "str" in typ: - if len(must_be) is not 0 and value not in must_be: + if len(must_be) != 0 and value not in must_be: return 1 - if min_val is not -1 and len(value) < min_val: + if min_val != -1 and len(value) < min_val: return 3 - if max_val is not -1 and len(value) > max_val: + if max_val != -1 and len(value) > max_val: return 4 if cast: param_dict[name] = str(value) elif "bool" in typ: - if value not in ("true", "false") + self.must_be: + if value not in ("true", "false") + tuple(self.must_be): return 1 if cast: @@ -246,13 +256,13 @@ def validate(self, param_dict: dict) -> int: except ValueError: return 2 - if len(must_be) is not 0 and val not in must_be: + if len(must_be) != 0 and val not in must_be: return 1 - if min_val is not -1 and val < min_val: + if min_val != -1 and val < min_val: return 3 - if max_val is not -1 and val > max_val: + if max_val != -1 and val > max_val: return 4 if cast: @@ -280,7 +290,7 @@ def __init__(self, self.args = () if args is None else args self.path_arg = path_arg - def handle(self, handler, params: dict, queries: dict, path_param: dict) -> Union[Response, any]: + def handle(self, handler, params: dict, queries: dict, path_param: dict) -> Any: if self.auth_required and handler.do_auth(): return @@ -309,7 +319,7 @@ def validate_arg(self, handler, params: dict, queries: dict, path_param: dict) - continue elif code == 1: if "bool" in arg.type: - quick_invalid(handler, arg.name, "[" + ", ".join(("true", "false") + arg.must_be) + "]") + quick_invalid(handler, arg.name, "[" + ", ".join(("true", "false") + tuple(arg.must_be)) + "]") return False else: quick_invalid(handler, arg.name, "[" + ", ".join(arg.must_be) + "]") @@ -339,7 +349,7 @@ def validate_arg(self, handler, params: dict, queries: dict, path_param: dict) - val = arg.norm_type(path_param[arg.name]) if arg.auto_cast else path_param[arg.name] params[arg.name] = val - if len(missing) is not 0: + if len(missing) != 0: write(handler, 400, e(Cause.MISSING_FIELD, Cause.MISSING_FIELD[2] .replace("%0", str(len(missing))) .replace("%1", ", ".join(missing)))) @@ -350,9 +360,9 @@ def validate_arg(self, handler, params: dict, queries: dict, path_param: dict) - class Response(Documented): def __init__(self, code: int = 0, - body: Optional[any] = None, + body: Any = None, raw_body: bool = False, - content_type: Union[str, list] = None, + content_type: str | list = None, headers: Optional[dict] = None, doc: Optional[Document] = None): super().__init__(doc) @@ -367,7 +377,7 @@ def header(self, name: str, value: str) -> Response: self.headers[name] = value return self - def body(self, value: any, raw: bool = False) -> Response: + def body(self, value: Any, raw: bool = False) -> Response: self.body_data = value self.raw = raw return self @@ -390,13 +400,13 @@ def __init__(self, cause: Optional[Cause] = None, code: int = 0, headers: Optional[dict] = None, - body: Optional[any] = None, - content_type: Optional[Union[str, list]] = None, + body: Any = None, + content_type: Optional[str | list] = None, doc: Optional[Document] = None): if cause is not None: - super().__init__(cause[0], headers, cause[2], content_type, doc) + super().__init__(cause[0], headers, cause[2], content_type, headers, doc) else: - super().__init__(code, headers, body, content_type, doc) + super().__init__(code, body, False, content_type, headers, doc) self.cause = cause @@ -415,6 +425,8 @@ def error(cause: Optional[Cause] = None, code: int = 0, message: Optional[str] = class EPManager: + known_source: list[str] + def __init__(self): global loader self.signals = [] @@ -501,7 +513,7 @@ def make_cache(self) -> None: cursor[method] = EndPoint(method, rt, path, function, auth, args, bool(paths), docs) self.count += 1 - def get_endpoint(self, method: str, path: str, params: Optional[dict] = None) -> Optional[EndPoint]: + def get_endpoint(self, method: str, path: str, params: dict = {}) -> Optional[EndPoint]: cursor = self.index_tree diff --git a/src/gendoc.py b/src/gendoc.py index 58a010b..65b9600 100644 --- a/src/gendoc.py +++ b/src/gendoc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os import pathlib @@ -94,7 +96,7 @@ def b(ex): properties = {} for zz in ex.items(): tz = whats_type_of_this_object(zz[1]) - if tz is "array": + if tz == "array": a_t_field = "" for at in zz[1]: diff --git a/src/route.py b/src/route.py index 5a23b88..3ed5859 100644 --- a/src/route.py +++ b/src/route.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any + import enum import json from server.handler_base import AbstractHandlerBase @@ -6,7 +10,7 @@ # This code is deprecated in new futures. -def encode(amfs: any) -> str: +def encode(amfs: Any) -> str: return json.JSONEncoder().encode(amfs) @@ -40,7 +44,7 @@ def __getitem__(self, index): return self.value[index] -def validate(handler, fname: str, value: any, must: str) -> bool: +def validate(handler, fname: str, value: Any, must: str) -> bool: if str(value) in must: return False @@ -52,7 +56,7 @@ def validate(handler, fname: str, value: any, must: str) -> bool: def missing(handler, fields: dict, require: list) -> bool: diff = search_missing(fields, require) - if len(diff) is 0: + if len(diff) == 0: return False write(handler, 400, error(Cause.MISSING_FIELD, Cause.MISSING_FIELD[2] .replace("%0", str(len(diff))) @@ -60,7 +64,7 @@ def missing(handler, fields: dict, require: list) -> bool: return True -def success(handler, code: int, obj: any): +def success(handler, code: int, obj: Any): write(handler, code, encode({ "success": True, "result": obj diff --git a/src/run.py b/src/run.py index 672751d..3ea933e 100644 --- a/src/run.py +++ b/src/run.py @@ -68,7 +68,7 @@ def main(self): if not token.load(): self.log.warn("main", "Token not found. ") self.log.info("auth", "Generating token...") - self.log.info("auth", "Token generated: " + token.generate()) + self.log.info("auth", "Token generated: %s" % token.generate()) self.log.warn( "auth", "Make sure to copy this token now. You won't be able to see it again.") diff --git a/src/server/handler.py b/src/server/handler.py index 8902a58..1beb315 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Optional + import cgi import json import mimetypes @@ -16,6 +20,7 @@ class Handler(ServerHandler): + request: Optional[HTTPRequest] def __init__(self, request, client_address, server): self.logger = server.logger @@ -76,7 +81,10 @@ def call_handler(self, path: str, params, queries): def dynamic_handle(self, path, params, queries): path_param = {} - ep = endpoint.loader.get_endpoint(self.request.method, path, path_param) + ep: Optional[endpoint.EndPoint] = None + + if self.request is not None and self.request.method is not None: + ep = endpoint.loader.get_endpoint(self.request.method, path, path_param) if ep is None: return False @@ -128,7 +136,10 @@ def _send_body(self, body, raw=False, content_types=None): return default = self.config["system"]["request"]["default_content_type"] - accept = self.request.headers["Accept"] if "Accept" in self.request.headers else "" + accept = "" + + if self.request is not None: + self.request.headers["Accept"] if self.request.headers is not None and "Accept" in self.request.headers else "" if content_types is not None: if isinstance(content_types, str): @@ -149,10 +160,16 @@ def log_request(self, **kwargs): if not no_req_log: self.logger.info(get_log_name(), '%s -- %s %s -- "%s %s"' % (kwargs["client"], kwargs["code"], "" if kwargs["message"] is None else kwargs["message"], - self.request.method, kwargs["path"])) + self.request.method if self.request is not None else "", kwargs["path"])) def handle_switch(self): try: + if self.request is None: + raise TypeError("Request instance is None") + + if self.request.path is None: + raise TypeError("Request path is None") + path = parse.urlparse(self.request.path) queries = dict(parse.parse_qsl(path.query)) @@ -160,7 +177,7 @@ def handle_switch(self): self.call_handler(path.path, {}, queries) else: - if "Content-Type" in self.request.headers: + if self.request.headers is not None and "Content-Type" in self.request.headers: content_len = int(self.request.headers.get("content-length").value) content_type = str(self.request.headers["Content-Type"]) @@ -196,7 +213,6 @@ def handle_switch(self): self.logger.warn(get_log_name(), get_stack_trace("server", *sys.exc_info())) def do_auth(self): - self.request: HTTPRequest if "Authorization" not in self.request.headers: route.post_error(self, route.Cause.AUTH_REQUIRED) diff --git a/src/server/handler_base.py b/src/server/handler_base.py index aaf49e8..27fe8c3 100644 --- a/src/server/handler_base.py +++ b/src/server/handler_base.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +from typing import Any +from abc import ABC, abstractmethod + from socket import socket from typing import Optional, BinaryIO from socketserver import StreamRequestHandler @@ -81,34 +86,43 @@ header_limit = 100 -class AbstractHandlerBase(StreamRequestHandler): +class AbstractHandlerBase(ABC, StreamRequestHandler): def __init__(self, request, client_address, server): super().__init__(request, client_address, server) + @abstractmethod def handle(self) -> None: pass + @abstractmethod def handle_parse_error(self, cause: str) -> None: pass + @abstractmethod def handle_request(self) -> None: pass - def send_header(self, name: str, value: any, server_version: str) -> None: + @abstractmethod + def send_header(self, name: str, value: Any, server_version: str) -> None: pass + @abstractmethod def flush_header(self) -> None: pass + @abstractmethod def end_header(self) -> None: pass + @abstractmethod def send_response(self, code: int, message: str, server_version: str) -> None: pass + @abstractmethod def send_body(self, content_type: str, raw_body: bytes) -> None: pass + @abstractmethod def log_request(self, **kwargs) -> None: pass @@ -117,7 +131,7 @@ class CachedHeader(AbstractHandlerBase): def __init__(self): self._response_cache = [] - def send_header(self, name: str, value: any, server_version: str = "HTTP/1.1") -> None: + def send_header(self, name: str, value: Any, server_version: str = "HTTP/1.1") -> None: if server_version != "HTTP/0.9": self._response_cache.append(f"{name}: {str(value)}\r\n".encode("iso-8859-1")) @@ -161,21 +175,22 @@ def _handle(self) -> None: try: req = HTTPParser(self, self.rfile).parse() - if not req: + if req is None: return - if req.protocol >= "HTTP/1.1": + if req.protocol is not None and req.protocol >= "HTTP/1.1": self.multiple = True - if "Connection" in req.headers: - if req.headers["Connection"] == "keep-alive": - self.multiple = True - elif req.headers["Connection"] == "close": - self.multiple = False + if req.headers is not None: + if "Connection" in req.headers: + if req.headers["Connection"] == "keep-alive": + self.multiple = True + elif req.headers["Connection"] == "close": + self.multiple = False - if "Expect" in req.headers: - if req.headers["Expect"] == "100-continue": - req.expect_100 = True + if "Expect" in req.headers: + if req.headers["Expect"] == "100-continue": + req.expect_100 = True self.request = req @@ -192,7 +207,7 @@ def _handle(self) -> None: except ParseException as e: self.handle_parse_error(e.cause) - def send_header(self, name: str, value: any, server_version: str = "HTTP/1.1"): + def send_header(self, name: str, value: Any, server_version: str = "HTTP/1.1"): if name == "Connection": if value.lower() == "keep-alive": self.multiple = True @@ -218,9 +233,9 @@ def decode(line: bytes) -> str: class HTTPRequest: - def __init__(self, handler, method: str = None, path: str = None, protocol: str = None, - headers: HeaderSet = None, rfile: BinaryIO = None, expect_100: bool = False, - parameters: dict = None): + def __init__(self, handler, method: Optional[str] = None, path: Optional[str] = None, protocol: Optional[str] = None, + headers: Optional[HeaderSet] = None, rfile: Optional[BinaryIO] = None, expect_100: bool = False, + parameters: Optional[dict] = None): self.handler = handler self.method = method self.path = path @@ -277,7 +292,7 @@ def _header(self, data: str) -> None: if len(kv) != 2: raise ParseException("MALFORMED_HEADER") - self._response.headers.add(*kv) + self._response.headers.add(*kv) if self._response.headers is not None else None def _first_line(self, byte: bytes) -> None: line = decode(byte) diff --git a/src/utils/header_parse.py b/src/utils/header_parse.py index 1f329e3..b6af1e9 100644 --- a/src/utils/header_parse.py +++ b/src/utils/header_parse.py @@ -1,3 +1,5 @@ +from __future__ import annotations + def _(n): return n.rstrip().lstrip()