Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make pyright happy #152

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions oras/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
from typing import Dict, Literal, Optional

import requests

from oras.auth.base import AuthBackend
from oras.logger import logger

from .basic import BasicAuth
from .token import TokenAuth

auth_backends = {"token": TokenAuth, "basic": BasicAuth}
AuthBackendName = Literal["token", "basic"]

auth_backends: Dict[AuthBackendName, type[AuthBackend]] = {
"token": TokenAuth,
"basic": BasicAuth,
}


def get_auth_backend(name="token", session=None, **kwargs):
def get_auth_backend(
name: AuthBackendName = "token",
session: Optional[requests.Session] = None,
**kwargs,
):
backend = auth_backends.get(name)
if not backend:
logger.exit(f"Authentication backend {backend} is not known.")
backend = backend(**kwargs)
backend.session = session or requests.Session()
return logger.exit(f"Authentication backend {backend} is not known.")
_session = session or requests.Session()
backend = backend(session=_session, **kwargs)
return backend
83 changes: 31 additions & 52 deletions oras/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,49 @@
__copyright__ = "Copyright The ORAS Authors."
__license__ = "Apache-2.0"

import abc
from typing import Dict, Optional, Tuple

from typing import Optional
import requests

import oras.auth.utils as auth_utils
import oras.container
import oras.decorator as decorator
import oras.utils
from oras.logger import logger
from oras.types import container_type


class AuthBackend:
class AuthBackend(abc.ABC):
"""
Generic (and default) auth backend.
"""

def __init__(self, *args, **kwargs):
def __init__(self, session: requests.Session):
self._auths: dict = {}
self.session = session
self.headers: Dict[str, str] = {}

@abc.abstractmethod
def authenticate_request(
self,
original: requests.Response,
headers: Optional[dict[str, str]] = None,
refresh=False,
) -> Tuple[dict[str, str], bool]:
pass

def get_auth_header(self):
raise NotImplementedError

def get_container(self, name: container_type) -> oras.container.Container:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this method as it relies on self.hostname which is not available

def set_header(self, name: str, value: str):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method was referenced in multiple places but did not actually exist. So I added it by copying it from Provider.

"""
Courtesy function to get a container from a URI.
Courtesy function to set a header

:param name: unique resource identifier to parse
:type name: oras.container.Container or str
:param name: header name to set
:type name: str
:param value: header value to set
:type value: str
"""
if isinstance(name, oras.container.Container):
return name
return oras.container.Container(name, registry=self.hostname)
self.headers.update({name: value})

def get_auth_header(self):
raise NotImplementedError

def logout(self, hostname: str):
"""
Expand Down Expand Up @@ -81,8 +93,9 @@ def _load_auth(self, hostname: str) -> bool:
return True
return False

@decorator.ensure_container
def load_configs(self, container: container_type, configs: Optional[list] = None):
def load_configs(
self, container: oras.container.Container, configs: Optional[list] = None
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the container type as it relied on self.get_container() method above trough the decorator. Which as said would not have worked anyways

):
"""
Load configs to discover credentials for a specific container.

Expand All @@ -96,7 +109,7 @@ def load_configs(self, container: container_type, configs: Optional[list] = None
"""
if not self._auths:
self._auths = auth_utils.load_configs(configs)
for registry in oras.utils.iter_localhosts(container.registry): # type: ignore
for registry in oras.utils.iter_localhosts(container.registry):
if self._load_auth(registry):
return

Expand All @@ -120,37 +133,3 @@ def set_basic_auth(self, username: str, password: str):
:type password: str
"""
self._basic_auth = auth_utils.get_basic_auth(username, password)

def request_anonymous_token(self, h: auth_utils.authHeader, headers: dict) -> bool:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my understanding this is dead code and completely incompatible with the overriding method in token auth, which is actually being used.

"""
Given no basic auth, fall back to trying to request an anonymous token.

Returns: boolean if headers have been updated with token.
"""
if not h.realm:
logger.debug("Request anonymous token: no realm provided, exiting early")
return headers, False

params = {}
if h.service:
params["service"] = h.service
if h.scope:
params["scope"] = h.scope

logger.debug(f"Final params are {params}")
response = self.session.request("GET", h.realm, params=params)
if response.status_code != 200:
logger.debug(f"Response for anon token failed: {response.text}")
return headers, False

# From https://docs.docker.com/registry/spec/auth/token/ section
# We can get token OR access_token OR both (when both they are identical)
data = response.json()
token = data.get("token") or data.get("access_token")

# Update the headers but not self.token (expects Basic)
if token:
headers["Authorization"] = {"Authorization": "Bearer %s" % token}

logger.debug("Warning: no token or access_token present in response.")
return headers, False
12 changes: 8 additions & 4 deletions oras/auth/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__license__ = "Apache-2.0"

import os
from typing import Optional, Tuple

import requests

Expand All @@ -14,10 +15,10 @@ class BasicAuth(AuthBackend):
Generic (and default) auth backend.
"""

def __init__(self):
def __init__(self, session: requests.Session):
username = os.environ.get("ORAS_USER")
password = os.environ.get("ORAS_PASS")
super().__init__()
super().__init__(session=session)
if username and password:
self.set_basic_auth(username, password)

Expand All @@ -28,8 +29,11 @@ def get_auth_header(self):
return {"Authorization": "Basic %s" % self._basic_auth}

def authenticate_request(
self, original: requests.Response, headers: dict, refresh=False
):
self,
original: requests.Response,
headers: Optional[dict[str, str]] = None,
refresh=False,
) -> Tuple[dict[str, str], bool]:
"""
Authenticate Request
Given a response, look for a Www-Authenticate header to parse.
Expand Down
45 changes: 26 additions & 19 deletions oras/auth/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
__copyright__ = "Copyright The ORAS Authors."
__license__ = "Apache-2.0"

from typing import Optional, Tuple

import requests

import oras.auth.utils as auth_utils
Expand All @@ -15,9 +17,9 @@ class TokenAuth(AuthBackend):
Token (OAuth2) style auth.
"""

def __init__(self):
def __init__(self, session: requests.Session):
self.token = None
super().__init__()
super().__init__(session=session)

def _logout(self):
self.token = None
Expand All @@ -44,8 +46,11 @@ def reset_basic_auth(self):
self.set_header("Authorization", "Basic %s" % self._basic_auth)

def authenticate_request(
self, original: requests.Response, headers: dict, refresh=False
):
self,
original: requests.Response,
headers: Optional[dict[str, str]] = None,
refresh=False,
) -> Tuple[dict[str, str], bool]:
"""
Authenticate Request
Given a response, look for a Www-Authenticate header to parse.
Expand All @@ -55,6 +60,8 @@ def authenticate_request(
:param original: original response to get the Www-Authenticate header
:type original: requests.Response
"""
_headers = headers or {}

if refresh:
self.token = None

Expand All @@ -63,12 +70,12 @@ def authenticate_request(
logger.debug(
"Www-Authenticate not found in original response, cannot authenticate."
)
return headers, False
return _headers, False

# If we have a token, set auth header (base64 encoded user/pass)
if self.token:
headers["Authorization"] = "Bearer %s" % self.token
return headers, True
_headers["Authorization"] = "Bearer %s" % self.token
return _headers, True

h = auth_utils.parse_auth_header(authHeaderRaw)

Expand All @@ -78,23 +85,23 @@ def authenticate_request(
if anon_token:
logger.debug("Successfully obtained anonymous token!")
self.token = anon_token
headers["Authorization"] = "Bearer %s" % self.token
return headers, True
_headers["Authorization"] = "Bearer %s" % self.token
return _headers, True

# Next try for logged in token
token = self.request_token(h)
if token:
self.token = token
headers["Authorization"] = "Bearer %s" % self.token
return headers, True
_headers["Authorization"] = "Bearer %s" % self.token
return _headers, True

logger.error(
"This endpoint requires a token. Please use "
"basic auth with a username or password."
)
return headers, False
return _headers, False

def request_token(self, h: auth_utils.authHeader) -> bool:
def request_token(self, h: auth_utils.authHeader):
"""
Request an authenticated token and save for later.s
"""
Expand All @@ -113,16 +120,18 @@ def request_token(self, h: auth_utils.authHeader) -> bool:
}
)

assert h.realm is not None, "realm must be defined"

# Ensure the realm starts with http
if not h.realm.startswith("http"): # type: ignore
h.realm = f"{self.prefix}://{h.realm}"
if not h.realm.startswith("http"):
h.realm = f"http://{h.realm}" # TODO: Should this be htts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! fixed


# If the www-authenticate included a scope, honor it!
if h.scope:
logger.debug(f"Scope: {h.scope}")
params["scope"] = h.scope

authResponse = self.session.get(h.realm, headers=headers, params=params) # type: ignore
authResponse = self.session.get(h.realm, headers=headers, params=params)
if authResponse.status_code != 200:
logger.debug(f"Auth response was not successful: {authResponse.text}")
return
Expand All @@ -131,11 +140,9 @@ def request_token(self, h: auth_utils.authHeader) -> bool:
info = authResponse.json()
return info.get("token") or info.get("access_token")

def request_anonymous_token(self, h: auth_utils.authHeader) -> bool:
def request_anonymous_token(self, h: auth_utils.authHeader) -> Optional[str]:
"""
Given no basic auth, fall back to trying to request an anonymous token.

Returns: boolean if headers have been updated with token.
"""
if not h.realm:
logger.debug("Request anonymous token: no realm provided, exiting early")
Expand Down
7 changes: 3 additions & 4 deletions oras/auth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@ def load_configs(configs: Optional[List[str]] = None):
:param configs: list of configuration paths to load, defaults to None
:type configs: optional list
"""
configs = configs or []
_configs = configs or []
default_config = oras.utils.find_docker_config()

# Add the default docker config
if default_config:
configs.append(default_config)
configs = set(configs) # type: ignore
_configs.append(default_config)

# Load configs until we find our registry hostname
auths = {}
for config in configs:
for config in set(_configs):
if not os.path.exists(config):
logger.warning(f"{config} does not exist.")
continue
Expand Down
2 changes: 1 addition & 1 deletion oras/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def parse(self, name: str):
raise ValueError(
f"{name} does not match a recognized registry unique resource identifier. Try <registry>/<namespace>/<repository>:<tag|digest>"
)
items = match.groupdict() # type: ignore
items = match.groupdict()
self.repository = items["repository"]
self.registry = items["registry"] or self.registry
self.namespace = items["namespace"]
Expand Down
Loading
Loading