diff --git a/oras/auth/__init__.py b/oras/auth/__init__.py index 394efd0b..f2050245 100644 --- a/oras/auth/__init__.py +++ b/oras/auth/__init__.py @@ -1,17 +1,29 @@ +from typing import Dict, Literal, Optional + import requests +from oras.auth.base import AuthBackend from oras.logger import logger from .basic import BasicAuth from .token import TokenAuth -auth_backends = {"token": TokenAuth, "basic": BasicAuth} +AuthBackendName = Literal["token", "basic"] + +auth_backends: Dict[AuthBackendName, type[AuthBackend]] = { + "token": TokenAuth, + "basic": BasicAuth, +} -def get_auth_backend(name="token", session=None, **kwargs): +def get_auth_backend( + name: AuthBackendName = "token", + session: Optional[requests.Session] = None, + **kwargs, +): backend = auth_backends.get(name) if not backend: - logger.exit(f"Authentication backend {backend} is not known.") - backend = backend(**kwargs) - backend.session = session or requests.Session() + return logger.exit(f"Authentication backend {backend} is not known.") + _session = session or requests.Session() + backend = backend(session=_session, **kwargs) return backend diff --git a/oras/auth/base.py b/oras/auth/base.py index 37098e50..eab8caa3 100644 --- a/oras/auth/base.py +++ b/oras/auth/base.py @@ -2,37 +2,49 @@ __copyright__ = "Copyright The ORAS Authors." __license__ = "Apache-2.0" +import abc +from typing import Dict, Optional, Tuple -from typing import Optional +import requests import oras.auth.utils as auth_utils import oras.container -import oras.decorator as decorator +import oras.utils from oras.logger import logger -from oras.types import container_type -class AuthBackend: +class AuthBackend(abc.ABC): """ Generic (and default) auth backend. """ - def __init__(self, *args, **kwargs): + def __init__(self, session: requests.Session): self._auths: dict = {} + self.session = session + self.headers: Dict[str, str] = {} + + @abc.abstractmethod + def authenticate_request( + self, + original: requests.Response, + headers: Optional[dict[str, str]] = None, + refresh=False, + ) -> Tuple[dict[str, str], bool]: + pass - def get_auth_header(self): - raise NotImplementedError - - def get_container(self, name: container_type) -> oras.container.Container: + def set_header(self, name: str, value: str): """ - Courtesy function to get a container from a URI. + Courtesy function to set a header - :param name: unique resource identifier to parse - :type name: oras.container.Container or str + :param name: header name to set + :type name: str + :param value: header value to set + :type value: str """ - if isinstance(name, oras.container.Container): - return name - return oras.container.Container(name, registry=self.hostname) + self.headers.update({name: value}) + + def get_auth_header(self): + raise NotImplementedError def logout(self, hostname: str): """ @@ -81,8 +93,9 @@ def _load_auth(self, hostname: str) -> bool: return True return False - @decorator.ensure_container - def load_configs(self, container: container_type, configs: Optional[list] = None): + def load_configs( + self, container: oras.container.Container, configs: Optional[list] = None + ): """ Load configs to discover credentials for a specific container. @@ -96,7 +109,7 @@ def load_configs(self, container: container_type, configs: Optional[list] = None """ if not self._auths: self._auths = auth_utils.load_configs(configs) - for registry in oras.utils.iter_localhosts(container.registry): # type: ignore + for registry in oras.utils.iter_localhosts(container.registry): if self._load_auth(registry): return @@ -120,37 +133,3 @@ def set_basic_auth(self, username: str, password: str): :type password: str """ self._basic_auth = auth_utils.get_basic_auth(username, password) - - def request_anonymous_token(self, h: auth_utils.authHeader, headers: dict) -> bool: - """ - Given no basic auth, fall back to trying to request an anonymous token. - - Returns: boolean if headers have been updated with token. - """ - if not h.realm: - logger.debug("Request anonymous token: no realm provided, exiting early") - return headers, False - - params = {} - if h.service: - params["service"] = h.service - if h.scope: - params["scope"] = h.scope - - logger.debug(f"Final params are {params}") - response = self.session.request("GET", h.realm, params=params) - if response.status_code != 200: - logger.debug(f"Response for anon token failed: {response.text}") - return headers, False - - # From https://docs.docker.com/registry/spec/auth/token/ section - # We can get token OR access_token OR both (when both they are identical) - data = response.json() - token = data.get("token") or data.get("access_token") - - # Update the headers but not self.token (expects Basic) - if token: - headers["Authorization"] = {"Authorization": "Bearer %s" % token} - - logger.debug("Warning: no token or access_token present in response.") - return headers, False diff --git a/oras/auth/basic.py b/oras/auth/basic.py index bd9e72ac..28ee7bff 100644 --- a/oras/auth/basic.py +++ b/oras/auth/basic.py @@ -3,6 +3,7 @@ __license__ = "Apache-2.0" import os +from typing import Optional, Tuple import requests @@ -14,10 +15,10 @@ class BasicAuth(AuthBackend): Generic (and default) auth backend. """ - def __init__(self): + def __init__(self, session: requests.Session): username = os.environ.get("ORAS_USER") password = os.environ.get("ORAS_PASS") - super().__init__() + super().__init__(session=session) if username and password: self.set_basic_auth(username, password) @@ -28,8 +29,11 @@ def get_auth_header(self): return {"Authorization": "Basic %s" % self._basic_auth} def authenticate_request( - self, original: requests.Response, headers: dict, refresh=False - ): + self, + original: requests.Response, + headers: Optional[dict[str, str]] = None, + refresh=False, + ) -> Tuple[dict[str, str], bool]: """ Authenticate Request Given a response, look for a Www-Authenticate header to parse. diff --git a/oras/auth/token.py b/oras/auth/token.py index cdb8fd0a..5c5491e0 100644 --- a/oras/auth/token.py +++ b/oras/auth/token.py @@ -2,6 +2,8 @@ __copyright__ = "Copyright The ORAS Authors." __license__ = "Apache-2.0" +from typing import Optional, Tuple + import requests import oras.auth.utils as auth_utils @@ -15,9 +17,9 @@ class TokenAuth(AuthBackend): Token (OAuth2) style auth. """ - def __init__(self): + def __init__(self, session: requests.Session): self.token = None - super().__init__() + super().__init__(session=session) def _logout(self): self.token = None @@ -44,8 +46,11 @@ def reset_basic_auth(self): self.set_header("Authorization", "Basic %s" % self._basic_auth) def authenticate_request( - self, original: requests.Response, headers: dict, refresh=False - ): + self, + original: requests.Response, + headers: Optional[dict[str, str]] = None, + refresh=False, + ) -> Tuple[dict[str, str], bool]: """ Authenticate Request Given a response, look for a Www-Authenticate header to parse. @@ -55,6 +60,8 @@ def authenticate_request( :param original: original response to get the Www-Authenticate header :type original: requests.Response """ + _headers = headers or {} + if refresh: self.token = None @@ -63,12 +70,12 @@ def authenticate_request( logger.debug( "Www-Authenticate not found in original response, cannot authenticate." ) - return headers, False + return _headers, False # If we have a token, set auth header (base64 encoded user/pass) if self.token: - headers["Authorization"] = "Bearer %s" % self.token - return headers, True + _headers["Authorization"] = "Bearer %s" % self.token + return _headers, True h = auth_utils.parse_auth_header(authHeaderRaw) @@ -78,23 +85,23 @@ def authenticate_request( if anon_token: logger.debug("Successfully obtained anonymous token!") self.token = anon_token - headers["Authorization"] = "Bearer %s" % self.token - return headers, True + _headers["Authorization"] = "Bearer %s" % self.token + return _headers, True # Next try for logged in token token = self.request_token(h) if token: self.token = token - headers["Authorization"] = "Bearer %s" % self.token - return headers, True + _headers["Authorization"] = "Bearer %s" % self.token + return _headers, True logger.error( "This endpoint requires a token. Please use " "basic auth with a username or password." ) - return headers, False + return _headers, False - def request_token(self, h: auth_utils.authHeader) -> bool: + def request_token(self, h: auth_utils.authHeader): """ Request an authenticated token and save for later.s """ @@ -113,16 +120,18 @@ def request_token(self, h: auth_utils.authHeader) -> bool: } ) + assert h.realm is not None, "realm must be defined" + # Ensure the realm starts with http - if not h.realm.startswith("http"): # type: ignore - h.realm = f"{self.prefix}://{h.realm}" + if not h.realm.startswith("http"): + h.realm = f"http://{h.realm}" # TODO: Should this be https # If the www-authenticate included a scope, honor it! if h.scope: logger.debug(f"Scope: {h.scope}") params["scope"] = h.scope - authResponse = self.session.get(h.realm, headers=headers, params=params) # type: ignore + authResponse = self.session.get(h.realm, headers=headers, params=params) if authResponse.status_code != 200: logger.debug(f"Auth response was not successful: {authResponse.text}") return @@ -131,11 +140,9 @@ def request_token(self, h: auth_utils.authHeader) -> bool: info = authResponse.json() return info.get("token") or info.get("access_token") - def request_anonymous_token(self, h: auth_utils.authHeader) -> bool: + def request_anonymous_token(self, h: auth_utils.authHeader) -> Optional[str]: """ Given no basic auth, fall back to trying to request an anonymous token. - - Returns: boolean if headers have been updated with token. """ if not h.realm: logger.debug("Request anonymous token: no realm provided, exiting early") diff --git a/oras/auth/utils.py b/oras/auth/utils.py index 6e3ba5d4..bb0f02c3 100644 --- a/oras/auth/utils.py +++ b/oras/auth/utils.py @@ -18,17 +18,16 @@ def load_configs(configs: Optional[List[str]] = None): :param configs: list of configuration paths to load, defaults to None :type configs: optional list """ - configs = configs or [] + _configs = configs or [] default_config = oras.utils.find_docker_config() # Add the default docker config if default_config: - configs.append(default_config) - configs = set(configs) # type: ignore + _configs.append(default_config) # Load configs until we find our registry hostname auths = {} - for config in configs: + for config in set(_configs): if not os.path.exists(config): logger.warning(f"{config} does not exist.") continue diff --git a/oras/container.py b/oras/container.py index 81f117dd..61a9a242 100644 --- a/oras/container.py +++ b/oras/container.py @@ -108,7 +108,7 @@ def parse(self, name: str): raise ValueError( f"{name} does not match a recognized registry unique resource identifier. Try //:" ) - items = match.groupdict() # type: ignore + items = match.groupdict() self.repository = items["repository"] self.registry = items["registry"] or self.registry self.namespace = items["namespace"] diff --git a/oras/decorator.py b/oras/decorator.py index 885ec52a..e156c866 100644 --- a/oras/decorator.py +++ b/oras/decorator.py @@ -1,83 +1,51 @@ +from __future__ import annotations + __author__ = "Vanessa Sochat" __copyright__ = "Copyright The ORAS Authors." __license__ = "Apache-2.0" + +import functools import time -from functools import partial, update_wrapper +from typing import TYPE_CHECKING from oras.logger import logger +if TYPE_CHECKING: + from oras.provider import Registry -class Decorator: - """ - Shared parent decorator class - """ - - def __init__(self, func): - update_wrapper(self, func) - self.func = func - - def __get__(self, obj, objtype): - return partial(self.__call__, obj) - - -class ensure_container(Decorator): - """ - Ensure the first argument is a container, and not a string. - """ - def __call__(self, cls, *args, **kwargs): +def ensure_container(fn): + @functools.wraps(fn) + def wrapper(self: Registry, *args, **kwargs): if "container" in kwargs: - kwargs["container"] = cls.get_container(kwargs["container"]) + kwargs["container"] = self.get_container(kwargs["container"]) elif args: - container = cls.get_container(args[0]) + container = self.get_container(args[0]) args = (container, *args[1:]) - return self.func(cls, *args, **kwargs) + return fn(self, *args, **kwargs) + return wrapper -class classretry(Decorator): + +def retry(func): """ - Retry a function that is part of a class + A simple retry decorator """ - def __init__(self, func, attempts=5, timeout=2): - super().__init__(func) - self.attempts = attempts - self.timeout = timeout - - def __call__(self, cls, *args, **kwargs): + @functools.wraps(func) + def wrapper(*args, **kwargs): + attempts = 5 + timeout = 2 attempt = 0 - attempts = self.attempts - timeout = self.timeout while attempt < attempts: try: - return self.func(cls, *args, **kwargs) + return func(*args, **kwargs) except Exception as e: sleep = timeout + 3**attempt logger.info(f"Retrying in {sleep} seconds - error: {e}") time.sleep(sleep) attempt += 1 - return self.func(cls, *args, **kwargs) - - -def retry(attempts, timeout=2): - """ - A simple retry decorator - """ - - def decorator(func): - def inner(*args, **kwargs): - attempt = 0 - while attempt < attempts: - try: - return func(*args, **kwargs) - except Exception as e: - sleep = timeout + 3**attempt - logger.info(f"Retrying in {sleep} seconds - error: {e}") - time.sleep(sleep) - attempt += 1 - return func(*args, **kwargs) - - return inner + return func(*args, **kwargs) - return decorator + return wrapper diff --git a/oras/main/login.py b/oras/main/login.py index df1fac91..d8bd4898 100644 --- a/oras/main/login.py +++ b/oras/main/login.py @@ -33,12 +33,11 @@ def login( :param dockercfg_str: docker config path :type dockercfg_str: list """ - if not dockercfg_path: - dockercfg_path = oras.utils.find_docker_config(exists=False) - if os.path.exists(dockercfg_path): # type: ignore - cfg = oras.utils.read_json(dockercfg_path) # type: ignore + _dockercfg_path = dockercfg_path or "~/.docker/config.json" + if os.path.exists(_dockercfg_path): + cfg = oras.utils.read_json(_dockercfg_path) else: - oras.utils.mkdir_p(os.path.dirname(dockercfg_path)) # type: ignore + oras.utils.mkdir_p(os.path.dirname(_dockercfg_path)) cfg = {"auths": {}} if registry in cfg["auths"]: cfg["auths"][registry]["auth"] = auth_utils.get_basic_auth( @@ -48,5 +47,5 @@ def login( cfg["auths"][registry] = { "auth": auth_utils.get_basic_auth(username, password) } - oras.utils.write_json(cfg, dockercfg_path) # type: ignore + oras.utils.write_json(cfg, _dockercfg_path) return {"Status": "Login Succeeded"} diff --git a/oras/oci.py b/oras/oci.py index 4865f8ea..1a53332d 100644 --- a/oras/oci.py +++ b/oras/oci.py @@ -29,7 +29,7 @@ class Annotations: Create a new set of annotations """ - def __init__(self, filename=None): + def __init__(self, filename: Optional[str] = None): self.lookup = {} self.load(filename) @@ -41,7 +41,7 @@ def add(self, section, key, value): self.lookup[section] = {} self.lookup[section][key] = value - def load(self, filename: str): + def load(self, filename: Optional[str]): if filename and os.path.exists(filename): self.lookup = oras.utils.read_json(filename) if filename and not os.path.exists(filename): diff --git a/oras/provider.py b/oras/provider.py index 10a68fa2..0420176f 100644 --- a/oras/provider.py +++ b/oras/provider.py @@ -6,6 +6,7 @@ import os import sys import urllib +import urllib.parse from contextlib import contextmanager, nullcontext from dataclasses import asdict from http.cookiejar import DefaultCookiePolicy @@ -23,6 +24,7 @@ import oras.oci import oras.schemas import oras.utils +import oras.version from oras.logger import logger from oras.types import container_type from oras.utils.fileio import PathAndOptionalContent @@ -49,7 +51,7 @@ def __init__( hostname: Optional[str] = None, insecure: bool = False, tls_verify: bool = True, - auth_backend: str = "token", + auth_backend: oras.auth.AuthBackendName = "token", ): """ Create an ORAS client. @@ -86,7 +88,7 @@ def __repr__(self) -> str: def __str__(self) -> str: return "[oras-client]" - def version(self, return_items: bool = False) -> Union[dict, str]: + def version(self, return_items: bool = False) -> Union[dict[str, str], str]: """ Get the version of the client. @@ -109,7 +111,7 @@ def version(self, return_items: bool = False) -> Union[dict, str]: # Otherwise return a string that can be printed return "\n".join(["%s: %s" % (k, v) for k, v in versions.items()]) - def delete_tags(self, name: str, tags=Union[str, list]) -> List[str]: + def delete_tags(self, name: str, tags: Union[str, List[str]]) -> List[str]: """ Delete one or more tags for a unique resource identifier. @@ -120,10 +122,9 @@ def delete_tags(self, name: str, tags=Union[str, list]) -> List[str]: :param tags: single or multiple tags name to delete :type N: string or list """ - if isinstance(tags, str): - tags = [tags] + _tags = [tags] if isinstance(tags, str) else tags deleted = [] - for tag in tags: + for tag in _tags: if self.delete_tag(name, tag): deleted.append(tag) return deleted @@ -199,10 +200,10 @@ def login( # Fallback to manual login except Exception: return login.DockerClient().login( - username=username, # type: ignore - password=password, # type: ignore + username=username, + password=password, registry=hostname, # type: ignore - dockercfg_path=config_path, + dockercfg_path=config_path, # type: ignore ) def set_header(self, name: str, value: str): @@ -306,9 +307,10 @@ def delete_tag(self, container: container_type, tag: str) -> bool: :param tag: name of tag to delete :type tag: str """ + assert isinstance(container, oras.container.Container) logger.debug(f"Deleting tag {tag} for {container}") - head_url = f"{self.prefix}://{container.manifest_url(tag)}" # type: ignore + head_url = f"{self.prefix}://{container.manifest_url(tag)}" # get digest of manifest to delete response = self.do_request( @@ -324,14 +326,14 @@ def delete_tag(self, container: container_type, tag: str) -> bool: if not digest: raise RuntimeError("Expected to find Docker-Content-Digest header.") - delete_url = f"{self.prefix}://{container.manifest_url(digest)}" # type: ignore + delete_url = f"{self.prefix}://{container.manifest_url(digest)}" response = self.do_request(delete_url, "DELETE") if response.status_code != 202: raise RuntimeError("Delete was not successful: {response.json()}") return True @decorator.ensure_container - def get_tags(self, container: container_type, N=None) -> List[str]: + def get_tags(self, container: container_type, N: Optional[int] = None) -> List[str]: """ Retrieve tags for a package. @@ -340,18 +342,21 @@ def get_tags(self, container: container_type, N=None) -> List[str]: :param N: limit number of tags, None for all (default) :type N: Optional[int] """ + assert isinstance(container, oras.container.Container) retrieve_all = N is None - tags_url = f"{self.prefix}://{container.tags_url(N=N)}" # type: ignore + tags_url = f"{self.prefix}://{container.tags_url(N=N)}" tags: List[str] = [] - def extract_tags(response: requests.Response): + def extract_tags(response: requests.Response) -> bool: """ Determine if we should continue based on new tags and under limit. """ json = response.json() new_tags = json.get("tags") or [] tags.extend(new_tags) - return len(new_tags) and (retrieve_all or len(tags) < N) + if not len(tags) > 0: + return False + return retrieve_all if N is None else len(tags) < N self._do_paginated_request(tags_url, callable=extract_tags) @@ -414,8 +419,9 @@ def get_blob( :param head: use head to determine if blob exists :type head: bool """ + assert isinstance(container, oras.container.Container) method = "GET" if not head else "HEAD" - blob_url = f"{self.prefix}://{container.get_blob_url(digest)}" # type: ignore + blob_url = f"{self.prefix}://{container.get_blob_url(digest)}" return self.do_request(blob_url, method, headers=self.headers, stream=stream) def get_container(self, name: container_type) -> oras.container.Container: @@ -699,7 +705,7 @@ def push( manifest_config: Optional[str] = None, annotation_file: Optional[str] = None, manifest_annotations: Optional[dict] = None, - subject: Optional[str] = None, + subject: Optional[oras.oci.Subject] = None, do_chunked: bool = False, chunk_size: int = oras.defaults.default_chunksize, ) -> requests.Response: @@ -729,7 +735,9 @@ def push( """ container = self.get_container(target) files = files or [] - self.auth.load_configs(container, configs=config_path) + self.auth.load_configs( + container, configs=[config_path] if config_path else None + ) # Prepare a new manifest manifest = oras.oci.NewManifest() @@ -808,7 +816,7 @@ def push( manifest["annotations"] = manifest_annots if subject: - manifest["subject"] = asdict(subject) + manifest["subject"] = asdict(subject) # type: ignore # Prepare the manifest config (temporary or one provided) config_annots = annotset.get_annotations("$config") @@ -867,7 +875,9 @@ def pull( :type target: str """ container = self.get_container(target) - self.auth.load_configs(container, configs=config_path) + self.auth.load_configs( + container, configs=[config_path] if config_path else None + ) manifest = self.get_manifest(container, allowed_media_type) outdir = outdir or oras.utils.get_tmpdir() overwrite = overwrite @@ -920,27 +930,28 @@ def get_manifest( :param allowed_media_type: one or more allowed media types :type allowed_media_type: str """ + assert isinstance(container, oras.container.Container) if not allowed_media_type: allowed_media_type = [oras.defaults.default_manifest_media_type] headers = {"Accept": ";".join(allowed_media_type)} - get_manifest = f"{self.prefix}://{container.manifest_url()}" # type: ignore + get_manifest = f"{self.prefix}://{container.manifest_url()}" response = self.do_request(get_manifest, "GET", headers=headers) self._check_200_response(response) manifest = response.json() jsonschema.validate(manifest, schema=oras.schemas.manifest) return manifest - @decorator.classretry + @decorator.retry def do_request( self, url: str, method: str = "GET", data: Optional[Union[dict, bytes]] = None, - headers: Optional[dict] = None, + headers: Optional[dict[str, str]] = None, json: Optional[dict] = None, stream: bool = False, - ): + ) -> requests.Response: """ Do a request. This is a wrapper around requests to handle retry auth. diff --git a/oras/utils/fileio.py b/oras/utils/fileio.py index db308200..a4da8827 100644 --- a/oras/utils/fileio.py +++ b/oras/utils/fileio.py @@ -270,7 +270,7 @@ def read_in_chunks(image: Union[TextIO, io.BufferedReader], chunk_size: int = 10 data = image.read(chunk_size) if not data: break - yield data + yield data.encode() if isinstance(data, str) else data def write_json(json_obj: dict, filename: str, mode: str = "w") -> str: diff --git a/oras/utils/request.py b/oras/utils/request.py index 7fa48ae5..0ed9b305 100644 --- a/oras/utils/request.py +++ b/oras/utils/request.py @@ -23,15 +23,14 @@ def iter_localhosts(name: str): yield name -def find_docker_config(exists: bool = True): +def find_docker_config(): """ Return the docker default config path. """ path = os.path.expanduser("~/.docker/config.json") # Allow the caller to request the path regardless of existing - if os.path.exists(path) or not exists: - return path + return path if os.path.exists(path) else None def append_url_params(url: str, params: dict) -> str: