From 3b6aedf0f48b6cdbd3266ac35433a0dc315cf8cc Mon Sep 17 00:00:00 2001 From: Alex Batisse Date: Mon, 2 Dec 2024 15:06:39 +0100 Subject: [PATCH] [DPE-5987] Add expose-external configuration option (#172) --- .github/workflows/ci.yaml | 2 +- config.yaml | 11 +- .../data_platform_libs/v0/data_models.py | 354 ++++++++++++++++++ lib/charms/zookeeper/v0/client.py | 77 +++- src/charm.py | 15 +- src/core/cluster.py | 104 +++-- src/core/models.py | 49 ++- src/core/structured_config.py | 23 ++ src/core/stubs.py | 27 ++ src/core/workload.py | 2 +- src/events/tls.py | 26 +- src/events/upgrade.py | 5 +- src/managers/config.py | 13 +- src/managers/k8s.py | 215 +++++++++++ src/managers/quorum.py | 3 +- src/managers/tls.py | 70 ++++ tests/integration/test_provider.py | 4 +- tests/unit/conftest.py | 36 +- tests/unit/test_charm.py | 2 + tests/unit/test_client.py | 12 +- tests/unit/test_config.py | 10 +- tests/unit/test_provider.py | 58 ++- tests/unit/test_structured_config.py | 73 ++++ tests/unit/test_tls.py | 119 +++++- 24 files changed, 1199 insertions(+), 111 deletions(-) create mode 100644 lib/charms/data_platform_libs/v0/data_models.py create mode 100644 src/core/structured_config.py create mode 100644 src/managers/k8s.py create mode 100644 tests/unit/test_structured_config.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 38e6117a..98359c36 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -167,7 +167,7 @@ jobs: provider: lxd juju-channel: 3.4/stable bootstrap-options: "--agent-version 3.4.2" - lxd-channel: 5.20/stable # FIXME: https://warthogs.atlassian.net/browse/DPE-4047 + lxd-channel: 5.21/stable - name: Download packed charm(s) uses: actions/download-artifact@v4 with: diff --git a/config.yaml b/config.yaml index e8c0e025..17a0458a 100644 --- a/config.yaml +++ b/config.yaml @@ -16,7 +16,14 @@ options: type: int default: 2000 log-level: - description: 'Level of logging for the different components operated by the charm. Possible values: ERROR, WARNING, INFO, DEBUG' + description: "Level of logging for the different components operated by the charm. Possible values: ERROR, WARNING, INFO, DEBUG" type: string default: "INFO" - + expose-external: + description: "String to determine how to expose the ZooKeeper cluster externally from the Kubernetes cluster. Possible values: 'nodeport', 'loadbalancer', 'false'" + type: string + default: "false" + loadbalancer-extra-annotations: + description: "String in json format to describe extra configuration for load balancers. Needed for some cloud providers or services." + type: string + default: "" diff --git a/lib/charms/data_platform_libs/v0/data_models.py b/lib/charms/data_platform_libs/v0/data_models.py new file mode 100644 index 00000000..a1dbb829 --- /dev/null +++ b/lib/charms/data_platform_libs/v0/data_models.py @@ -0,0 +1,354 @@ +# Copyright 2023 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Library to provide simple API for promoting typed, validated and structured dataclass in charms. + +Dict-like data structure are often used in charms. They are used for config, action parameters +and databag. This library aims at providing simple API for using pydantic BaseModel-derived class +in charms, in order to enhance: +* Validation, by embedding custom business logic to validate single parameters or even have + validators that acts across different fields +* Parsing, by loading data into pydantic object we can both allow for other types (e.g. float) to + be used in configuration/parameters as well as specify even nested complex objects for databags +* Static typing checks, by moving from dict-like object to classes with typed-annotated properties, + that can be statically checked using mypy to ensure that the code is correct. + +Pydantic models can be used on: + +* Charm Configuration (as defined in config.yaml) +* Actions parameters (as defined in actions.yaml) +* Application/Unit Databag Information (thus making it more structured and encoded) + + +## Creating models + +Any data-structure can be modeled using dataclasses instead of dict-like objects (e.g. storing +config, action parameters and databags). Within pydantic, we can define dataclasses that provides +also parsing and validation on standard dataclass implementation: + +```python + +from charms.data_platform_libs.v0.data_models import BaseConfigModel + +class MyConfig(BaseConfigModel): + + my_key: int + + @validator("my_key") + def is_lower_than_100(cls, v: int): + if v > 100: + raise ValueError("Too high") + +``` + +This should allow to collapse both parsing and validation as the dataclass object is parsed and +created: + +```python +dataclass = MyConfig(my_key="1") + +dataclass.my_key # this returns 1 (int) +dataclass["my_key"] # this returns 1 (int) + +dataclass = MyConfig(my_key="102") # this returns a ValueError("Too High") +``` + +## Charm Configuration Model + +Using the class above, we can implement parsing and validation of configuration by simply +extending our charms using the `TypedCharmBase` class, as shown below. + +```python +class MyCharm(TypedCharmBase[MyConfig]): + config_type = MyConfig + + # everywhere in the code you will have config property already parsed and validate + def my_method(self): + self.config: MyConfig +``` + +## Action parameters + +In order to parse action parameters, we can use a decorator to be applied to action event +callbacks, as shown below. + +```python +@validate_params(PullActionModel) +def _pull_site_action( + self, event: ActionEvent, + params: Optional[Union[PullActionModel, ValidationError]] = None +): + if isinstance(params, ValidationError): + # handle errors + else: + # do stuff +``` + +Note that this changes the signature of the callbacks by adding an extra parameter with the parsed +counterpart of the `event.params` dict-like field. If validation fails, we return (not throw!) the +exception, to be handled (or raised) in the callback. + +## Databag + +In order to parse databag fields, we define a decorator to be applied to base relation event +callbacks. + +```python +@parse_relation_data(app_model=AppDataModel, unit_model=UnitDataModel) +def _on_cluster_relation_joined( + self, event: RelationEvent, + app_data: Optional[Union[AppDataModel, ValidationError]] = None, + unit_data: Optional[Union[UnitDataModel, ValidationError]] = None +) -> None: + ... +``` + +The parameters `app_data` and `unit_data` refers to the databag of the entity which fired the +RelationEvent. + +When we want to access to a relation databag outsides of an action, it can be useful also to +compact multiple databags into a single object (if there are no conflicting fields), e.g. + +```python + +class ProviderDataBag(BaseClass): + provider_key: str + +class RequirerDataBag(BaseClass): + requirer_key: str + +class MergedDataBag(ProviderDataBag, RequirerDataBag): + pass + +merged_data = get_relation_data_as( + MergedDataBag, relation.data[self.app], relation.data[relation.app] +) + +merged_data.requirer_key +merged_data.provider_key + +``` + +The above code can be generalized to other kinds of merged objects, e.g. application and unit, and +it can be extended to multiple sources beyond 2: + +```python +merged_data = get_relation_data_as( + MergedDataBag, relation.data[self.app], relation.data[relation.app], ... +) +``` + +""" + +import json +from functools import reduce, wraps +from typing import Callable, Generic, MutableMapping, Optional, Type, TypeVar, Union + +import pydantic +from ops.charm import ActionEvent, CharmBase, RelationEvent +from ops.model import RelationDataContent +from pydantic import BaseModel, ValidationError + +# The unique Charmhub library identifier, never change it +LIBID = "cb2094c5b07d47e1bf346aaee0fcfcfe" + +# Increment this major API version when introducing breaking changes +LIBAPI = 0 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 4 + +PYDEPS = ["ops>=2.0.0", "pydantic>=1.10,<2"] + +G = TypeVar("G") +T = TypeVar("T", bound=BaseModel) +AppModel = TypeVar("AppModel", bound=BaseModel) +UnitModel = TypeVar("UnitModel", bound=BaseModel) + +DataBagNativeTypes = (int, str, float) + + +class BaseConfigModel(BaseModel): + """Class to be used for defining the structured configuration options.""" + + def __getitem__(self, x): + """Return the item using the notation instance[key].""" + return getattr(self, x.replace("-", "_")) + + +class TypedCharmBase(CharmBase, Generic[T]): + """Class to be used for extending config-typed charms.""" + + config_type: Type[T] + + @property + def config(self) -> T: + """Return a config instance validated and parsed using the provided pydantic class.""" + translated_keys = {k.replace("-", "_"): v for k, v in self.model.config.items()} + return self.config_type(**translated_keys) + + +def validate_params(cls: Type[T]): + """Return a decorator to allow pydantic parsing of action parameters. + + Args: + cls: Pydantic class representing the model to be used for parsing the content of the + action parameter + """ + + def decorator( + f: Callable[[CharmBase, ActionEvent, Union[T, ValidationError]], G] + ) -> Callable[[CharmBase, ActionEvent], G]: + @wraps(f) + def event_wrapper(self: CharmBase, event: ActionEvent): + try: + params = cls( + **{key.replace("-", "_"): value for key, value in event.params.items()} + ) + except ValidationError as e: + params = e + return f(self, event, params) + + return event_wrapper + + return decorator + + +def write(relation_data: RelationDataContent, model: BaseModel): + """Write the data contained in a domain object to the relation databag. + + Args: + relation_data: pointer to the relation databag + model: instance of pydantic model to be written + """ + for key, value in model.dict(exclude_none=False).items(): + if value: + relation_data[key.replace("_", "-")] = ( + str(value) + if any(isinstance(value, _type) for _type in DataBagNativeTypes) + else json.dumps(value) + ) + else: + relation_data[key.replace("_", "-")] = "" + + +def read(relation_data: MutableMapping[str, str], obj: Type[T]) -> T: + """Read data from a relation databag and parse it into a domain object. + + Args: + relation_data: pointer to the relation databag + obj: pydantic class representing the model to be used for parsing + """ + return obj( + **{ + field_name: ( + relation_data[parsed_key] + if field.outer_type_ in DataBagNativeTypes + else json.loads(relation_data[parsed_key]) + ) + for field_name, field in obj.__fields__.items() + # pyright: ignore[reportGeneralTypeIssues] + if (parsed_key := field_name.replace("_", "-")) in relation_data + if relation_data[parsed_key] + } + ) + + +def parse_relation_data( + app_model: Optional[Type[AppModel]] = None, unit_model: Optional[Type[UnitModel]] = None +): + """Return a decorator to allow pydantic parsing of the app and unit databags. + + Args: + app_model: Pydantic class representing the model to be used for parsing the content of the + app databag. None if no parsing ought to be done. + unit_model: Pydantic class representing the model to be used for parsing the content of the + unit databag. None if no parsing ought to be done. + """ + + def decorator( + f: Callable[ + [ + CharmBase, + RelationEvent, + Optional[Union[AppModel, ValidationError]], + Optional[Union[UnitModel, ValidationError]], + ], + G, + ] + ) -> Callable[[CharmBase, RelationEvent], G]: + @wraps(f) + def event_wrapper(self: CharmBase, event: RelationEvent): + try: + app_data = ( + read(event.relation.data[event.app], app_model) + if app_model is not None and event.app + else None + ) + except pydantic.ValidationError as e: + app_data = e + + try: + unit_data = ( + read(event.relation.data[event.unit], unit_model) + if unit_model is not None and event.unit + else None + ) + except pydantic.ValidationError as e: + unit_data = e + + return f(self, event, app_data, unit_data) + + return event_wrapper + + return decorator + + +class RelationDataModel(BaseModel): + """Base class to be used for creating data models to be used for relation databags.""" + + def write(self, relation_data: RelationDataContent): + """Write data to a relation databag. + + Args: + relation_data: pointer to the relation databag + """ + return write(relation_data, self) + + @classmethod + def read(cls, relation_data: RelationDataContent) -> "RelationDataModel": + """Read data from a relation databag and parse it as an instance of the pydantic class. + + Args: + relation_data: pointer to the relation databag + """ + return read(relation_data, cls) + + +def get_relation_data_as( + model_type: Type[AppModel], + *relation_data: RelationDataContent, +) -> Union[AppModel, ValidationError]: + """Return a merged representation of the provider and requirer databag into a single object. + + Args: + model_type: pydantic class representing the merged databag + relation_data: list of RelationDataContent of provider/requirer/unit sides + """ + try: + app_data = read(reduce(lambda x, y: dict(x) | dict(y), relation_data, {}), model_type) + except pydantic.ValidationError as e: + app_data = e + return app_data diff --git a/lib/charms/zookeeper/v0/client.py b/lib/charms/zookeeper/v0/client.py index 84e0ab35..d0b2d78a 100644 --- a/lib/charms/zookeeper/v0/client.py +++ b/lib/charms/zookeeper/v0/client.py @@ -74,7 +74,7 @@ def update_cluster(new_members: List[str], event: EventBase) -> None: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 7 +LIBPATCH = 8 logger = logging.getLogger(__name__) @@ -101,6 +101,12 @@ class QuorumLeaderNotFoundError(Exception): pass +class NoUnitFoundError(Exception): + """Generic exception for when there are no running zk unit in the app.""" + + pass + + class ZooKeeperManager: """Handler for performing ZK commands.""" @@ -114,6 +120,7 @@ def __init__( keyfile_path: Optional[str] = "", keyfile_password: Optional[str] = "", certfile_path: Optional[str] = "", + read_only: bool = True, ): self.hosts = hosts self.username = username @@ -123,12 +130,21 @@ def __init__( self.keyfile_path = keyfile_path self.keyfile_password = keyfile_password self.certfile_path = certfile_path - self.leader = "" + self.zk_host = "" + self.read_only = read_only - try: - self.leader = self.get_leader() - except RetryError: - raise QuorumLeaderNotFoundError("quorum leader not found") + if not read_only: + try: + self.zk_host = self.get_leader() + except RetryError: + raise QuorumLeaderNotFoundError("quorum leader not found") + + else: + try: + self.zk_host = self.get_any_unit() + + except RetryError: + raise NoUnitFoundError @retry( wait=wait_fixed(3), @@ -170,6 +186,35 @@ def get_leader(self) -> str: return leader or "" + @retry( + wait=wait_fixed(3), + stop=stop_after_attempt(2), + retry=retry_if_not_result(lambda result: True if result else False), + ) + def get_any_unit(self) -> str: + any_host = None + for host in self.hosts: + try: + with ZooKeeperClient( + host=host, + client_port=self.client_port, + username=self.username, + password=self.password, + use_ssl=self.use_ssl, + keyfile_path=self.keyfile_path, + keyfile_password=self.keyfile_password, + certfile_path=self.certfile_path, + ) as zk: + response = zk.srvr + if response: + any_host = host + break + except KazooTimeoutError: # in the case of having a dead unit in relation data + logger.debug(f"TIMEOUT - {host}") + continue + + return any_host or "" + @property def server_members(self) -> Set[str]: """The current members within the ZooKeeper quorum. @@ -179,7 +224,7 @@ def server_members(self) -> Set[str]: e.g {"server.1=10.141.78.207:2888:3888:participant;0.0.0.0:2181"} """ with ZooKeeperClient( - host=self.leader, + host=self.zk_host, client_port=self.client_port, username=self.username, password=self.password, @@ -200,7 +245,7 @@ def config_version(self) -> int: The zookeeper config version decoded from base16 """ with ZooKeeperClient( - host=self.leader, + host=self.zk_host, client_port=self.client_port, username=self.username, password=self.password, @@ -221,7 +266,7 @@ def members_syncing(self) -> bool: True if any members are syncing. Otherwise False. """ with ZooKeeperClient( - host=self.leader, + host=self.zk_host, client_port=self.client_port, username=self.username, password=self.password, @@ -305,7 +350,7 @@ def add_members(self, members: Iterable[str]) -> None: # specific connection to leader with ZooKeeperClient( - host=self.leader, + host=self.zk_host, client_port=self.client_port, username=self.username, password=self.password, @@ -330,7 +375,7 @@ def remove_members(self, members: Iterable[str]) -> None: for member in members: member_id = re.findall(r"server.([0-9]+)", member)[0] with ZooKeeperClient( - host=self.leader, + host=self.zk_host, client_port=self.client_port, username=self.username, password=self.password, @@ -356,7 +401,7 @@ def leader_znodes(self, path: str) -> Set[str]: Set of all nested child zNodes """ with ZooKeeperClient( - host=self.leader, + host=self.zk_host, client_port=self.client_port, username=self.username, password=self.password, @@ -377,7 +422,7 @@ def create_znode_leader(self, path: str, acls: List[ACL] | None = None) -> None: acls: the ACLs to be set on that path """ with ZooKeeperClient( - host=self.leader, + host=self.zk_host, client_port=self.client_port, username=self.username, password=self.password, @@ -396,7 +441,7 @@ def set_acls_znode_leader(self, path: str, acls: List[ACL] | None = None) -> Non acls: the new ACLs to be set on that path """ with ZooKeeperClient( - host=self.leader, + host=self.zk_host, client_port=self.client_port, username=self.username, password=self.password, @@ -414,7 +459,7 @@ def delete_znode_leader(self, path: str) -> None: path: the zNode path to delete """ with ZooKeeperClient( - host=self.leader, + host=self.zk_host, client_port=self.client_port, username=self.username, password=self.password, @@ -432,7 +477,7 @@ def get_version(self) -> str: String of ZooKeeper service version """ with ZooKeeperClient( - host=self.leader, + host=self.zk_host, client_port=self.client_port, username=self.username, password=self.password, diff --git a/src/charm.py b/src/charm.py index 397d97e2..f1f3b3bf 100755 --- a/src/charm.py +++ b/src/charm.py @@ -7,13 +7,13 @@ import logging import time +from charms.data_platform_libs.v0.data_models import TypedCharmBase from charms.grafana_agent.v0.cos_agent import COSAgentProvider from charms.rolling_ops.v0.rollingops import RollingOpsManager from charms.zookeeper.v0.client import QuorumLeaderNotFoundError from kazoo.exceptions import BadArgumentsError, BadVersionError, ReconfigInProcessError from ops import ( ActiveStatus, - CharmBase, EventBase, InstallEvent, RelationDepartedEvent, @@ -22,10 +22,11 @@ StatusBase, StorageAttachedEvent, WaitingStatus, + main, ) -from ops.main import main from core.cluster import ClusterState +from core.structured_config import CharmConfig from events.backup import BackupEvents from events.password_actions import PasswordActionEvents from events.provider import ProviderEvents @@ -45,16 +46,21 @@ Status, ) from managers.config import ConfigManager +from managers.k8s import K8sManager from managers.quorum import QuorumManager from managers.tls import TLSManager from workload import ZKWorkload logger = logging.getLogger(__name__) +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.getLogger("httpxcore").setLevel(logging.WARNING) -class ZooKeeperCharm(CharmBase): +class ZooKeeperCharm(TypedCharmBase[CharmConfig]): """Charmed Operator for ZooKeeper.""" + config_type = CharmConfig + def __init__(self, *args): super().__init__(*args) self.name = CHARM_KEY @@ -84,6 +90,9 @@ def __init__(self, *args): self.config_manager = ConfigManager( state=self.state, workload=self.workload, substrate=SUBSTRATE, config=self.config ) + self.k8s_manager = K8sManager( + pod_name=self.state.unit_server.pod_name, namespace=self.model.name + ) # --- LIB EVENT HANDLERS --- diff --git a/src/core/cluster.py b/src/core/cluster.py index bfe45108..4d8ea4cc 100644 --- a/src/core/cluster.py +++ b/src/core/cluster.py @@ -4,7 +4,8 @@ """Collection of global cluster state for the ZooKeeper quorum.""" import logging -from typing import Dict, Set +from ipaddress import IPv4Address, IPv6Address +from typing import TYPE_CHECKING from charms.data_platform_libs.v0.data_interfaces import ( DatabaseProviderData, @@ -12,10 +13,13 @@ DataPeerOtherUnitData, DataPeerUnitData, ) -from ops.framework import Framework, Object +from lightkube.core.exceptions import ApiError as LightKubeApiError +from ops.framework import Object from ops.model import Relation, Unit +from tenacity import retry, retry_if_exception_cause_type, stop_after_attempt, wait_fixed from core.models import SUBSTRATES, ZKClient, ZKCluster, ZKServer +from core.stubs import ExposeExternal from literals import ( CLIENT_PORT, PEER, @@ -25,13 +29,16 @@ Status, ) +if TYPE_CHECKING: + from charm import ZooKeeperCharm + logger = logging.getLogger(__name__) class ClusterState(Object): """Collection of global cluster state for Framework/Object.""" - def __init__(self, charm: Framework | Object, substrate: SUBSTRATES): + def __init__(self, charm: "ZooKeeperCharm", substrate: SUBSTRATES): super().__init__(parent=charm, key="charm_state") self.substrate: SUBSTRATES = substrate @@ -41,6 +48,7 @@ def __init__(self, charm: Framework | Object, substrate: SUBSTRATES): ) self.client_provider_interface = DatabaseProviderData(self.model, relation_name=REL_NAME) self._servers_data = {} + self.config = charm.config # --- RAW RELATION --- @@ -50,7 +58,7 @@ def peer_relation(self) -> Relation | None: return self.model.get_relation(PEER) @property - def client_relations(self) -> Set[Relation]: + def client_relations(self) -> set[Relation]: """The relations of all client applications.""" return set(self.model.relations[REL_NAME]) @@ -67,7 +75,7 @@ def unit_server(self) -> ZKServer: ) @property - def peer_units_data_interfaces(self) -> Dict[Unit, DataPeerOtherUnitData]: + def peer_units_data_interfaces(self) -> dict[Unit, DataPeerOtherUnitData]: """The cluster peer relation.""" if not self.peer_relation or not self.peer_relation.units: return {} @@ -90,7 +98,7 @@ def cluster(self) -> ZKCluster: ) @property - def servers(self) -> Set[ZKServer]: + def servers(self) -> set[ZKServer]: """Grabs all servers in the current peer relation, including the running unit server. Returns: @@ -114,7 +122,7 @@ def servers(self) -> Set[ZKServer]: return servers @property - def clients(self) -> Set[ZKClient]: + def clients(self) -> set[ZKClient]: """The state for all related client Applications.""" clients = set() for relation in self.client_relations: @@ -129,10 +137,8 @@ def clients(self) -> Set[ZKClient]: substrate=self.substrate, local_app=self.cluster.app, password=self.cluster.client_passwords.get(f"relation-{relation.id}", ""), - uris=",".join( - [f"{endpoint}:{self.client_port}" for endpoint in self.endpoints] - ), - endpoints=",".join(self.endpoints), + uris=self.endpoints, + endpoints=self.endpoints, tls="enabled" if self.cluster.tls else "disabled", ) ) @@ -141,6 +147,16 @@ def clients(self) -> Set[ZKClient]: # --- CLUSTER INIT --- + @property + def bind_address(self) -> IPv4Address | IPv6Address | str: + """The network binding address from the peer relation.""" + bind_address = None + if self.peer_relation: + if binding := self.model.get_binding(self.peer_relation): + bind_address = binding.network.bind_address + + return bind_address or "" + @property def client_port(self) -> int: """The port for clients to use. @@ -150,24 +166,66 @@ def client_port(self) -> int: 2181 if TLS is not enabled 2182 if TLS is enabled """ - if self.cluster.tls: - return SECURE_CLIENT_PORT - - return CLIENT_PORT + return SECURE_CLIENT_PORT if self.cluster.tls else CLIENT_PORT @property - def endpoints(self) -> list[str]: - """The connection uris for all started ZooKeeper units. - - Returns: - List of unit addresses + @retry( + wait=wait_fixed(5), + stop=stop_after_attempt(3), + retry=retry_if_exception_cause_type(LightKubeApiError), + reraise=True, + ) + def endpoints_external(self) -> str: + """Comma-separated string of connection uris for all started ZooKeeper unit, for external access. + + K8s only. """ - return sorted( - [server.host if self.substrate == "k8s" else server.ip for server in self.servers] + auth = "plain" if not self.cluster.tls else "tls" + expose = self.config.expose_external + if expose is ExposeExternal.NODEPORT: + # We might have several of them if we run on multiple k8s nodes + return ",".join( + sorted( + { + f"{server.node_ip}:{self.unit_server.k8s.get_nodeport(auth)}" + for server in self.servers + } + ) + ) + + elif expose is ExposeExternal.LOADBALANCER: + # There should be only one host + return f"{next(iter(self.servers)).loadbalancer_ip}:{self.client_port}" + + else: # pragma: nocover + # ExposeExternal.FALSE already covered + raise ValueError(f"{expose} not recognized.") + + @property + def endpoints(self) -> str: + """Comma-separated string of connection uris for all started ZooKeeper units.""" + if self.substrate == "k8s" and self.config.expose_external is not ExposeExternal.FALSE: + try: + return self.endpoints_external + except LightKubeApiError as e: + logger.debug(e) + return "" + + return ",".join( + sorted( + [ + ( + f"{server.internal_address}:{self.client_port}" + if self.substrate == "k8s" + else f"{server.internal_address}:{self.client_port}" + ) + for server in self.servers + ] + ) ) @property - def started_servers(self) -> Set[ZKServer]: + def started_servers(self) -> set[ZKServer]: """The server states of all started peer-related Units.""" return {server for server in self.servers if server.started} diff --git a/src/core/models.py b/src/core/models.py index ed132c26..641a80c9 100644 --- a/src/core/models.py +++ b/src/core/models.py @@ -7,6 +7,7 @@ import logging import warnings from collections.abc import MutableMapping +from functools import cached_property from typing import Literal from charms.data_platform_libs.v0.data_interfaces import Data, DataPeerData, DataPeerUnitData @@ -15,6 +16,7 @@ from core.stubs import RestoreStep, S3ConnectionInfo from literals import CHARM_USERS, CLIENT_PORT, ELECTION_PORT, SECRETS_APP, SERVER_PORT +from managers.k8s import K8sManager logger = logging.getLogger(__name__) @@ -317,6 +319,10 @@ def __init__( ): super().__init__(relation, data_interface, component, substrate) self.unit = component + self.k8s = K8sManager( + pod_name=self.pod_name, + namespace=self.unit._backend.model_name, + ) @property def unit_id(self) -> int: @@ -367,11 +373,11 @@ def server_id(self) -> int: return self.unit_id + 1 @property - def host(self) -> str: - """The hostname for the unit.""" + def internal_address(self) -> str: + """The hostname for the unit, for internal communication.""" host = "" if self.substrate == "vm": - for key in ["hostname", "ip", "private-address"]: + for key in ["ip", "hostname", "private-address"]: if host := self.relation_data.get(key, ""): break @@ -383,7 +389,7 @@ def host(self) -> str: @property def server_string(self) -> str: """The server string for the ZooKeeper server.""" - return f"server.{self.server_id}={self.host}:{SERVER_PORT}:{ELECTION_PORT}:participant;0.0.0.0:{CLIENT_PORT}" + return f"server.{self.server_id}={self.internal_address}:{SERVER_PORT}:{ELECTION_PORT}:participant;0.0.0.0:{CLIENT_PORT}" # -- TLS -- @@ -445,18 +451,31 @@ def ca_cert(self) -> str: """The root CA contents for the unit to use for TLS.""" return self.relation_data.get("ca-cert", self.ca) - @property - def sans(self) -> dict[str, list[str]]: - """The Subject Alternative Name for the unit's TLS certificates.""" - if not all([self.ip, self.hostname, self.fqdn]): - return {} - - return { - "sans_ip": [self.ip], - "sans_dns": [self.hostname, self.fqdn], - } - @property def restore_progress(self) -> RestoreStep: """Latest restore flow step the unit went through.""" return RestoreStep(self.relation_data.get("restore-progress", "")) + + @property + def pod_name(self) -> str: + """The name of the K8s Pod for the unit. + + K8s-only. + """ + return self.unit.name.replace("/", "-") + + @cached_property + def node_ip(self) -> str: + """The IPV4/IPV6 IP address of the Node the unit is on. + + K8s-only. + """ + return self.k8s.get_node_ip(self.pod_name) + + @cached_property + def loadbalancer_ip(self) -> str: + """The IPV4/IPV6 IP address of the LoadBalancer exposing the unit. + + K8s-only. + """ + return self.k8s.get_loadbalancer() diff --git a/src/core/structured_config.py b/src/core/structured_config.py new file mode 100644 index 00000000..082a0a0f --- /dev/null +++ b/src/core/structured_config.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Structured configuration for the ZooKeeper charm.""" +import logging + +from charms.data_platform_libs.v0.data_models import BaseConfigModel +from pydantic import Field + +from core.stubs import ExposeExternal, LogLevel + +logger = logging.getLogger(__name__) + + +class CharmConfig(BaseConfigModel): + """Manager for the structured configuration.""" + + init_limit: int = Field(gt=0) + sync_limit: int = Field(gt=0) + tick_time: int = Field(gt=0) + log_level: LogLevel + expose_external: ExposeExternal diff --git a/src/core/stubs.py b/src/core/stubs.py index 645e646c..084890c8 100644 --- a/src/core/stubs.py +++ b/src/core/stubs.py @@ -3,9 +3,36 @@ # See LICENSE file for licensing details. """Types module.""" +from dataclasses import dataclass from enum import Enum from typing import TypedDict + +class LogLevel(str, Enum): + """Enum for the `log-level` field.""" + + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + DEBUG = "DEBUG" + + +class ExposeExternal(str, Enum): + """Enum for the `expose-external` field.""" + + FALSE = "false" + NODEPORT = "nodeport" + LOADBALANCER = "loadbalancer" + + +@dataclass +class SANs: + """Subject Alternative Name (SAN)s used to create multi-domains certificates.""" + + sans_ip: list[str] + sans_dns: list[str] + + S3ConnectionInfo = TypedDict( "S3ConnectionInfo", { diff --git a/src/core/workload.py b/src/core/workload.py index 6f3005f4..2bc0f5c4 100644 --- a/src/core/workload.py +++ b/src/core/workload.py @@ -147,7 +147,7 @@ def write(self, content: str, path: str) -> None: ... @abstractmethod - def exec(self, command: list[str], working_dir: str | None = None) -> None: + def exec(self, command: list[str], working_dir: str | None = None) -> str: """Runs a command on the workload substrate.""" ... diff --git a/src/events/tls.py b/src/events/tls.py index 72e63b92..1e0e9d6e 100644 --- a/src/events/tls.py +++ b/src/events/tls.py @@ -5,6 +5,7 @@ """Event handler for related applications on the `certificates` relation interface.""" import base64 import logging +import os import re from typing import TYPE_CHECKING @@ -17,7 +18,7 @@ from ops.charm import ActionEvent, RelationCreatedEvent, RelationJoinedEvent from ops.framework import EventBase, Object -from literals import Status +from literals import SUBSTRATE, Status if TYPE_CHECKING: from charm import ZooKeeperCharm @@ -93,11 +94,16 @@ def _on_certificates_joined(self, event: RelationJoinedEvent) -> None: } ) + subject = ( + os.uname()[1] if SUBSTRATE == "k8s" else self.charm.state.unit_server.internal_address + ) + sans = self.charm.tls_manager.build_sans() + csr = generate_csr( private_key=self.charm.state.unit_server.private_key.encode("utf-8"), - subject=self.charm.state.unit_server.host, - sans_ip=self.charm.state.unit_server.sans.get("sans_ip", []), - sans_dns=self.charm.state.unit_server.sans.get("sans_dns", []), + subject=subject, + sans_ip=sans.sans_ip, + sans_dns=sans.sans_dns, ) self.charm.state.unit_server.update({"csr": csr.decode("utf-8").strip()}) @@ -120,6 +126,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: self.charm.tls_manager.set_certificate() self.charm.tls_manager.set_truststore() self.charm.tls_manager.set_p12_keystore() + self.charm.on.config_changed.emit() def _on_certificate_expiring(self, _: EventBase) -> None: """Handler for `certificates_expiring` event when certs need renewing.""" @@ -127,11 +134,16 @@ def _on_certificate_expiring(self, _: EventBase) -> None: logger.error("Missing unit private key and/or old csr") return + subject = ( + os.uname()[1] if SUBSTRATE == "k8s" else self.charm.state.unit_server.internal_address + ) + sans = self.charm.tls_manager.build_sans() + new_csr = generate_csr( private_key=self.charm.state.unit_server.private_key.encode("utf-8"), - subject=self.charm.state.unit_server.host, - sans_ip=self.charm.state.unit_server.sans["sans_ip"], - sans_dns=self.charm.state.unit_server.sans["sans_dns"], + subject=subject, + sans_ip=sans.sans_ip, + sans_dns=sans.sans_dns, ) self.certificates.request_certificate_renewal( diff --git a/src/events/upgrade.py b/src/events/upgrade.py index 750ab64e..97fe700c 100644 --- a/src/events/upgrade.py +++ b/src/events/upgrade.py @@ -53,10 +53,11 @@ def idle(self) -> bool: def client(self) -> ZooKeeperManager: """Cached client manager application for performing ZK commands.""" return ZooKeeperManager( - hosts=[server.host for server in self.charm.state.started_servers], + hosts=[server.internal_address for server in self.charm.state.started_servers], client_port=CLIENT_PORT, username="super", password=self.charm.state.cluster.internal_user_credentials.get("super", ""), + read_only=False, ) @retry(stop=stop_after_attempt(5), wait=wait_random(min=1, max=5), reraise=True) @@ -77,7 +78,7 @@ def build_upgrade_stack(self) -> list[int]: upgrade_stack = [] for server in self.charm.state.servers: # upgrade quorum leader last - if server.host == self.client.leader: + if server.internal_address == self.client.zk_host: upgrade_stack.insert(0, server.unit_id) else: upgrade_stack.append(server.unit_id) diff --git a/src/managers/config.py b/src/managers/config.py index d654fab1..fd0ca98f 100644 --- a/src/managers/config.py +++ b/src/managers/config.py @@ -6,9 +6,8 @@ import logging from textwrap import dedent -from ops.model import ConfigData - from core.cluster import SUBSTRATES, ClusterState +from core.structured_config import CharmConfig from core.workload import WorkloadBase from literals import JMX_PORT, METRICS_PROVIDER_PORT @@ -58,7 +57,7 @@ def __init__( state: ClusterState, workload: WorkloadBase, substrate: SUBSTRATES, - config: ConfigData, + config: CharmConfig, ): self.state = state self.workload = workload @@ -72,16 +71,8 @@ def log_level(self) -> str: Returns: String with these possible values: DEBUG, INFO, WARN, ERROR """ - # FIXME: use pydantic config models for this validation instead - permitted_levels = ["DEBUG", "INFO", "WARNING", "ERROR"] config_log_level = self.config["log-level"] - if config_log_level not in permitted_levels: - logger.error( - f"Invalid log-level config value of {config_log_level}. Must be one of {','.join(permitted_levels)}. Defaulting to 'INFO'" - ) - config_log_level = "INFO" - # Remapping to WARN that is generally used in Java applications based on log4j and logback. if config_log_level == "WARNING": return "WARN" diff --git a/src/managers/k8s.py b/src/managers/k8s.py new file mode 100644 index 00000000..3d07c4e2 --- /dev/null +++ b/src/managers/k8s.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Manager for handling ZooKeeper Kubernetes resources.""" +import logging +from typing import Literal + +from lightkube.core.client import Client +from lightkube.core.exceptions import ApiError +from lightkube.models.core_v1 import LoadBalancerIngress, ServicePort, ServiceSpec +from lightkube.models.meta_v1 import ObjectMeta +from lightkube.resources.core_v1 import Node, Pod, Service + +from literals import CLIENT_PORT, SECURE_CLIENT_PORT + +logger = logging.getLogger(__name__) + + +class K8sManager: + """Manager for handling ZooKeeper Kubernetes resources.""" + + def __init__(self, pod_name: str, namespace: str) -> None: + self.pod_name = pod_name + self.app_name, _, _ = pod_name.rpartition("-") + self.namespace = namespace + self.exposer_service_name = f"{self.app_name}-exposer" + + @property + def client(self) -> Client: + """The Lightkube client.""" + return Client( + field_manager=self.app_name, + namespace=self.namespace, + ) + + def apply_service(self, service: Service) -> None: + """Apply a given Service.""" + try: + self.client.apply(service) + except ApiError as e: + if e.status.code == 403: + logger.error("Could not apply service, application needs `juju trust`") + return + if e.status.code == 422 and "port is already allocated" in e.status.message: + logger.error(e.status.message) + return + else: + raise + + def remove_service(self, service_name: str) -> None: + """Remove the exposer service.""" + try: + self.client.delete(Service, name=service_name) + except ApiError as e: + if e.status.code == 403: + logger.error("Could not apply service, application needs `juju trust`") + return + if e.status.code == 404: + return + else: + raise + + def build_nodeport_service(self) -> Service: + """Build the exposer service for 'nodeport' configuration option.""" + # Pods are incrementally added to the StatefulSet, so we will always have a "0". + # Even if the "0" unit is not the leader, we just want a reference to the StatefulSet + # which owns the "0" pod. + pod = self.get_pod(f"{self.app_name}-0") + if not pod.metadata: + raise Exception(f"Could not find metadata for {pod}") + + ports = [ + ServicePort( + protocol="TCP", + port=CLIENT_PORT, + targetPort=CLIENT_PORT, + name=f"{self.exposer_service_name}-plain", + ), + ServicePort( + protocol="TCP", + port=SECURE_CLIENT_PORT, + targetPort=SECURE_CLIENT_PORT, + name=f"{self.exposer_service_name}-tls", + ), + ] + + return Service( + metadata=ObjectMeta( + name=self.exposer_service_name, + namespace=self.namespace, + # owned by the StatefulSet + ownerReferences=pod.metadata.ownerReferences, + ), + spec=ServiceSpec( + externalTrafficPolicy="Local", + type="NodePort", + selector={"app.kubernetes.io/name": self.app_name}, + ports=ports, + ), + ) + + def build_loadbalancer_service(self) -> Service: + """Build the exposer service for 'loadbalancer' configuration option.""" + # Pods are incrementally added to the StatefulSet, so we will always have a "0". + # Even if the "0" unit is not the leader, we just want a reference to the StatefulSet + # which owns the "0" pod. + pod = self.get_pod(f"{self.app_name}-0") + if not pod.metadata: + raise Exception(f"Could not find metadata for {pod}") + + ports = [ + ServicePort( + protocol="TCP", + port=CLIENT_PORT, + targetPort=CLIENT_PORT, + name=f"{self.exposer_service_name}-plain", + ), + ServicePort( + protocol="TCP", + port=SECURE_CLIENT_PORT, + targetPort=SECURE_CLIENT_PORT, + name=f"{self.exposer_service_name}-tls", + ), + ] + + return Service( + metadata=ObjectMeta( + name=self.exposer_service_name, + namespace=self.namespace, + # owned by the StatefulSet + ownerReferences=pod.metadata.ownerReferences, + ), + spec=ServiceSpec( + externalTrafficPolicy="Local", + type="LoadBalancer", + selector={"app.kubernetes.io/name": self.app_name}, + ports=ports, + ), + ) + + def get_pod(self, pod_name: str) -> Pod: + """Gets the Pod via the K8s API.""" + return self.client.get( + res=Pod, + name=pod_name, + ) + + def get_service(self, service_name: str) -> Service: + """Gets the Service via the K8s API.""" + return self.client.get( + res=Service, + name=service_name, + ) + + def get_node(self, pod_name: str) -> Node: + """Gets the Node the Pod is running on via the K8s API.""" + pod = self.get_pod(pod_name) + if not pod.spec or not pod.spec.nodeName: + raise Exception("Could not find podSpec or nodeName") + + return self.client.get( + Node, + name=pod.spec.nodeName, + ) + + def get_node_ip(self, pod_name: str) -> str: + """Gets the IP Address of the Node of a given Pod via the K8s API.""" + try: + node = self.get_node(pod_name) + except ApiError as e: + if e.status.code == 403: + return "" + + if not node.status or not node.status.addresses: + raise Exception(f"No status found for {node}") + + for addresses in node.status.addresses: + if addresses.type in ["ExternalIP", "InternalIP", "Hostname"]: + return addresses.address + + return "" + + def get_nodeport(self, auth: Literal["plain", "tls"]) -> int: + """Gets the NodePort number for the service via the K8s API.""" + if not (service := self.get_service(self.exposer_service_name)): + raise Exception("Unable to find Service") + + if not service.spec or not service.spec.ports: + raise Exception("Could not find Service spec or ports") + + for port in service.spec.ports: + if str(port.name).endswith(auth): + return port.nodePort + + raise Exception(f"Unable to find NodePort using {auth} for the {service} service") + + def get_loadbalancer(self) -> str: + """Gets the LoadBalancer address for the service via the K8s API.""" + if not (service := self.get_service(self.exposer_service_name)): + raise Exception("Unable to find Service") + + if ( + not service.status + or not (lb_status := service.status.loadBalancer) + or not lb_status.ingress + ): + raise Exception("Could not find Service status or LoadBalancer") + + lb: LoadBalancerIngress + for lb in lb_status.ingress: + if lb.ip is not None: + return lb.ip + + raise Exception(f"Unable to find LoadBalancer ingress for the {service} service") diff --git a/src/managers/quorum.py b/src/managers/quorum.py index 58ec2750..52e74a26 100644 --- a/src/managers/quorum.py +++ b/src/managers/quorum.py @@ -39,13 +39,14 @@ def client(self) -> ZooKeeperManager: """Cached client manager application for performing ZK commands.""" admin_username = "super" admin_password = self.state.cluster.internal_user_credentials.get(admin_username, "") - active_hosts = [server.host for server in self.state.started_servers] + active_hosts = [server.internal_address for server in self.state.started_servers] return ZooKeeperManager( hosts=active_hosts, client_port=CLIENT_PORT, username=admin_username, password=admin_password, + read_only=False, ) @dataclass diff --git a/src/managers/tls.py b/src/managers/tls.py index dd8c09b9..00f9fbba 100644 --- a/src/managers/tls.py +++ b/src/managers/tls.py @@ -4,11 +4,15 @@ """Manager for building necessary files for Java TLS auth.""" import logging +import socket import subprocess import ops.pebble +from lightkube.core.exceptions import ApiError as LightKubeApiError +from tenacity import retry, retry_if_exception_cause_type, stop_after_attempt, wait_fixed from core.cluster import SUBSTRATES, ClusterState +from core.stubs import SANs from core.workload import WorkloadBase from literals import GROUP, USER @@ -23,6 +27,72 @@ def __init__(self, state: ClusterState, workload: WorkloadBase, substrate: SUBST self.workload = workload self.substrate = substrate + @retry( + wait=wait_fixed(5), + stop=stop_after_attempt(3), + retry=retry_if_exception_cause_type(LightKubeApiError), + reraise=True, + ) + def build_sans(self) -> SANs: + """Builds a SAN structure of DNS names and IPs for the unit.""" + if self.substrate == "vm": + return SANs( + sans_ip=[self.state.unit_server.internal_address], + sans_dns=[self.state.unit_server.unit.name, socket.getfqdn()], + ) + else: + sans_ip = [str(self.state.bind_address)] + + if node_ip := self.state.unit_server.node_ip: + sans_ip.append(node_ip) + + try: + sans_ip.append(self.state.unit_server.loadbalancer_ip) + except Exception: + pass + + return SANs( + sans_ip=sorted(sans_ip), + sans_dns=sorted( + [ + self.state.unit_server.internal_address.split(".")[0], + self.state.unit_server.internal_address, + socket.getfqdn(), + ] + ), + ) + + def get_current_sans(self) -> SANs | None: + """Gets the current SANs for the unit cert.""" + if not self.state.unit_server.certificate: + return + + command = ["openssl", "x509", "-noout", "-ext", "subjectAltName", "-in", "server.pem"] + + try: + sans_lines = self.workload.exec( + command=command, working_dir=self.workload.paths.conf_path + ).splitlines() + except (subprocess.CalledProcessError, ops.pebble.ExecError) as e: + logger.error(e.stdout) + raise e + + for line in sans_lines: + if "DNS" in line and "IP" in line: + break + + sans_ip = [] + sans_dns = [] + for item in line.split(", "): + san_type, san_value = item.split(":") + + if san_type.strip() == "DNS": + sans_dns.append(san_value) + if san_type.strip() == "IP Address": + sans_ip.append(san_value) + + return SANs(sans_ip=sorted(sans_ip), sans_dns=sorted(sans_dns)) + def set_private_key(self) -> None: """Sets the unit private-key.""" if not self.state.unit_server.private_key: diff --git a/tests/integration/test_provider.py b/tests/integration/test_provider.py index 95ec848c..b8b71cb4 100644 --- a/tests/integration/test_provider.py +++ b/tests/integration/test_provider.py @@ -33,9 +33,9 @@ async def test_deploy_charms_relate_active(ops_test: OpsTest, zk_charm): ops_test.model.deploy(zk_charm, application_name=APP_NAME, num_units=3), ops_test.model.deploy(app_charm, application_name=DUMMY_NAME_1, num_units=1), ) - await ops_test.model.wait_for_idle(apps=[APP_NAME, DUMMY_NAME_1]) + await ops_test.model.wait_for_idle(apps=[APP_NAME, DUMMY_NAME_1], timeout=1000) await ops_test.model.add_relation(APP_NAME, DUMMY_NAME_1) - await ops_test.model.wait_for_idle(apps=[APP_NAME, DUMMY_NAME_1]) + await ops_test.model.wait_for_idle(apps=[APP_NAME, DUMMY_NAME_1], timeout=1000) assert ops_test.model.applications[APP_NAME].status == "active" assert ops_test.model.applications[DUMMY_NAME_1].status == "active" assert ping_servers(ops_test) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 59c3dcd2..f6301fa4 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,12 +1,14 @@ # Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from ops import JujuVersion from tests.unit.test_charm import PropertyMock +from literals import SUBSTRATE + @pytest.fixture(autouse=True) def patched_idle(mocker, request): @@ -56,3 +58,35 @@ def patched_etc_hosts_environment(): def juju_has_secrets(mocker): """Using Juju3 we should always have secrets available.""" mocker.patch.object(JujuVersion, "has_secrets", new_callable=PropertyMock).return_value = True + + +@pytest.fixture(autouse=True) +def patched_k8s_client(monkeypatch): + with monkeypatch.context() as m: + m.setattr("lightkube.core.client.GenericSyncClient", Mock()) + yield + + +@pytest.fixture(autouse=True) +def patched_node_ip(): + if SUBSTRATE == "k8s": + with patch( + "core.models.ZKServer.node_ip", + new_callable=PropertyMock, + return_value="111.111.111.111", + ) as patched_node_ip: + yield patched_node_ip + else: + yield + + +@pytest.fixture(autouse=True) +def patched_node_port(): + if SUBSTRATE == "k8s": + with patch( + "managers.k8s.K8sManager.get_nodeport", + return_value=30000, + ) as patched_node_port: + yield patched_node_port + else: + yield diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py index 41eaea95..11b341c1 100644 --- a/tests/unit/test_charm.py +++ b/tests/unit/test_charm.py @@ -742,6 +742,7 @@ def test_init_server_calls_necessary_methods(ctx: Context, base_state: State) -> patch("managers.tls.TLSManager.set_truststore") as patched_truststore, patch("managers.tls.TLSManager.set_p12_keystore") as patched_keystore, patch("workload.ZKWorkload.start") as start, + patch("managers.tls.TLSManager.get_current_sans", return_value=""), patch( "charms.rolling_ops.v0.rollingops.RollingOpsManager._on_acquire_lock", autospec=True, @@ -1232,6 +1233,7 @@ def test_port_updates_if_tls(ctx: Context, base_state: State) -> None: PEER, local_app_data={"quorum": "ssl", "relation-0": "mellon", "tls": "enabled"}, local_unit_data={"private-address": "treebeard", "state": "started"}, + peers_data={}, ) client_relation = Relation(REL_NAME, "application", remote_app_data={"database": "app"}) state_in = dataclasses.replace(base_state, relations=[cluster_peer, client_relation]) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 89939ce6..01c1fbac 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -97,13 +97,13 @@ def test_is_ready(): def test_init_raises_if_leader_not_found(): with patch("charms.zookeeper.v0.client.KazooClient", return_value=DummyClient(follower=True)): with pytest.raises(QuorumLeaderNotFoundError): - ZooKeeperManager(hosts=["host"], username="", password="") + ZooKeeperManager(hosts=["host"], username="", password="", read_only=False) def test_init_finds_leader(): with patch("charms.zookeeper.v0.client.KazooClient", return_value=DummyClient()): - zk = ZooKeeperManager(hosts=["host"], username="", password="") - assert zk.leader == "host" + zk = ZooKeeperManager(hosts=["host"], username="", password="", read_only=False) + assert zk.zk_host == "host" def test_members_syncing(): @@ -134,7 +134,7 @@ def test_add_members_correct_args(reconfig): def test_add_members_runs_on_leader(_): with patch("charms.zookeeper.v0.client.KazooClient", return_value=DummyClient()) as client: zk = ZooKeeperManager(hosts=["server.1=bilbo.baggins"], username="", password="") - zk.leader = "leader" + zk.zk_host = "leader" zk.add_members(["server.2=sam.gamgee"]) calls = client.call_args_list @@ -197,7 +197,7 @@ def test_remove_members_handles_zeroes(reconfig): def test_remove_members_runs_on_leader(_): with patch("charms.zookeeper.v0.client.KazooClient", return_value=DummyClient()) as client: zk = ZooKeeperManager(hosts=["server.1=bilbo.baggins"], username="", password="") - zk.leader = "leader" + zk.zk_host = "leader" zk.remove_members(["server.2=sam.gamgee"]) calls = client.call_args_list @@ -220,7 +220,7 @@ def test_remove_members_runs_on_leader(_): def test_get_version(_): with patch("charms.zookeeper.v0.client.KazooClient", return_value=DummyClient()): zk = ZooKeeperManager(hosts=["server"], username="", password="") - zk.leader = "leader" + zk.zk_host = "leader" version = zk.get_version() diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 0ac02c10..042361e3 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -189,7 +189,10 @@ def test_tls_enabled(ctx: Context, base_state: State) -> None: state_in = dataclasses.replace(base_state, relations=[cluster_peer]) # When - with ctx(ctx.on.config_changed(), state_in) as manager: + with ( + patch("managers.tls.TLSManager.get_current_sans", return_value=""), + ctx(ctx.on.config_changed(), state_in) as manager, + ): charm = cast(ZooKeeperCharm, manager.charm) # Then @@ -275,7 +278,10 @@ def test_tls_ssl_quorum(ctx: Context, base_state: State) -> None: state_in = dataclasses.replace(base_state, relations=[cluster_peer]) # When - with ctx(ctx.on.config_changed(), state_in) as manager: + with ( + patch("managers.tls.TLSManager.get_current_sans", return_value=""), + ctx(ctx.on.config_changed(), state_in) as manager, + ): charm = cast(ZooKeeperCharm, manager.charm) # Then diff --git a/tests/unit/test_provider.py b/tests/unit/test_provider.py index eeba377b..01d9c855 100644 --- a/tests/unit/test_provider.py +++ b/tests/unit/test_provider.py @@ -2,6 +2,7 @@ # Copyright 2022 Canonical Ltd. # See LICENSE file for licensing details. import dataclasses +import json import logging from pathlib import Path from typing import cast @@ -9,11 +10,11 @@ import pytest import yaml -from ops import RelationBrokenEvent +from ops import MaintenanceStatus, RelationBrokenEvent from ops.testing import Container, Context, PeerRelation, Relation, State from charm import ZooKeeperCharm -from literals import CONTAINER, PEER, REL_NAME, SUBSTRATE +from literals import CONTAINER, PEER, REL_NAME, SUBSTRATE, Status logger = logging.getLogger(__name__) @@ -34,6 +35,12 @@ def base_state(): return state +@pytest.fixture() +def charm_configuration(): + """Enable direct mutation on configuration dict.""" + return json.loads(json.dumps(CONFIG)) + + @pytest.fixture() def ctx() -> Context: ctx = Context(ZooKeeperCharm, meta=METADATA, config=CONFIG, actions=ACTIONS, unit_id=0) @@ -212,3 +219,50 @@ def test_client_relation_broken_removes_passwords(ctx: Context, base_state: Stat # Then assert not charm.state.cluster.client_passwords + + +@pytest.mark.skipif(SUBSTRATE == "vm", reason="K8s services not used on VM charms") +def test_expose_external_service_down_disconnect_clients( + charm_configuration: dict, base_state: State +) -> None: + # Given + charm_configuration["options"]["expose-external"]["default"] = "nodeport" + cluster_peer = PeerRelation( + PEER, + PEER, + peers_data={}, + local_unit_data={"state": "started"}, + local_app_data={ + "sync-password": "mellon", + "super-password": "mellon", + }, + ) + client_relation = Relation( + REL_NAME, + "application", + remote_app_data={"database": "balrog"}, + local_app_data={"endpoints": "9.9.9.9:2181"}, + ) + state_in = dataclasses.replace(base_state, relations=[cluster_peer, client_relation]) + ctx = Context( + ZooKeeperCharm, meta=METADATA, config=charm_configuration, actions=ACTIONS, unit_id=0 + ) + + # When + with ( + patch( + "core.cluster.ClusterState.stable", + new_callable=PropertyMock, + return_value=Status.ACTIVE, + ), + patch( + "core.cluster.ClusterState.endpoints_external", + new_callable=PropertyMock, + return_value="", + ), + ): + state_out = ctx.run(ctx.on.config_changed(), state_in) + + # Then + assert not state_out.get_relation(client_relation.id).local_app_data.get("endpoints", "") + assert isinstance(state_out.unit_status, MaintenanceStatus) diff --git a/tests/unit/test_structured_config.py b/tests/unit/test_structured_config.py new file mode 100644 index 00000000..407d43c3 --- /dev/null +++ b/tests/unit/test_structured_config.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +import logging +from pathlib import Path +from typing import Iterable + +import pytest +import yaml +from pydantic import ValidationError + +from core.structured_config import CharmConfig + +logger = logging.getLogger(__name__) +CONFIG = yaml.safe_load(Path("./config.yaml").read_text()) + + +def to_underscore(string: str) -> str: + """Convert dashes to underscores. + + This function is used to automatically generate aliases for our charm since the + config.yaml file uses kebab-case. + """ + return string.replace("-", "_") + + +def check_valid_values(field: str, accepted_values: Iterable) -> None: + """Check the correctness of the passed values for a field.""" + flat_config_options = { + to_underscore(option_name): mapping.get("default") + for option_name, mapping in CONFIG["options"].items() + } + for value in accepted_values: + CharmConfig(**{**flat_config_options, **{field: value}}) + + +def check_invalid_values(field: str, erroneus_values: Iterable) -> None: + """Check the incorrectness of the passed values for a field.""" + flat_config_options = { + to_underscore(option_name): mapping.get("default") + for option_name, mapping in CONFIG["options"].items() + } + for value in erroneus_values: + with pytest.raises(ValidationError) as excinfo: + CharmConfig(**{**flat_config_options, **{field: value}}) + + assert field in excinfo.value.errors()[0]["loc"] + + +def test_values_gt_zero() -> None: + """Check fields greater than zero.""" + gt_zero_fields = ["init_limit", "sync_limit", "tick_time"] + erroneus_values = [0, -2147483649, -34] + valid_values = [42, 1000, 1, 9223372036854775807] + for field in gt_zero_fields: + check_invalid_values(field, erroneus_values) + check_valid_values(field, valid_values) + + +def test_incorrect_log_level(): + """Accepted log-level values must be part of the defined enumeration and uppercase.""" + erroneus_values = ["", "something_else", "warning", "DEBUG,INFO"] + valid_values = ["INFO", "WARNING", "DEBUG", "ERROR"] + check_invalid_values("log_level", erroneus_values) + check_valid_values("log_level", valid_values) + + +def test_incorrect_expose_external(): + erroneus_values = ["", "something_else", "false,nodeport", "load_balancer"] + valid_values = ["false", "nodeport", "loadbalancer"] + check_invalid_values("expose_external", erroneus_values) + check_valid_values("expose_external", valid_values) diff --git a/tests/unit/test_tls.py b/tests/unit/test_tls.py index a40dba69..439d16c5 100644 --- a/tests/unit/test_tls.py +++ b/tests/unit/test_tls.py @@ -6,6 +6,8 @@ import dataclasses import json +import socket +from ipaddress import IPv4Address from pathlib import Path from typing import cast from unittest.mock import DEFAULT, Mock, PropertyMock, patch @@ -15,7 +17,8 @@ from ops.testing import Container, Context, PeerRelation, Relation, Secret, State from charm import ZooKeeperCharm -from literals import CERTS_REL_NAME, CONTAINER, PEER, SUBSTRATE, Status +from core.stubs import SANs +from literals import CERTS_REL_NAME, CHARM_KEY, CONTAINER, PEER, SUBSTRATE, Status CONFIG = yaml.safe_load(Path("./config.yaml").read_text()) ACTIONS = yaml.safe_load(Path("./actions.yaml").read_text()) @@ -38,6 +41,12 @@ def base_state(): return state +@pytest.fixture() +def charm_configuration(): + """Enable direct mutation on configuration dict.""" + return json.loads(json.dumps(CONFIG)) + + @pytest.fixture() def ctx() -> Context: ctx = Context(ZooKeeperCharm, meta=METADATA, config=CONFIG, actions=ACTIONS, unit_id=0) @@ -212,7 +221,11 @@ def test_certificates_joined_creates_new_key_trust_store_password( with ( patch("core.cluster.ClusterState.stable", new_callable=PropertyMock, return_value=True), patch("core.models.ZKCluster.tls", new_callable=PropertyMock, return_value=True), - patch("core.models.ZKServer.host", new_callable=PropertyMock, return_value="host"), + patch( + "core.models.ZKServer.internal_address", + new_callable=PropertyMock, + return_value="1.1.1.1", + ), ctx(ctx.on.relation_joined(tls_relation), state_in) as manager, ): charm = cast(ZooKeeperCharm, manager.charm) @@ -284,13 +297,17 @@ def test_certificates_available_succeeds(ctx: Context, base_state: State) -> Non ) # When - with patch.multiple( - "managers.tls.TLSManager", - set_private_key=DEFAULT, - set_ca=DEFAULT, - set_certificate=DEFAULT, - set_truststore=DEFAULT, - set_p12_keystore=DEFAULT, + with ( + patch.multiple( + "managers.tls.TLSManager", + set_private_key=DEFAULT, + set_ca=DEFAULT, + set_certificate=DEFAULT, + set_truststore=DEFAULT, + set_p12_keystore=DEFAULT, + get_current_sans=lambda _: None, + ), + patch("workload.ZKWorkload.write"), ): state_out = ctx.run(ctx.on.relation_changed(tls_relation), state_in) @@ -333,13 +350,17 @@ def test_renew_certificates_auto_reload(ctx: Context, base_state: State) -> None state_in = dataclasses.replace(base_state, relations=[cluster_peer, tls_relation]) # When - with patch.multiple( - "managers.tls.TLSManager", - set_private_key=DEFAULT, - set_ca=DEFAULT, - set_certificate=DEFAULT, - set_truststore=DEFAULT, - set_p12_keystore=DEFAULT, + with ( + patch.multiple( + "managers.tls.TLSManager", + set_private_key=DEFAULT, + set_ca=DEFAULT, + set_certificate=DEFAULT, + set_truststore=DEFAULT, + set_p12_keystore=DEFAULT, + get_current_sans=lambda _: None, + ), + patch("workload.ZKWorkload.write"), ): state_out = ctx.run(ctx.on.relation_changed(tls_relation), state_in) @@ -371,6 +392,7 @@ def test_certificates_available_halfway_through_upgrade_succeeds( set_certificate=DEFAULT, set_truststore=DEFAULT, set_p12_keystore=DEFAULT, + get_current_sans=lambda _: None, ), ctx(ctx.on.relation_changed(tls_relation), state_in) as manager, ): @@ -541,3 +563,68 @@ def test_set_tls_private_key(ctx: Context, base_state: State) -> None: # Then assert charm.state.unit_server.csr != "csr" + + +@pytest.mark.parametrize("expose_external", ["false", "nodeport", "loadbalancer"]) +def test_sans_external_access( + charm_configuration: dict, base_state: State, expose_external: str +) -> None: + # Given + charm_configuration["options"]["expose-external"]["default"] = expose_external + ctx = Context( + ZooKeeperCharm, meta=METADATA, config=charm_configuration, actions=ACTIONS, unit_id=0 + ) + cluster_peer = PeerRelation( + PEER, PEER, local_unit_data={"private-address": "treebeard"}, peers_data={} + ) + state_in = dataclasses.replace(base_state, relations=[cluster_peer]) + sock_dns = socket.getfqdn() + + # When + if SUBSTRATE == "vm": + with ( + patch("workload.ZKWorkload.write"), + ctx(ctx.on.config_changed(), state_in) as manager, + ): + charm = cast(ZooKeeperCharm, manager.charm) + built_sans = charm.tls_manager.build_sans() + + # Then + assert built_sans == SANs( + sans_ip=["treebeard"], + sans_dns=[f"{CHARM_KEY}/0", sock_dns], + ) + + # When + if SUBSTRATE == "k8s": + with ( + patch( + "core.cluster.ClusterState.bind_address", + new_callable=PropertyMock, + return_value=IPv4Address("2.2.2.2"), + ), + patch( + "core.models.ZKServer.loadbalancer_ip", + new_callable=PropertyMock, + return_value="3.3.3.3", + ), + ctx(ctx.on.config_changed(), state_in) as manager, + ): + charm = cast(ZooKeeperCharm, manager.charm) + built_sans = charm.tls_manager.build_sans() + + # Then + assert sorted(built_sans.sans_dns) == sorted( + [ + f"{CHARM_KEY}-0", + f"{CHARM_KEY}-0.{CHARM_KEY}-endpoints", + sock_dns, + ] + ) + assert "2.2.2.2" in "".join(built_sans.sans_ip) + + if expose_external == "nodeport": + assert "111.111.111.111" in "".join(built_sans.sans_ip) + + if expose_external == "loadbalancer": + assert "3.3.3.3" in "".join(built_sans.sans_ip)