diff --git a/lib/charms/loki_k8s/v1/loki_push_api.py b/lib/charms/loki_k8s/v1/loki_push_api.py index 4bc4e22a..e2a61959 100644 --- a/lib/charms/loki_k8s/v1/loki_push_api.py +++ b/lib/charms/loki_k8s/v1/loki_push_api.py @@ -490,7 +490,7 @@ def _alert_rules_error(self, event): from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union from urllib import request -from urllib.error import HTTPError +from urllib.error import URLError import yaml from cosl import JujuTopology @@ -519,7 +519,7 @@ def _alert_rules_error(self, event): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 6 +LIBPATCH = 7 logger = logging.getLogger(__name__) @@ -2326,7 +2326,7 @@ def _ensure_promtail_binary(self, promtail_binaries: dict, container: Container) try: self._obtain_promtail(promtail_binaries[self._arch], container) - except HTTPError as e: + except URLError as e: msg = f"Promtail binary couldn't be downloaded - {str(e)}" logger.warning(msg) self.on.promtail_digest_error.emit(msg) diff --git a/lib/charms/observability_libs/v0/cert_handler.py b/lib/charms/observability_libs/v0/cert_handler.py index db14e00f..9dcfc8f1 100644 --- a/lib/charms/observability_libs/v0/cert_handler.py +++ b/lib/charms/observability_libs/v0/cert_handler.py @@ -37,22 +37,25 @@ import json import socket from itertools import filterfalse -from typing import List, Optional, Union +from typing import List, Optional, Union, cast try: - from charms.tls_certificates_interface.v2.tls_certificates import ( # type: ignore + from charms.tls_certificates_interface.v3.tls_certificates import ( # type: ignore AllCertificatesInvalidatedEvent, CertificateAvailableEvent, CertificateExpiringEvent, CertificateInvalidatedEvent, - TLSCertificatesRequiresV2, + TLSCertificatesRequiresV3, generate_csr, generate_private_key, ) -except ImportError: +except ImportError as e: raise ImportError( - "charms.tls_certificates_interface.v2.tls_certificates is missing; please get it through charmcraft fetch-lib" - ) + "failed to import charms.tls_certificates_interface.v3.tls_certificates; " + "Either the library itself is missing (please get it through charmcraft fetch-lib) " + "or one of its dependencies is unmet." + ) from e + import logging from ops.charm import CharmBase, RelationBrokenEvent @@ -64,7 +67,7 @@ LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" LIBAPI = 0 -LIBPATCH = 9 +LIBPATCH = 11 def is_ip_address(value: str) -> bool: @@ -129,7 +132,7 @@ def __init__( self.peer_relation_name = peer_relation_name self.certificates_relation_name = certificates_relation_name - self.certificates = TLSCertificatesRequiresV2(self.charm, self.certificates_relation_name) + self.certificates = TLSCertificatesRequiresV3(self.charm, self.certificates_relation_name) self.framework.observe( self.charm.on.config_changed, @@ -279,7 +282,7 @@ def _generate_csr( if clear_cert: self._ca_cert = "" self._server_cert = "" - self._chain = [] + self._chain = "" def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: """Get the certificate from the event and store it in a peer relation. @@ -301,7 +304,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: if event_csr == self._csr: self._ca_cert = event.ca self._server_cert = event.certificate - self._chain = event.chain + self._chain = event.chain_as_pem() self.on.cert_changed.emit() # pyright: ignore @property @@ -372,21 +375,29 @@ def _server_cert(self, value: str): rel.data[self.charm.unit].update({"certificate": value}) @property - def _chain(self) -> List[str]: + def _chain(self) -> str: if self._peer_relation: - if chain := self._peer_relation.data[self.charm.unit].get("chain", []): - return json.loads(chain) - return [] + if chain := self._peer_relation.data[self.charm.unit].get("chain", ""): + chain = json.loads(chain) + + # In a previous version of this lib, chain used to be a list. + # Convert the List[str] to str, per + # https://github.com/canonical/tls-certificates-interface/pull/141 + if isinstance(chain, list): + chain = "\n\n".join(reversed(chain)) + + return cast(str, chain) + return "" @_chain.setter - def _chain(self, value: List[str]): + def _chain(self, value: str): # Caller must guard. We want the setter to fail loudly. Failure must have a side effect. rel = self._peer_relation assert rel is not None # For type checker rel.data[self.charm.unit].update({"chain": json.dumps(value)}) @property - def chain(self) -> List[str]: + def chain(self) -> str: """Return the ca chain.""" return self._chain diff --git a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py index 665af886..72c3fe72 100644 --- a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py +++ b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py @@ -362,7 +362,7 @@ def _on_scrape_targets_changed(self, event): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 44 +LIBPATCH = 45 PYDEPS = ["cosl"] @@ -1537,12 +1537,11 @@ def set_scrape_job_spec(self, _=None): relation.data[self._charm.app]["scrape_metadata"] = json.dumps(self._scrape_metadata) relation.data[self._charm.app]["scrape_jobs"] = json.dumps(self._scrape_jobs) - if alert_rules_as_dict: - # Update relation data with the string representation of the rule file. - # Juju topology is already included in the "scrape_metadata" field above. - # The consumer side of the relation uses this information to name the rules file - # that is written to the filesystem. - relation.data[self._charm.app]["alert_rules"] = json.dumps(alert_rules_as_dict) + # Update relation data with the string representation of the rule file. + # Juju topology is already included in the "scrape_metadata" field above. + # The consumer side of the relation uses this information to name the rules file + # that is written to the filesystem. + relation.data[self._charm.app]["alert_rules"] = json.dumps(alert_rules_as_dict) def _set_unit_ip(self, _=None): """Set unit host address. diff --git a/lib/charms/tempo_k8s/v1/charm_tracing.py b/lib/charms/tempo_k8s/v1/charm_tracing.py index c146e6d3..39ebcd46 100644 --- a/lib/charms/tempo_k8s/v1/charm_tracing.py +++ b/lib/charms/tempo_k8s/v1/charm_tracing.py @@ -146,9 +146,9 @@ def my_tracing_endpoint(self) -> Optional[str]: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 5 -PYDEPS = ["opentelemetry-exporter-otlp-proto-http>=1.21.0"] +PYDEPS = ["opentelemetry-exporter-otlp-proto-http==1.21.0"] logger = logging.getLogger("tracing") @@ -205,7 +205,7 @@ def _get_tracer() -> Optional[Tracer]: return context_tracer.get() else: return None - except LookupError as err: + except LookupError: return None @@ -240,8 +240,8 @@ def _get_tracing_endpoint(tracing_endpoint_getter, self, charm): if tracing_endpoint is None: logger.debug( - "Charm tracing is disabled. Tracing endpoint is not defined - " - "tracing is not available or relation is not set." + f"{charm}.{tracing_endpoint_getter} returned None; quietly disabling " + f"charm_tracing for the run." ) return elif not isinstance(tracing_endpoint, str): @@ -310,7 +310,18 @@ def wrap_init(self: CharmBase, framework: Framework, *args, **kwargs): } ) provider = TracerProvider(resource=resource) - tracing_endpoint = _get_tracing_endpoint(tracing_endpoint_getter, self, charm) + try: + tracing_endpoint = _get_tracing_endpoint(tracing_endpoint_getter, self, charm) + except Exception: + # if anything goes wrong with retrieving the endpoint, we go on with tracing disabled. + # better than breaking the charm. + logger.exception( + f"exception retrieving the tracing " + f"endpoint from {charm}.{tracing_endpoint_getter}; " + f"proceeding with charm_tracing DISABLED. " + ) + return + if not tracing_endpoint: return @@ -528,6 +539,10 @@ def wrapped_function(*args, **kwargs): # type: ignore name = getattr(callable, "__qualname__", getattr(callable, "__name__", str(callable))) with _span(f"{'(static) ' if static else ''}{qualifier} call: {name}"): # type: ignore if static: + # fixme: do we or don't we need [1:]? + # The _trace_callable decorator doesn't always play nice with @staticmethods. + # Sometimes it will receive 'self', sometimes it won't. + # return callable(*args, **kwargs) # type: ignore return callable(*args[1:], **kwargs) # type: ignore return callable(*args, **kwargs) # type: ignore diff --git a/lib/charms/tempo_k8s/v2/tracing.py b/lib/charms/tempo_k8s/v2/tracing.py index 1660c971..d466531e 100644 --- a/lib/charms/tempo_k8s/v2/tracing.py +++ b/lib/charms/tempo_k8s/v2/tracing.py @@ -75,7 +75,6 @@ def __init__(self, *args): TYPE_CHECKING, Any, Dict, - Iterable, List, Literal, MutableMapping, @@ -105,7 +104,7 @@ def __init__(self, *args): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 1 +LIBPATCH = 2 PYDEPS = ["pydantic"] @@ -818,8 +817,13 @@ def get_endpoint( endpoint = self._get_endpoint(relation or self._relation, protocol=protocol) if not endpoint: requested_protocols = set() - for relation in self.relations: - databag = TracingRequirerAppData.load(relation.data[self._charm.app]) + relations = [relation] if relation else self.relations + for relation in relations: + try: + databag = TracingRequirerAppData.load(relation.data[self._charm.app]) + except DataValidationError: + continue + requested_protocols.update(databag.receivers) if protocol not in requested_protocols: diff --git a/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/lib/charms/tls_certificates_interface/v3/tls_certificates.py similarity index 79% rename from lib/charms/tls_certificates_interface/v2/tls_certificates.py rename to lib/charms/tls_certificates_interface/v3/tls_certificates.py index 9f67833b..cbdd80d1 100644 --- a/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -1,4 +1,4 @@ -# Copyright 2021 Canonical Ltd. +# Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. @@ -7,16 +7,19 @@ This library contains the Requires and Provides classes for handling the tls-certificates interface. +Pre-requisites: + - Juju >= 3.0 + ## Getting Started From a charm directory, fetch the library using `charmcraft`: ```shell -charmcraft fetch-lib charms.tls_certificates_interface.v2.tls_certificates +charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates ``` Add the following libraries to the charm's `requirements.txt` file: - jsonschema -- cryptography +- cryptography >= 42.0.0 Add the following section to the charm's `charmcraft.yaml` file: ```yaml @@ -36,10 +39,10 @@ Example: ```python -from charms.tls_certificates_interface.v2.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateCreationRequestEvent, CertificateRevocationRequestEvent, - TLSCertificatesProvidesV2, + TLSCertificatesProvidesV3, generate_private_key, ) from ops.charm import CharmBase, InstallEvent @@ -59,7 +62,7 @@ class ExampleProviderCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - self.certificates = TLSCertificatesProvidesV2(self, "certificates") + self.certificates = TLSCertificatesProvidesV3(self, "certificates") self.framework.observe( self.certificates.on.certificate_request, self._on_certificate_request @@ -126,15 +129,15 @@ def _on_certificate_revocation_request(self, event: CertificateRevocationRequest Example: ```python -from charms.tls_certificates_interface.v2.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateAvailableEvent, CertificateExpiringEvent, CertificateRevokedEvent, - TLSCertificatesRequiresV2, + TLSCertificatesRequiresV3, generate_csr, generate_private_key, ) -from ops.charm import CharmBase, RelationJoinedEvent +from ops.charm import CharmBase, RelationCreatedEvent from ops.main import main from ops.model import ActiveStatus, WaitingStatus from typing import Union @@ -145,10 +148,10 @@ class ExampleRequirerCharm(CharmBase): def __init__(self, *args): super().__init__(*args) self.cert_subject = "whatever" - self.certificates = TLSCertificatesRequiresV2(self, "certificates") + self.certificates = TLSCertificatesRequiresV3(self, "certificates") self.framework.observe(self.on.install, self._on_install) self.framework.observe( - self.on.certificates_relation_joined, self._on_certificates_relation_joined + self.on.certificates_relation_created, self._on_certificates_relation_created ) self.framework.observe( self.certificates.on.certificate_available, self._on_certificate_available @@ -176,7 +179,7 @@ def _on_install(self, event) -> None: {"private_key_password": "banana", "private_key": private_key.decode()} ) - def _on_certificates_relation_joined(self, event: RelationJoinedEvent) -> None: + def _on_certificates_relation_created(self, event: RelationCreatedEvent) -> None: replicas_relation = self.model.get_relation("replicas") if not replicas_relation: self.unit.status = WaitingStatus("Waiting for peer relation to be created") @@ -277,15 +280,15 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven import logging import uuid from contextlib import suppress +from dataclasses import dataclass from datetime import datetime, timedelta, timezone from ipaddress import IPv4Address -from typing import Any, Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.serialization import pkcs12 from jsonschema import exceptions, validate from ops.charm import ( CharmBase, @@ -293,21 +296,27 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven RelationBrokenEvent, RelationChangedEvent, SecretExpiredEvent, - UpdateStatusEvent, ) from ops.framework import EventBase, EventSource, Handle, Object from ops.jujuversion import JujuVersion -from ops.model import ModelError, Relation, RelationDataContent, SecretNotFoundError +from ops.model import ( + Application, + ModelError, + Relation, + RelationDataContent, + SecretNotFoundError, + Unit, +) # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" # Increment this major API version when introducing breaking changes -LIBAPI = 2 +LIBAPI = 3 # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 28 +LIBPATCH = 10 PYDEPS = ["cryptography", "jsonschema"] @@ -422,6 +431,34 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven logger = logging.getLogger(__name__) +@dataclass +class RequirerCSR: + """This class represents a certificate signing request from an interface Requirer.""" + + relation_id: int + application_name: str + unit_name: str + csr: str + is_ca: bool + + +@dataclass +class ProviderCertificate: + """This class represents a certificate from an interface Provider.""" + + relation_id: int + application_name: str + csr: str + certificate: str + ca: str + chain: List[str] + revoked: bool + + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + + class CertificateAvailableEvent(EventBase): """Charm Event triggered when a TLS certificate is available.""" @@ -455,6 +492,10 @@ def restore(self, snapshot: dict): self.ca = snapshot["ca"] self.chain = snapshot["chain"] + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + class CertificateExpiringEvent(EventBase): """Charm Event triggered when a TLS certificate is almost expired.""" @@ -886,38 +927,6 @@ def generate_certificate( return cert.public_bytes(serialization.Encoding.PEM) -def generate_pfx_package( - certificate: bytes, - private_key: bytes, - package_password: str, - private_key_password: Optional[bytes] = None, -) -> bytes: - """Generate a PFX package to contain the TLS certificate and private key. - - Args: - certificate (bytes): TLS certificate - private_key (bytes): Private key - package_password (str): Password to open the PFX package - private_key_password (bytes): Private key password - - Returns: - bytes: - """ - private_key_object = serialization.load_pem_private_key( - private_key, password=private_key_password - ) - certificate_object = x509.load_pem_x509_certificate(certificate) - name = certificate_object.subject.rfc4514_string() - pfx_bytes = pkcs12.serialize_key_and_certificates( - name=name.encode(), - cert=certificate_object, - key=private_key_object, # type: ignore[arg-type] - cas=None, - encryption_algorithm=serialization.BestAvailableEncryption(package_password.encode()), - ) - return pfx_bytes - - def generate_private_key( password: Optional[bytes] = None, key_size: int = 2048, @@ -1053,6 +1062,27 @@ def csr_matches_certificate(csr: str, cert: str) -> bool: return True +def _relation_data_is_valid( + relation: Relation, app_or_unit: Union[Application, Unit], json_schema: dict +) -> bool: + """Check whether relation data is valid based on json schema. + + Args: + relation (Relation): Relation object + app_or_unit (Union[Application, Unit]): Application or unit object + json_schema (dict): Json schema + + Returns: + bool: Whether relation data is valid. + """ + relation_data = _load_relation_data(relation.data[app_or_unit]) + try: + validate(instance=relation_data, schema=json_schema) + return True + except exceptions.ValidationError: + return False + + class CertificatesProviderCharmEvents(CharmEvents): """List of events that the TLS Certificates provider charm can leverage.""" @@ -1069,7 +1099,7 @@ class CertificatesRequirerCharmEvents(CharmEvents): all_certificates_invalidated = EventSource(AllCertificatesInvalidatedEvent) -class TLSCertificatesProvidesV2(Object): +class TLSCertificatesProvidesV3(Object): """TLS certificates provider class to be instantiated by TLS certificates providers.""" on = CertificatesProviderCharmEvents() # type: ignore[reportAssignmentType] @@ -1178,22 +1208,6 @@ def _remove_certificate( certificates.remove(certificate_dict) relation.data[self.model.app]["certificates"] = json.dumps(certificates) - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Use JSON schema validator to validate relation data content. - - Args: - certificates_data (dict): Certificate data dictionary as retrieved from relation data. - - Returns: - bool: True/False depending on whether the relation data follows the json schema. - """ - try: - validate(instance=certificates_data, schema=REQUIRER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False - def revoke_all_certificates(self) -> None: """Revoke all certificates of this provider. @@ -1262,16 +1276,24 @@ def remove_certificate(self, certificate: str) -> None: def get_issued_certificates( self, relation_id: Optional[int] = None - ) -> Dict[str, List[Dict[str, str]]]: - """Return a dictionary of issued certificates. + ) -> List[ProviderCertificate]: + """Return a List of issued (non revoked) certificates. + + Returns: + List: List of ProviderCertificate objects + """ + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + return [certificate for certificate in provider_certificates if not certificate.revoked] - It returns certificates from all relations if relation_id is not specified. - Certificates are returned per application name and CSR. + def get_provider_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued certificates. Returns: - dict: Certificates per application name. + List: List of ProviderCertificate objects """ - certificates: Dict[str, List[Dict[str, str]]] = {} + certificates: List[ProviderCertificate] = [] relations = ( [ relation @@ -1282,19 +1304,22 @@ def get_issued_certificates( else self.model.relations.get(self.relationship_name, []) ) for relation in relations: + if not relation.app: + logger.warning("Relation %s does not have an application", relation.id) + continue provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) - - certificates[relation.app.name] = [] # type: ignore[union-attr] for certificate in provider_certificates: - if not certificate.get("revoked", False): - certificates[relation.app.name].append( # type: ignore[union-attr] - { - "csr": certificate["certificate_signing_request"], - "certificate": certificate["certificate"], - } - ) - + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=certificate["certificate_signing_request"], + certificate=certificate["certificate"], + ca=certificate["ca"], + chain=certificate["chain"], + revoked=certificate.get("revoked", False), + ) + certificates.append(provider_certificate) return certificates def _on_relation_changed(self, event: RelationChangedEvent) -> None: @@ -1317,124 +1342,77 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: return if not self.model.unit.is_leader(): return - requirer_relation_data = _load_relation_data(event.relation.data[event.unit]) - provider_relation_data = self._load_app_relation_data(event.relation) - if not self._relation_data_is_valid(requirer_relation_data): + if not _relation_data_is_valid(event.relation, event.unit, REQUIRER_JSON_SCHEMA): logger.debug("Relation data did not pass JSON Schema validation") return - provider_certificates = provider_relation_data.get("certificates", []) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) + provider_certificates = self.get_provider_certificates(relation_id=event.relation.id) + requirer_csrs = self.get_requirer_csrs(relation_id=event.relation.id) provider_csrs = [ - certificate_creation_request["certificate_signing_request"] + certificate_creation_request.csr for certificate_creation_request in provider_certificates ] - requirer_unit_certificate_requests = [ - { - "csr": certificate_creation_request["certificate_signing_request"], - "is_ca": certificate_creation_request.get("ca", False), - } - for certificate_creation_request in requirer_csrs - ] - for certificate_request in requirer_unit_certificate_requests: - if certificate_request["csr"] not in provider_csrs: + for certificate_request in requirer_csrs: + if certificate_request.csr not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_request["csr"], - relation_id=event.relation.id, - is_ca=certificate_request["is_ca"], + certificate_signing_request=certificate_request.csr, + relation_id=certificate_request.relation_id, + is_ca=certificate_request.is_ca, ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: """Revoke certificates for which no unit has a CSR. - Goes through all generated certificates and compare against the list of CSRs for all units - of a given relationship. - - Args: - relation_id (int): Relation id + Goes through all generated certificates and compare against the list of CSRs for all units. Returns: None """ - certificates_relation = self.model.get_relation( - relation_name=self.relationship_name, relation_id=relation_id - ) - if not certificates_relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") - provider_relation_data = self._load_app_relation_data(certificates_relation) - list_of_csrs: List[str] = [] - for unit in certificates_relation.units: - requirer_relation_data = _load_relation_data(certificates_relation.data[unit]) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) - list_of_csrs.extend(csr["certificate_signing_request"] for csr in requirer_csrs) - provider_certificates = provider_relation_data.get("certificates", []) + provider_certificates = self.get_provider_certificates(relation_id) + requirer_csrs = self.get_requirer_csrs(relation_id) + list_of_csrs = [csr.csr for csr in requirer_csrs] for certificate in provider_certificates: - if certificate["certificate_signing_request"] not in list_of_csrs: + if certificate.csr not in list_of_csrs: self.on.certificate_revocation_request.emit( - certificate=certificate["certificate"], - certificate_signing_request=certificate["certificate_signing_request"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, ) - self.remove_certificate(certificate=certificate["certificate"]) + self.remove_certificate(certificate=certificate.certificate) def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None - ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: + ) -> List[RequirerCSR]: """Return CSR's for which no certificate has been issued. - Example return: [ - { - "relation_id": 0, - "application_name": "tls-certificates-requirer", - "unit_name": "tls-certificates-requirer/0", - "unit_csrs": [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "is_ca": false - } - ] - } - ] - Args: relation_id (int): Relation id Returns: - list: List of dictionaries that contain the unit's csrs - that don't have a certificate issued. + list: List of RequirerCSR objects. """ - all_unit_csr_mappings = copy.deepcopy(self.get_requirer_csrs(relation_id=relation_id)) - filtered_all_unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] - for unit_csr_mapping in all_unit_csr_mappings: - csrs_without_certs = [] - for csr in unit_csr_mapping["unit_csrs"]: # type: ignore[union-attr] - if not self.certificate_issued_for_csr( - app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] - csr=csr["certificate_signing_request"], # type: ignore[index] - relation_id=relation_id, - ): - csrs_without_certs.append(csr) - if csrs_without_certs: - unit_csr_mapping["unit_csrs"] = csrs_without_certs # type: ignore[assignment] - filtered_all_unit_csr_mappings.append(unit_csr_mapping) - return filtered_all_unit_csr_mappings - - def get_requirer_csrs( - self, relation_id: Optional[int] = None - ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Return a list of requirers' CSRs grouped by unit. + requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) + outstanding_csrs: List[RequirerCSR] = [] + for relation_csr in requirer_csrs: + if not self.certificate_issued_for_csr( + app_name=relation_csr.application_name, + csr=relation_csr.csr, + relation_id=relation_id, + ): + outstanding_csrs.append(relation_csr) + return outstanding_csrs + + def get_requirer_csrs(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: + """Return a list of requirers' CSRs. It returns CSRs from all relations if relation_id is not specified. CSRs are returned per relation id, application name and unit name. Returns: - list: List of dictionaries that contain the unit's csrs - with the following information - relation_id, application_name and unit_name. + list: List[RequirerCSR] """ - unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] - + relation_csrs: List[RequirerCSR] = [] relations = ( [ relation @@ -1449,15 +1427,24 @@ def get_requirer_csrs( for unit in relation.units: requirer_relation_data = _load_relation_data(relation.data[unit]) unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) - unit_csr_mappings.append( - { - "relation_id": relation.id, - "application_name": relation.app.name, # type: ignore[union-attr] - "unit_name": unit.name, - "unit_csrs": unit_csrs_list, - } - ) - return unit_csr_mappings + for unit_csr in unit_csrs_list: + csr = unit_csr.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = unit_csr.get("ca", False) + if not relation.app: + logger.warning("No remote app in relation - Skipping") + continue + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=relation.app.name, + unit_name=unit.name, + csr=csr, + is_ca=ca, + ) + relation_csrs.append(relation_csr) + return relation_csrs def certificate_issued_for_csr( self, app_name: str, csr: str, relation_id: Optional[int] @@ -1468,19 +1455,18 @@ def certificate_issued_for_csr( app_name (str): Application name that the CSR belongs to. csr (str): Certificate Signing Request. relation_id (Optional[int]): Relation ID + Returns: bool: True/False depending on whether a certificate has been issued for the given CSR. """ - issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)[ - app_name - ] - for issued_pair in issued_certificates_per_csr: - if "csr" in issued_pair and issued_pair["csr"] == csr: - return csr_matches_certificate(csr, issued_pair["certificate"]) + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) + for issued_certificate in issued_certificates_per_csr: + if issued_certificate.csr == csr and issued_certificate.application_name == app_name: + return csr_matches_certificate(csr, issued_certificate.certificate) return False -class TLSCertificatesRequiresV2(Object): +class TLSCertificatesRequiresV3(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] @@ -1500,6 +1486,8 @@ def __init__( Used to trigger the CertificateExpiring event. Default: 7 days. """ super().__init__(charm, relationship_name) + if not JujuVersion.from_environ().has_secrets: + logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") self.relationship_name = relationship_name self.charm = charm self.expiry_notification_time = expiry_notification_time @@ -1509,32 +1497,39 @@ def __init__( self.framework.observe( charm.on[relationship_name].relation_broken, self._on_relation_broken ) - if JujuVersion.from_environ().has_secrets: - self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - else: - self.framework.observe(charm.on.update_status, self._on_update_status) + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - @property - def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: + def get_requirer_csrs(self) -> List[RequirerCSR]: """Return list of requirer's CSRs from relation unit data. - Example: - [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "ca": false - } - ] + Returns: + list: List of RequirerCSR objects. """ relation = self.model.get_relation(self.relationship_name) if not relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") + return [] + requirer_csrs = [] requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - return requirer_relation_data.get("certificate_signing_requests", []) + requirer_csrs_dict = requirer_relation_data.get("certificate_signing_requests", []) + for requirer_csr_dict in requirer_csrs_dict: + csr = requirer_csr_dict.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = requirer_csr_dict.get("ca", False) + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=self.model.app.name, + unit_name=self.model.unit.name, + csr=csr, + is_ca=ca, + ) + requirer_csrs.append(relation_csr) + return requirer_csrs - @property - def _provider_certificates(self) -> List[Dict[str, str]]: + def get_provider_certificates(self) -> List[ProviderCertificate]: """Return list of certificates from the provider's relation data.""" + provider_certificates: List[ProviderCertificate] = [] relation = self.model.get_relation(self.relationship_name) if not relation: logger.debug("No relation: %s", self.relationship_name) @@ -1543,12 +1538,32 @@ def _provider_certificates(self) -> List[Dict[str, str]]: logger.debug("No remote app in relation: %s", self.relationship_name) return [] provider_relation_data = _load_relation_data(relation.data[relation.app]) - if not self._relation_data_is_valid(provider_relation_data): - logger.warning("Provider relation data did not pass JSON Schema validation") - return [] - return provider_relation_data.get("certificates", []) + provider_certificate_dicts = provider_relation_data.get("certificates", []) + for provider_certificate_dict in provider_certificate_dicts: + certificate = provider_certificate_dict.get("certificate") + if not certificate: + logger.warning("No certificate found in relation data - Skipping") + continue + ca = provider_certificate_dict.get("ca") + chain = provider_certificate_dict.get("chain", []) + csr = provider_certificate_dict.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + revoked = provider_certificate_dict.get("revoked", False) + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=csr, + certificate=certificate, + ca=ca, + chain=chain, + revoked=revoked, + ) + provider_certificates.append(provider_certificate) + return provider_certificates - def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: + def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: """Add CSR to relation data. Args: @@ -1564,18 +1579,23 @@ def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - new_csr_dict: Dict[str, Union[bool, str]] = { + for requirer_csr in self.get_requirer_csrs(): + if requirer_csr.csr == csr and requirer_csr.is_ca == is_ca: + logger.info("CSR already in relation data - Doing nothing") + return + new_csr_dict = { "certificate_signing_request": csr, "ca": is_ca, } - if new_csr_dict in self._requirer_csrs: - logger.info("CSR already in relation data - Doing nothing") - return - requirer_csrs = copy.deepcopy(self._requirer_csrs) - requirer_csrs.append(new_csr_dict) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + new_relation_data.append(new_csr_dict) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) - def _remove_requirer_csr(self, csr: str) -> None: + def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: """Remove CSR from relation data. Args: @@ -1590,14 +1610,18 @@ def _remove_requirer_csr(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - requirer_csrs = copy.deepcopy(self._requirer_csrs) - if not requirer_csrs: + if not self.get_requirer_csrs(): logger.info("No CSRs in relation data - Doing nothing") return - for requirer_csr in requirer_csrs: + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + for requirer_csr in new_relation_data: if requirer_csr["certificate_signing_request"] == csr: - requirer_csrs.remove(requirer_csr) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + new_relation_data.remove(requirer_csr) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) def request_certificate_creation( self, certificate_signing_request: bytes, is_ca: bool = False @@ -1617,7 +1641,9 @@ def request_certificate_creation( f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - self._add_requirer_csr(certificate_signing_request.decode().strip(), is_ca=is_ca) + self._add_requirer_csr_to_relation_data( + certificate_signing_request.decode().strip(), is_ca=is_ca + ) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: @@ -1633,7 +1659,7 @@ def request_certificate_revocation(self, certificate_signing_request: bytes) -> Returns: None """ - self._remove_requirer_csr(certificate_signing_request.decode().strip()) + self._remove_requirer_csr_from_relation_data(certificate_signing_request.decode().strip()) logger.info("Certificate revocation sent to provider") def request_certificate_renewal( @@ -1661,107 +1687,62 @@ def request_certificate_renewal( ) logger.info("Certificate renewal request completed.") - def get_assigned_certificates(self) -> List[Dict[str, str]]: + def get_assigned_certificates(self) -> List[ProviderCertificate]: """Get a list of certificates that were assigned to this unit. Returns: - List of certificates. For example: - [ - { - "ca": "-----BEGIN CERTIFICATE-----...", - "chain": [ - "-----BEGIN CERTIFICATE-----..." - ], - "certificate": "-----BEGIN CERTIFICATE-----...", - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - } - ] + List: List[ProviderCertificate] """ - final_list = [] - for csr in self.get_certificate_signing_requests(fulfilled_only=True): - assert isinstance(csr["certificate_signing_request"], str) - if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): - final_list.append(cert) - return final_list - - def get_expiring_certificates(self) -> List[Dict[str, str]]: + assigned_certificates = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + assigned_certificates.append(cert) + return assigned_certificates + + def get_expiring_certificates(self) -> List[ProviderCertificate]: """Get a list of certificates that were assigned to this unit that are expiring or expired. Returns: - List of certificates. For example: - [ - { - "ca": "-----BEGIN CERTIFICATE-----...", - "chain": [ - "-----BEGIN CERTIFICATE-----..." - ], - "certificate": "-----BEGIN CERTIFICATE-----...", - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - } - ] + List: List[ProviderCertificate] """ - final_list = [] - for csr in self.get_certificate_signing_requests(fulfilled_only=True): - assert isinstance(csr["certificate_signing_request"], str) - if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): - expiry_time = _get_certificate_expiry_time(cert["certificate"]) + expiring_certificates: List[ProviderCertificate] = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + expiry_time = _get_certificate_expiry_time(cert.certificate) if not expiry_time: continue expiry_notification_time = expiry_time - timedelta( hours=self.expiry_notification_time ) if datetime.now(timezone.utc) > expiry_notification_time: - final_list.append(cert) - return final_list + expiring_certificates.append(cert) + return expiring_certificates def get_certificate_signing_requests( self, fulfilled_only: bool = False, unfulfilled_only: bool = False, - ) -> List[Dict[str, Union[bool, str]]]: + ) -> List[RequirerCSR]: """Get the list of CSR's that were sent to the provider. You can choose to get only the CSR's that have a certificate assigned or only the CSR's - that don't. + that don't. Args: fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. unfulfilled_only (bool): This option will discard CSRs that have certificates signed. Returns: - List of CSR dictionaries. For example: - [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "ca": false - } - ] + List of RequirerCSR objects. """ - final_list = [] - for csr in self._requirer_csrs: - assert isinstance(csr["certificate_signing_request"], str) - cert = self._find_certificate_in_relation_data(csr["certificate_signing_request"]) + csrs = [] + for requirer_csr in self.get_requirer_csrs(): + cert = self._find_certificate_in_relation_data(requirer_csr.csr) if (unfulfilled_only and cert) or (fulfilled_only and not cert): continue - final_list.append(csr) - - return final_list + csrs.append(requirer_csr) - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Check whether relation data is valid based on json schema. - - Args: - certificates_data: Certificate data in dict format. - - Returns: - bool: Whether relation data is valid. - """ - try: - validate(instance=certificates_data, schema=PROVIDER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False + return csrs def _on_relation_changed(self, event: RelationChangedEvent) -> None: """Handle relation changed event. @@ -1771,9 +1752,8 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: If the provider certificate is revoked, emit a CertificateInvalidateEvent, otherwise emit a CertificateAvailableEvent. - When Juju secrets are available, remove the secret for revoked certificate, - or add a secret with the correct expiry time for new certificates. - + Remove the secret for revoked certificate, or add a secret with the correct expiry + time for new certificates. Args: event: Juju event @@ -1781,51 +1761,48 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ + if not event.app: + logger.warning("No remote app in relation - Skipping") + return + if not _relation_data_is_valid(event.relation, event.app, PROVIDER_JSON_SCHEMA): + logger.debug("Relation data did not pass JSON Schema validation") + return + provider_certificates = self.get_provider_certificates() requirer_csrs = [ - certificate_creation_request["certificate_signing_request"] - for certificate_creation_request in self._requirer_csrs + certificate_creation_request.csr + for certificate_creation_request in self.get_requirer_csrs() ] - for certificate in self._provider_certificates: - if certificate["certificate_signing_request"] in requirer_csrs: - if certificate.get("revoked", False): - if JujuVersion.from_environ().has_secrets: - with suppress(SecretNotFoundError): - secret = self.model.get_secret( - label=f"{LIBID}-{certificate['certificate_signing_request']}" - ) - secret.remove_all_revisions() + for certificate in provider_certificates: + if certificate.csr in requirer_csrs: + if certificate.revoked: + with suppress(SecretNotFoundError): + secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") + secret.remove_all_revisions() self.on.certificate_invalidated.emit( reason="revoked", - certificate=certificate["certificate"], - certificate_signing_request=certificate["certificate_signing_request"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, ) else: - if JujuVersion.from_environ().has_secrets: - try: - secret = self.model.get_secret( - label=f"{LIBID}-{certificate['certificate_signing_request']}" - ) - secret.set_content({"certificate": certificate["certificate"]}) - secret.set_info( - expire=self._get_next_secret_expiry_time( - certificate["certificate"] - ), - ) - except SecretNotFoundError: - secret = self.charm.unit.add_secret( - {"certificate": certificate["certificate"]}, - label=f"{LIBID}-{certificate['certificate_signing_request']}", - expire=self._get_next_secret_expiry_time( - certificate["certificate"] - ), - ) + try: + secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") + secret.set_content({"certificate": certificate.certificate}) + secret.set_info( + expire=self._get_next_secret_expiry_time(certificate.certificate), + ) + except SecretNotFoundError: + secret = self.charm.unit.add_secret( + {"certificate": certificate.certificate}, + label=f"{LIBID}-{certificate.csr}", + expire=self._get_next_secret_expiry_time(certificate.certificate), + ) self.on.certificate_available.emit( - certificate_signing_request=certificate["certificate_signing_request"], - certificate=certificate["certificate"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate_signing_request=certificate.csr, + certificate=certificate.certificate, + ca=certificate.ca, + chain=certificate.chain, ) def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: @@ -1849,7 +1826,7 @@ def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: return _get_closest_future_time(expiry_notification_time, expiry_time) def _on_relation_broken(self, event: RelationBrokenEvent) -> None: - """Handle relation broken event. + """Handle Relation Broken Event. Emitting `all_certificates_invalidated` from `relation-broken` rather than `relation-departed` since certs are stored in app data. @@ -1863,7 +1840,7 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None: self.on.all_certificates_invalidated.emit() def _on_secret_expired(self, event: SecretExpiredEvent) -> None: - """Handle secret expired event. + """Handle Secret Expired Event. Loads the certificate from the secret, and will emit 1 of 2 events. @@ -1881,13 +1858,13 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): return csr = event.secret.label[len(f"{LIBID}-") :] - certificate_dict = self._find_certificate_in_relation_data(csr) - if not certificate_dict: + provider_certificate = self._find_certificate_in_relation_data(csr) + if not provider_certificate: # A secret expired but we did not find matching certificate. Cleaning up event.secret.remove_all_revisions() return - expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) + expiry_time = _get_certificate_expiry_time(provider_certificate.certificate) if not expiry_time: # A secret expired but matching certificate is invalid. Cleaning up event.secret.remove_all_revisions() @@ -1896,64 +1873,28 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: if datetime.now(timezone.utc) < expiry_time: logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( - certificate=certificate_dict["certificate"], + certificate=provider_certificate.certificate, expiry=expiry_time.isoformat(), ) event.secret.set_info( - expire=_get_certificate_expiry_time(certificate_dict["certificate"]), + expire=_get_certificate_expiry_time(provider_certificate.certificate), ) else: logger.warning("Certificate is expired") self.on.certificate_invalidated.emit( reason="expired", - certificate=certificate_dict["certificate"], - certificate_signing_request=certificate_dict["certificate_signing_request"], - ca=certificate_dict["ca"], - chain=certificate_dict["chain"], + certificate=provider_certificate.certificate, + certificate_signing_request=provider_certificate.csr, + ca=provider_certificate.ca, + chain=provider_certificate.chain, ) - self.request_certificate_revocation(certificate_dict["certificate"].encode()) + self.request_certificate_revocation(provider_certificate.certificate.encode()) event.secret.remove_all_revisions() - def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any]]: + def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCertificate]: """Return the certificate that match the given CSR.""" - for certificate_dict in self._provider_certificates: - if certificate_dict["certificate_signing_request"] != csr: + for provider_certificate in self.get_provider_certificates(): + if provider_certificate.csr != csr: continue - return certificate_dict + return provider_certificate return None - - def _on_update_status(self, event: UpdateStatusEvent) -> None: - """Handle update status event. - - Goes through each certificate in the "certificates" relation and checks their expiry date. - If they are close to expire (<7 days), emits a CertificateExpiringEvent event and if - they are expired, emits a CertificateExpiredEvent. - - Args: - event (UpdateStatusEvent): Juju event - - Returns: - None - """ - for certificate_dict in self._provider_certificates: - expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) - if not expiry_time: - continue - time_difference = expiry_time - datetime.now(timezone.utc) - if time_difference.total_seconds() < 0: - logger.warning("Certificate is expired") - self.on.certificate_invalidated.emit( - reason="expired", - certificate=certificate_dict["certificate"], - certificate_signing_request=certificate_dict["certificate_signing_request"], - ca=certificate_dict["ca"], - chain=certificate_dict["chain"], - ) - self.request_certificate_revocation(certificate_dict["certificate"].encode()) - continue - if time_difference.total_seconds() < (self.expiry_notification_time * 60 * 60): - logger.warning("Certificate almost expired") - self.on.certificate_expiring.emit( - certificate=certificate_dict["certificate"], - expiry=expiry_time.isoformat(), - )