diff --git a/lib/charms/observability_libs/v0/cert_handler.py b/lib/charms/observability_libs/v0/cert_handler.py index 88a8374e..db14e00f 100644 --- a/lib/charms/observability_libs/v0/cert_handler.py +++ b/lib/charms/observability_libs/v0/cert_handler.py @@ -64,7 +64,7 @@ LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" LIBAPI = 0 -LIBPATCH = 8 +LIBPATCH = 9 def is_ip_address(value: str) -> bool: @@ -181,33 +181,40 @@ def _peer_relation(self) -> Optional[Relation]: return self.charm.model.get_relation(self.peer_relation_name, None) def _on_peer_relation_created(self, _): - """Generate the private key and store it in a peer relation.""" - # We're in "relation-created", so the relation should be there + """Generate the CSR if the certificates relation is ready.""" + self._generate_privkey() - # Just in case we already have a private key, do not overwrite it. - # Not sure how this could happen. - # TODO figure out how to go about key rotation. - if not self._private_key: - private_key = generate_private_key() - self._private_key = private_key.decode() - - # Generate CSR here, in case peer events fired after tls-certificate relation events + # check cert relation is ready if not (self.charm.model.get_relation(self.certificates_relation_name)): # peer relation event happened to fire before tls-certificates events. # Abort, and let the "certificates joined" observer create the CSR. + logger.info("certhandler waiting on certificates relation") return + logger.debug("certhandler has peer and certs relation: proceeding to generate csr") self._generate_csr() def _on_certificates_relation_joined(self, _) -> None: - """Generate the CSR and request the certificate creation.""" + """Generate the CSR if the peer relation is ready.""" + self._generate_privkey() + + # check peer relation is there if not self._peer_relation: # tls-certificates relation event happened to fire before peer events. # Abort, and let the "peer joined" relation create the CSR. + logger.info("certhandler waiting on peer relation") return + logger.debug("certhandler has peer and certs relation: proceeding to generate csr") self._generate_csr() + def _generate_privkey(self): + # Generate priv key unless done already + # TODO figure out how to go about key rotation. + if not self._private_key: + private_key = generate_private_key() + self._private_key = private_key.decode() + def _on_config_changed(self, _): # FIXME on config changed, the web_external_url may or may not change. But because every # call to `generate_csr` appends a uuid, CSRs cannot be easily compared to one another. @@ -237,7 +244,12 @@ def _generate_csr( # In case we already have a csr, do not overwrite it by default. if overwrite or renew or not self._csr: private_key = self._private_key - assert private_key is not None # for type checker + if private_key is None: + # FIXME: raise this in a less nested scope by + # generating privkey and csr in the same method. + raise RuntimeError( + "private key unset. call _generate_privkey() before you call this method." + ) csr = generate_csr( private_key=private_key.encode(), subject=self.cert_subject, diff --git a/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py b/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py index 26f2aeb2..a6ad4dfb 100644 --- a/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py +++ b/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py @@ -107,7 +107,7 @@ def setUp(self, *unused): from math import ceil, floor from typing import Callable, Dict, List, Optional, Union -from lightkube import ApiError, Client +from lightkube import ApiError, Client # pyright: ignore from lightkube.core import exceptions from lightkube.models.apps_v1 import StatefulSetSpec from lightkube.models.core_v1 import ( @@ -133,7 +133,7 @@ def setUp(self, *unused): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 4 +LIBPATCH = 6 _Decimal = Union[Decimal, float, str, int] # types that are potentially convertible to Decimal @@ -322,7 +322,7 @@ def __init__(self, namespace: str, statefulset_name: str, container_name: str): self.namespace = namespace self.statefulset_name = statefulset_name self.container_name = container_name - self.client = Client() + self.client = Client() # pyright: ignore def _patched_delta(self, resource_reqs: ResourceRequirements) -> StatefulSet: statefulset = self.client.get( @@ -366,7 +366,7 @@ def is_patched(self, resource_reqs: ResourceRequirements) -> bool: """ return equals_canonically(self.get_templated(), resource_reqs) - def get_templated(self) -> ResourceRequirements: + def get_templated(self) -> Optional[ResourceRequirements]: """Returns the resource limits specified in the StatefulSet template.""" statefulset = self.client.get( StatefulSet, name=self.statefulset_name, namespace=self.namespace @@ -377,7 +377,7 @@ def get_templated(self) -> ResourceRequirements: ) return podspec_tpl.resources - def get_actual(self, pod_name: str) -> ResourceRequirements: + def get_actual(self, pod_name: str) -> Optional[ResourceRequirements]: """Return the resource limits that are in effect for the container in the given pod.""" pod = self.client.get(Pod, name=pod_name, namespace=self.namespace) podspec = self._get_container( @@ -421,7 +421,7 @@ def apply(self, resource_reqs: ResourceRequirements) -> None: class KubernetesComputeResourcesPatch(Object): """A utility for patching the Kubernetes compute resources set up by Juju.""" - on = K8sResourcePatchEvents() + on = K8sResourcePatchEvents() # pyright: ignore def __init__( self, diff --git a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py index e4297aa1..665af886 100644 --- a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py +++ b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py @@ -362,7 +362,7 @@ def _on_scrape_targets_changed(self, event): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 42 +LIBPATCH = 44 PYDEPS = ["cosl"] @@ -386,6 +386,7 @@ def _on_scrape_targets_changed(self, event): "basic_auth", "tls_config", "authorization", + "params", } DEFAULT_JOB = { "metrics_path": "/metrics", @@ -764,7 +765,7 @@ def _validate_relation_by_interface_and_direction( actual_relation_interface = relation.interface_name if actual_relation_interface != expected_relation_interface: raise RelationInterfaceMismatchError( - relation_name, expected_relation_interface, actual_relation_interface + relation_name, expected_relation_interface, actual_relation_interface or "None" ) if expected_relation_role == RelationRole.provides: @@ -857,7 +858,7 @@ class MonitoringEvents(ObjectEvents): class MetricsEndpointConsumer(Object): """A Prometheus based Monitoring service.""" - on = MonitoringEvents() + on = MonitoringEvents() # pyright: ignore def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME): """A Prometheus based Monitoring service. @@ -1014,7 +1015,6 @@ def alerts(self) -> dict: try: scrape_metadata = json.loads(relation.data[relation.app]["scrape_metadata"]) identifier = JujuTopology.from_dict(scrape_metadata).identifier - alerts[identifier] = self._tool.apply_label_matchers(alert_rules) # type: ignore except KeyError as e: logger.debug( @@ -1029,6 +1029,10 @@ def alerts(self) -> dict: ) continue + # We need to append the relation info to the identifier. This is to allow for cases for there are two + # relations which eventually scrape the same application. Issue #551. + identifier = f"{identifier}_{relation.name}_{relation.id}" + alerts[identifier] = alert_rules _, errmsg = self._tool.validate_alert_rules(alert_rules) @@ -1294,7 +1298,7 @@ def _resolve_dir_against_charm_path(charm: CharmBase, *path_elements: str) -> st class MetricsEndpointProvider(Object): """A metrics endpoint for Prometheus.""" - on = MetricsEndpointProviderEvents() + on = MetricsEndpointProviderEvents() # pyright: ignore def __init__( self, @@ -1836,14 +1840,16 @@ def _set_prometheus_data(self, event): return jobs = [] + _type_convert_stored( - self._stored.jobs + self._stored.jobs # pyright: ignore ) # list of scrape jobs, one per relation for relation in self.model.relations[self._target_relation]: targets = self._get_targets(relation) if targets and relation.app: jobs.append(self._static_scrape_job(targets, relation.app.name)) - groups = [] + _type_convert_stored(self._stored.alert_rules) # list of alert rule groups + groups = [] + _type_convert_stored( + self._stored.alert_rules # pyright: ignore + ) # list of alert rule groups for relation in self.model.relations[self._alert_rules_relation]: unit_rules = self._get_alert_rules(relation) if unit_rules and relation.app: @@ -1895,7 +1901,7 @@ def set_target_job_data(self, targets: dict, app_name: str, **kwargs) -> None: jobs.append(updated_job) relation.data[self._charm.app]["scrape_jobs"] = json.dumps(jobs) - if not _type_convert_stored(self._stored.jobs) == jobs: + if not _type_convert_stored(self._stored.jobs) == jobs: # pyright: ignore self._stored.jobs = jobs def _on_prometheus_targets_departed(self, event): @@ -1947,7 +1953,7 @@ def remove_prometheus_jobs(self, job_name: str, unit_name: Optional[str] = ""): relation.data[self._charm.app]["scrape_jobs"] = json.dumps(jobs) - if not _type_convert_stored(self._stored.jobs) == jobs: + if not _type_convert_stored(self._stored.jobs) == jobs: # pyright: ignore self._stored.jobs = jobs def _job_name(self, appname) -> str: @@ -2126,7 +2132,7 @@ def set_alert_rule_data(self, name: str, unit_rules: dict, label_rules: bool = T groups.append(updated_group) relation.data[self._charm.app]["alert_rules"] = json.dumps({"groups": groups}) - if not _type_convert_stored(self._stored.alert_rules) == groups: + if not _type_convert_stored(self._stored.alert_rules) == groups: # pyright: ignore self._stored.alert_rules = groups def _on_alert_rules_departed(self, event): @@ -2176,7 +2182,7 @@ def remove_alert_rules(self, group_name: str, unit_name: str) -> None: json.dumps({"groups": groups}) if groups else "{}" ) - if not _type_convert_stored(self._stored.alert_rules) == groups: + if not _type_convert_stored(self._stored.alert_rules) == groups: # pyright: ignore self._stored.alert_rules = groups def _get_alert_rules(self, relation) -> dict: diff --git a/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/lib/charms/tls_certificates_interface/v2/tls_certificates.py index f4a08366..99741f51 100644 --- a/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v2/tls_certificates.py @@ -287,7 +287,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.serialization import pkcs12 from cryptography.x509.extensions import Extension, ExtensionNotFound -from jsonschema import exceptions, validate # type: ignore[import] +from jsonschema import exceptions, validate # type: ignore[import-untyped] from ops.charm import ( CharmBase, CharmEvents, @@ -298,7 +298,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven ) from ops.framework import EventBase, EventSource, Handle, Object from ops.jujuversion import JujuVersion -from ops.model import Relation, SecretNotFoundError +from ops.model import ModelError, Relation, RelationDataContent, SecretNotFoundError # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" @@ -308,7 +308,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 16 +LIBPATCH = 20 PYDEPS = ["cryptography", "jsonschema"] @@ -335,7 +335,10 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven "type": "array", "items": { "type": "object", - "properties": {"certificate_signing_request": {"type": "string"}}, + "properties": { + "certificate_signing_request": {"type": "string"}, + "ca": {"type": "boolean"}, + }, "required": ["certificate_signing_request"], }, } @@ -536,22 +539,31 @@ def restore(self, snapshot: dict): class CertificateCreationRequestEvent(EventBase): """Charm Event triggered when a TLS certificate is required.""" - def __init__(self, handle: Handle, certificate_signing_request: str, relation_id: int): + def __init__( + self, + handle: Handle, + certificate_signing_request: str, + relation_id: int, + is_ca: bool = False, + ): super().__init__(handle) self.certificate_signing_request = certificate_signing_request self.relation_id = relation_id + self.is_ca = is_ca def snapshot(self) -> dict: """Returns snapshot.""" return { "certificate_signing_request": self.certificate_signing_request, "relation_id": self.relation_id, + "is_ca": self.is_ca, } def restore(self, snapshot: dict): """Restores snapshot.""" self.certificate_signing_request = snapshot["certificate_signing_request"] self.relation_id = snapshot["relation_id"] + self.is_ca = snapshot["is_ca"] class CertificateRevocationRequestEvent(EventBase): @@ -588,23 +600,26 @@ def restore(self, snapshot: dict): self.chain = snapshot["chain"] -def _load_relation_data(raw_relation_data: dict) -> dict: +def _load_relation_data(relation_data_content: RelationDataContent) -> dict: """Loads relation data from the relation data bag. Json loads all data. Args: - raw_relation_data: Relation data from the databag + relation_data_content: Relation data from the databag Returns: dict: Relation data in dict format. """ certificate_data = dict() - for key in raw_relation_data: - try: - certificate_data[key] = json.loads(raw_relation_data[key]) - except (json.decoder.JSONDecodeError, TypeError): - certificate_data[key] = raw_relation_data[key] + try: + for key in relation_data_content: + try: + certificate_data[key] = json.loads(relation_data_content[key]) + except (json.decoder.JSONDecodeError, TypeError): + certificate_data[key] = relation_data_content[key] + except ModelError: + pass return certificate_data @@ -685,6 +700,7 @@ def generate_certificate( ca_key_password: Optional[bytes] = None, validity: int = 365, alt_names: Optional[List[str]] = None, + is_ca: bool = False, ) -> bytes: """Generates a TLS certificate based on a CSR. @@ -695,6 +711,7 @@ def generate_certificate( ca_key_password: CA private key password validity (int): Certificate validity (in days) alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate Returns: bytes: Certificate @@ -726,7 +743,6 @@ def generate_certificate( .add_extension( x509.SubjectKeyIdentifier.from_public_key(csr_object.public_key()), critical=False ) - .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=False) ) extensions_list = csr_object.extensions @@ -758,6 +774,29 @@ def generate_certificate( critical=extension.critical, ) + if is_ca: + certificate_builder = certificate_builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), critical=True + ) + certificate_builder = certificate_builder.add_extension( + x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + else: + certificate_builder = certificate_builder.add_extension( + x509.BasicConstraints(ca=False, path_length=None), critical=False + ) + certificate_builder._version = x509.Version.v3 cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] return cert.public_bytes(serialization.Encoding.PEM) @@ -1171,15 +1210,19 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: certificate_creation_request["certificate_signing_request"] for certificate_creation_request in provider_certificates ] - requirer_unit_csrs = [ - certificate_creation_request["certificate_signing_request"] + requirer_unit_certificate_requests = [ + { + "csr": certificate_creation_request["certificate_signing_request"], + "is_ca": certificate_creation_request.get("ca", False), + } for certificate_creation_request in requirer_csrs ] - for certificate_signing_request in requirer_unit_csrs: - if certificate_signing_request not in provider_csrs: + for certificate_request in requirer_unit_certificate_requests: + if certificate_request["csr"] not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_signing_request, + certificate_signing_request=certificate_request["csr"], relation_id=event.relation.id, + is_ca=certificate_request["is_ca"], ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) @@ -1217,12 +1260,24 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None ) self.remove_certificate(certificate=certificate["certificate"]) - def get_requirer_csrs_with_no_certs( + def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Filters the requirer's units csrs. + """Returns CSR's for which no certificate has been issued. - Keeps the ones for which no certificate was provided. + Example return: [ + { + "relation_id": 0, + "application_name": "tls-certificates-requirer", + "unit_name": "tls-certificates-requirer/0", + "unit_csrs": [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "is_ca": false + } + ] + } + ] Args: relation_id (int): Relation id @@ -1239,6 +1294,7 @@ def get_requirer_csrs_with_no_certs( if not self.certificate_issued_for_csr( app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] csr=csr["certificate_signing_request"], # type: ignore[index] + relation_id=relation_id, ): csrs_without_certs.append(csr) if csrs_without_certs: @@ -1285,17 +1341,21 @@ def get_requirer_csrs( ) return unit_csr_mappings - def certificate_issued_for_csr(self, app_name: str, csr: str) -> bool: + def certificate_issued_for_csr( + self, app_name: str, csr: str, relation_id: Optional[int] + ) -> bool: """Checks whether a certificate has been issued for a given CSR. Args: app_name (str): Application name that the CSR belongs to. csr (str): Certificate Signing Request. - + relation_id (Optional[int]): Relation ID Returns: bool: True/False depending on whether a certificate has been issued for the given CSR. """ - issued_certificates_per_csr = self.get_issued_certificates()[app_name] + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)[ + app_name + ] for issued_pair in issued_certificates_per_csr: if "csr" in issued_pair and issued_pair["csr"] == csr: return csr_matches_certificate(csr, issued_pair["certificate"]) @@ -1337,8 +1397,17 @@ def __init__( self.framework.observe(charm.on.update_status, self._on_update_status) @property - def _requirer_csrs(self) -> List[Dict[str, str]]: - """Returns list of requirer's CSRs from relation data.""" + def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: + """Returns list of requirer's CSRs from relation data. + + Example: + [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "ca": false + } + ] + """ relation = self.model.get_relation(self.relationship_name) if not relation: raise RuntimeError(f"Relation {self.relationship_name} does not exist") @@ -1361,11 +1430,12 @@ def _provider_certificates(self) -> List[Dict[str, str]]: return [] return provider_relation_data.get("certificates", []) - def _add_requirer_csr(self, csr: str) -> None: + def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: """Adds CSR to relation data. Args: csr (str): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None @@ -1376,7 +1446,10 @@ def _add_requirer_csr(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - new_csr_dict = {"certificate_signing_request": csr} + new_csr_dict: Dict[str, Union[bool, str]] = { + "certificate_signing_request": csr, + "ca": is_ca, + } if new_csr_dict in self._requirer_csrs: logger.info("CSR already in relation data - Doing nothing") return @@ -1400,18 +1473,22 @@ def _remove_requirer_csr(self, csr: str) -> None: f"The certificate request can't be completed" ) requirer_csrs = copy.deepcopy(self._requirer_csrs) - csr_dict = {"certificate_signing_request": csr} - if csr_dict not in requirer_csrs: - logger.info("CSR not in relation data - Doing nothing") + if not requirer_csrs: + logger.info("No CSRs in relation data - Doing nothing") return - requirer_csrs.remove(csr_dict) + for requirer_csr in requirer_csrs: + if requirer_csr["certificate_signing_request"] == csr: + requirer_csrs.remove(requirer_csr) relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) - def request_certificate_creation(self, certificate_signing_request: bytes) -> None: + def request_certificate_creation( + self, certificate_signing_request: bytes, is_ca: bool = False + ) -> None: """Request TLS certificate to provider charm. Args: certificate_signing_request (bytes): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None @@ -1422,7 +1499,7 @@ def request_certificate_creation(self, certificate_signing_request: bytes) -> No f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - self._add_requirer_csr(certificate_signing_request.decode().strip()) + self._add_requirer_csr(certificate_signing_request.decode().strip(), is_ca=is_ca) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: @@ -1701,7 +1778,10 @@ def csr_matches_certificate(csr: str, cert: str) -> bool: format=serialization.PublicFormat.SubjectPublicKeyInfo, ): return False - if csr_object.subject != cert_object.subject: + if ( + csr_object.public_key().public_numbers().n # type: ignore[union-attr] + != cert_object.public_key().public_numbers().n # type: ignore[union-attr] + ): return False except ValueError: logger.warning("Could not load certificate or CSR.") diff --git a/lib/charms/traefik_k8s/v2/ingress.py b/lib/charms/traefik_k8s/v2/ingress.py index 0364c8ab..31028e97 100644 --- a/lib/charms/traefik_k8s/v2/ingress.py +++ b/lib/charms/traefik_k8s/v2/ingress.py @@ -50,20 +50,13 @@ def _on_ingress_ready(self, event: IngressPerAppReadyEvent): def _on_ingress_revoked(self, event: IngressPerAppRevokedEvent): logger.info("This app no longer has ingress") """ +import ipaddress import json import logging import socket import typing from dataclasses import dataclass -from typing import ( - Any, - Dict, - List, - MutableMapping, - Optional, - Sequence, - Tuple, -) +from typing import Any, Callable, Dict, List, MutableMapping, Optional, Sequence, Tuple, Union import pydantic from ops.charm import CharmBase, RelationBrokenEvent, RelationEvent @@ -79,7 +72,7 @@ def _on_ingress_revoked(self, event: IngressPerAppRevokedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 6 +LIBPATCH = 8 PYDEPS = ["pydantic<2.0"] @@ -200,7 +193,11 @@ def validate_port(cls, port): # noqa: N805 # pydantic wants 'cls' as first arg class IngressRequirerUnitData(DatabagModel): """Ingress requirer unit databag model.""" - host: str = Field(description="Hostname the unit wishes to be exposed.") + host: str = Field(description="Hostname at which the unit is reachable.") + ip: Optional[str] = Field( + description="IP at which the unit is reachable, " + "IP can only be None if the IP information can't be retrieved from juju." + ) @validator("host", pre=True) def validate_host(cls, host): # noqa: N805 # pydantic wants 'cls' as first arg @@ -208,6 +205,24 @@ def validate_host(cls, host): # noqa: N805 # pydantic wants 'cls' as first arg assert isinstance(host, str), type(host) return host + @validator("ip", pre=True) + def validate_ip(cls, ip): # noqa: N805 # pydantic wants 'cls' as first arg + """Validate ip.""" + if ip is None: + return None + if not isinstance(ip, str): + raise TypeError(f"got ip of type {type(ip)} instead of expected str") + try: + ipaddress.IPv4Address(ip) + return ip + except ipaddress.AddressValueError: + pass + try: + ipaddress.IPv6Address(ip) + return ip + except ipaddress.AddressValueError: + raise ValueError(f"{ip!r} is not a valid ip address") + class RequirerSchema(BaseModel): """Requirer schema for Ingress.""" @@ -244,6 +259,7 @@ def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME) observe(rel_events.relation_created, self._handle_relation) observe(rel_events.relation_joined, self._handle_relation) observe(rel_events.relation_changed, self._handle_relation) + observe(rel_events.relation_departed, self._handle_relation) observe(rel_events.relation_broken, self._handle_relation_broken) observe(charm.on.leader_elected, self._handle_upgrade_or_leader) # type: ignore observe(charm.on.upgrade_charm, self._handle_upgrade_or_leader) # type: ignore @@ -540,12 +556,13 @@ def __init__( relation_name: str = DEFAULT_RELATION_NAME, *, host: Optional[str] = None, + ip: Optional[str] = None, port: Optional[int] = None, strip_prefix: bool = False, redirect_https: bool = False, # fixme: this is horrible UX. # shall we switch to manually calling provide_ingress_requirements with all args when ready? - scheme: typing.Callable[[], str] = lambda: "http", + scheme: Union[Callable[[], str], str] = lambda: "http", ): """Constructor for IngressRequirer. @@ -560,9 +577,12 @@ def __init__( relation must be of interface type `ingress` and have "limit: 1") host: Hostname to be used by the ingress provider to address the requiring application; if unspecified, the default Kubernetes service name will be used. + ip: Alternative addressing method other than host to be used by the ingress provider; + if unspecified, binding address from juju network API will be used. strip_prefix: configure Traefik to strip the path prefix. redirect_https: redirect incoming requests to HTTPS. scheme: callable returning the scheme to use when constructing the ingress url. + Or a string, if the scheme is known and stable at charm-init-time. Request Args: port: the port of the service @@ -572,14 +592,14 @@ def __init__( self.relation_name = relation_name self._strip_prefix = strip_prefix self._redirect_https = redirect_https - self._get_scheme = scheme + self._get_scheme = scheme if callable(scheme) else lambda: scheme self._stored.set_default(current_url=None) # type: ignore # if instantiated with a port, and we are related, then # we immediately publish our ingress data to speed up the process. if port: - self._auto_data = host, port + self._auto_data = host, ip, port else: self._auto_data = None @@ -616,14 +636,15 @@ def is_ready(self): def _publish_auto_data(self): if self._auto_data: - host, port = self._auto_data - self.provide_ingress_requirements(host=host, port=port) + host, ip, port = self._auto_data + self.provide_ingress_requirements(host=host, ip=ip, port=port) def provide_ingress_requirements( self, *, scheme: Optional[str] = None, host: Optional[str] = None, + ip: Optional[str] = None, port: int, ): """Publishes the data that Traefik needs to provide ingress. @@ -632,34 +653,48 @@ def provide_ingress_requirements( scheme: Scheme to be used; if unspecified, use the one used by __init__. host: Hostname to be used by the ingress provider to address the requirer unit; if unspecified, FQDN will be used instead + ip: Alternative addressing method other than host to be used by the ingress provider. + if unspecified, binding address from juju network API will be used. port: the port of the service (required) """ for relation in self.relations: - self._provide_ingress_requirements(scheme, host, port, relation) + self._provide_ingress_requirements(scheme, host, ip, port, relation) def _provide_ingress_requirements( self, scheme: Optional[str], host: Optional[str], + ip: Optional[str], port: int, relation: Relation, ): if self.unit.is_leader(): self._publish_app_data(scheme, port, relation) - self._publish_unit_data(host, relation) + self._publish_unit_data(host, ip, relation) def _publish_unit_data( self, host: Optional[str], + ip: Optional[str], relation: Relation, ): if not host: host = socket.getfqdn() + if ip is None: + network_binding = self.charm.model.get_binding(relation) + if ( + network_binding is not None + and (bind_address := network_binding.network.bind_address) is not None + ): + ip = str(bind_address) + else: + log.error("failed to retrieve ip information from juju") + unit_databag = relation.data[self.unit] try: - IngressRequirerUnitData(host=host).dump(unit_databag) + IngressRequirerUnitData(host=host, ip=ip).dump(unit_databag) except pydantic.ValidationError as e: msg = "failed to validate unit data" log.info(msg, exc_info=True) # log to INFO because this might be expected