diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py index 59a9722..aaed2e5 100644 --- a/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -331,10 +331,14 @@ def _on_topic_requested(self, event: TopicRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 37 +LIBPATCH = 39 PYDEPS = ["ops>=2.0.0"] +# Starting from what LIBPATCH number to apply legacy solutions +# v0.17 was the last version without secrets +LEGACY_SUPPORT_FROM = 17 + logger = logging.getLogger(__name__) Diff = namedtuple("Diff", "added changed deleted") @@ -351,36 +355,16 @@ def _on_topic_requested(self, event: TopicRequestedEvent): GROUP_MAPPING_FIELD = "secret_group_mapping" GROUP_SEPARATOR = "@" +MODEL_ERRORS = { + "not_leader": "this unit is not the leader", + "no_label_and_uri": "ERROR either URI or label should be used for getting an owned secret but not both", + "owner_no_refresh": "ERROR secret owner cannot use --refresh", +} -class SecretGroup(str): - """Secret groups specific type.""" - - -class SecretGroupsAggregate(str): - """Secret groups with option to extend with additional constants.""" - def __init__(self): - self.USER = SecretGroup("user") - self.TLS = SecretGroup("tls") - self.EXTRA = SecretGroup("extra") - - def __setattr__(self, name, value): - """Setting internal constants.""" - if name in self.__dict__: - raise RuntimeError("Can't set constant!") - else: - super().__setattr__(name, SecretGroup(value)) - - def groups(self) -> list: - """Return the list of stored SecretGroups.""" - return list(self.__dict__.values()) - - def get_group(self, group: str) -> Optional[SecretGroup]: - """If the input str translates to a group name, return that.""" - return SecretGroup(group) if group in self.groups() else None - - -SECRET_GROUPS = SecretGroupsAggregate() +############################################################################## +# Exceptions +############################################################################## class DataInterfacesError(Exception): @@ -407,6 +391,15 @@ class IllegalOperationError(DataInterfacesError): """To be used when an operation is not allowed to be performed.""" +############################################################################## +# Global helpers / utilities +############################################################################## + +############################################################################## +# Databag handling and comparison methods +############################################################################## + + def get_encoded_dict( relation: Relation, member: Union[Unit, Application], field: str ) -> Optional[Dict[str, str]]: @@ -482,6 +475,11 @@ def diff(event: RelationChangedEvent, bucket: Optional[Union[Unit, Application]] return Diff(added, changed, deleted) +############################################################################## +# Module decorators +############################################################################## + + def leader_only(f): """Decorator to ensure that only leader can perform given operation.""" @@ -536,6 +534,36 @@ def wrapper(self, *args, **kwargs): return wrapper +def legacy_apply_from_version(version: int) -> Callable: + """Decorator to decide whether to apply a legacy function or not. + + Based on LEGACY_SUPPORT_FROM module variable value, the importer charm may only want + to apply legacy solutions starting from a specific LIBPATCH. + + NOTE: All 'legacy' functions have to be defined and called in a way that they return `None`. + This results in cleaner and more secure execution flows in case the function may be disabled. + This requirement implicitly means that legacy functions change the internal state strictly, + don't return information. + """ + + def decorator(f: Callable[..., None]): + """Signature is ensuring None return value.""" + f.legacy_version = version + + def wrapper(self, *args, **kwargs) -> None: + if version >= LEGACY_SUPPORT_FROM: + return f(self, *args, **kwargs) + + return wrapper + + return decorator + + +############################################################################## +# Helper classes +############################################################################## + + class Scope(Enum): """Peer relations scope.""" @@ -543,9 +571,35 @@ class Scope(Enum): UNIT = "unit" -################################################################################ -# Secrets internal caching -################################################################################ +class SecretGroup(str): + """Secret groups specific type.""" + + +class SecretGroupsAggregate(str): + """Secret groups with option to extend with additional constants.""" + + def __init__(self): + self.USER = SecretGroup("user") + self.TLS = SecretGroup("tls") + self.EXTRA = SecretGroup("extra") + + def __setattr__(self, name, value): + """Setting internal constants.""" + if name in self.__dict__: + raise RuntimeError("Can't set constant!") + else: + super().__setattr__(name, SecretGroup(value)) + + def groups(self) -> list: + """Return the list of stored SecretGroups.""" + return list(self.__dict__.values()) + + def get_group(self, group: str) -> Optional[SecretGroup]: + """If the input str translates to a group name, return that.""" + return SecretGroup(group) if group in self.groups() else None + + +SECRET_GROUPS = SecretGroupsAggregate() class CachedSecret: @@ -554,6 +608,8 @@ class CachedSecret: The data structure is precisely re-using/simulating as in the actual Secret Storage """ + KNOWN_MODEL_ERRORS = [MODEL_ERRORS["no_label_and_uri"], MODEL_ERRORS["owner_no_refresh"]] + def __init__( self, model: Model, @@ -571,6 +627,95 @@ def __init__( self.legacy_labels = legacy_labels self.current_label = None + @property + def meta(self) -> Optional[Secret]: + """Getting cached secret meta-information.""" + if not self._secret_meta: + if not (self._secret_uri or self.label): + return + + try: + self._secret_meta = self._model.get_secret(label=self.label) + except SecretNotFoundError: + # Falling back to seeking for potential legacy labels + self._legacy_compat_find_secret_by_old_label() + + # If still not found, to be checked by URI, to be labelled with the proposed label + if not self._secret_meta and self._secret_uri: + self._secret_meta = self._model.get_secret(id=self._secret_uri, label=self.label) + return self._secret_meta + + ########################################################################## + # Backwards compatibility / Upgrades + ########################################################################## + # These functions are used to keep backwards compatibility on rolling upgrades + # Policy: + # All data is kept intact until the first write operation. (This allows a minimal + # grace period during which rollbacks are fully safe. For more info see the spec.) + # All data involves: + # - databag contents + # - secrets content + # - secret labels (!!!) + # Legacy functions must return None, and leave an equally consistent state whether + # they are executed or skipped (as a high enough versioned execution environment may + # not require so) + + # Compatibility + + @legacy_apply_from_version(34) + def _legacy_compat_find_secret_by_old_label(self) -> None: + """Compatibility function, allowing to find a secret by a legacy label. + + This functionality is typically needed when secret labels changed over an upgrade. + Until the first write operation, we need to maintain data as it was, including keeping + the old secret label. In order to keep track of the old label currently used to access + the secret, and additional 'current_label' field is being defined. + """ + for label in self.legacy_labels: + try: + self._secret_meta = self._model.get_secret(label=label) + except SecretNotFoundError: + pass + else: + if label != self.label: + self.current_label = label + return + + # Migrations + + @legacy_apply_from_version(34) + def _legacy_migration_to_new_label_if_needed(self) -> None: + """Helper function to re-create the secret with a different label. + + Juju does not provide a way to change secret labels. + Thus whenever moving from secrets version that involves secret label changes, + we "re-create" the existing secret, and attach the new label to the new + secret, to be used from then on. + + Note: we replace the old secret with a new one "in place", as we can't + easily switch the containing SecretCache structure to point to a new secret. + Instead we are changing the 'self' (CachedSecret) object to point to the + new instance. + """ + if not self.current_label or not (self.meta and self._secret_meta): + return + + # Create a new secret with the new label + content = self._secret_meta.get_content() + self._secret_uri = None + + # It will be nice to have the possibility to check if we are the owners of the secret... + try: + self._secret_meta = self.add_secret(content, label=self.label) + except ModelError as err: + if MODEL_ERRORS["not_leader"] not in str(err): + raise + self.current_label = None + + ########################################################################## + # Public functions + ########################################################################## + def add_secret( self, content: Dict[str, str], @@ -593,28 +738,6 @@ def add_secret( self._secret_meta = secret return self._secret_meta - @property - def meta(self) -> Optional[Secret]: - """Getting cached secret meta-information.""" - if not self._secret_meta: - if not (self._secret_uri or self.label): - return - - for label in [self.label] + self.legacy_labels: - try: - self._secret_meta = self._model.get_secret(label=label) - except SecretNotFoundError: - pass - else: - if label != self.label: - self.current_label = label - break - - # If still not found, to be checked by URI, to be labelled with the proposed label - if not self._secret_meta and self._secret_uri: - self._secret_meta = self._model.get_secret(id=self._secret_uri, label=self.label) - return self._secret_meta - def get_content(self) -> Dict[str, str]: """Getting cached secret content.""" if not self._secret_content: @@ -624,35 +747,14 @@ def get_content(self) -> Dict[str, str]: except (ValueError, ModelError) as err: # https://bugs.launchpad.net/juju/+bug/2042596 # Only triggered when 'refresh' is set - known_model_errors = [ - "ERROR either URI or label should be used for getting an owned secret but not both", - "ERROR secret owner cannot use --refresh", - ] if isinstance(err, ModelError) and not any( - msg in str(err) for msg in known_model_errors + msg in str(err) for msg in self.KNOWN_MODEL_ERRORS ): raise # Due to: ValueError: Secret owner cannot use refresh=True self._secret_content = self.meta.get_content() return self._secret_content - def _move_to_new_label_if_needed(self): - """Helper function to re-create the secret with a different label.""" - if not self.current_label or not (self.meta and self._secret_meta): - return - - # Create a new secret with the new label - content = self._secret_meta.get_content() - self._secret_uri = None - - # I wish we could just check if we are the owners of the secret... - try: - self._secret_meta = self.add_secret(content, label=self.label) - except ModelError as err: - if "this unit is not the leader" not in str(err): - raise - self.current_label = None - def set_content(self, content: Dict[str, str]) -> None: """Setting cached secret content.""" if not self.meta: @@ -663,7 +765,7 @@ def set_content(self, content: Dict[str, str]) -> None: return if content: - self._move_to_new_label_if_needed() + self._legacy_migration_to_new_label_if_needed() self.meta.set_content(content) self._secret_content = content else: @@ -926,6 +1028,23 @@ def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" raise NotImplementedError + # Optional overrides + + def _legacy_apply_on_fetch(self) -> None: + """This function should provide a list of compatibility functions to be applied when fetching (legacy) data.""" + pass + + def _legacy_apply_on_update(self, fields: List[str]) -> None: + """This function should provide a list of compatibility functions to be applied when writing data. + + Since data may be at a legacy version, migration may be mandatory. + """ + pass + + def _legacy_apply_on_delete(self, fields: List[str]) -> None: + """This function should provide a list of compatibility functions to be applied when deleting (legacy) data.""" + pass + # Internal helper methods @staticmethod @@ -1178,6 +1297,16 @@ def get_relation(self, relation_name, relation_id) -> Relation: return relation + def get_secret_uri(self, relation: Relation, group: SecretGroup) -> Optional[str]: + """Get the secret URI for the corresponding group.""" + secret_field = self._generate_secret_field_name(group) + return relation.data[self.component].get(secret_field) + + def set_secret_uri(self, relation: Relation, group: SecretGroup, secret_uri: str) -> None: + """Set the secret URI for the corresponding group.""" + secret_field = self._generate_secret_field_name(group) + relation.data[self.component][secret_field] = secret_uri + def fetch_relation_data( self, relation_ids: Optional[List[int]] = None, @@ -1194,6 +1323,8 @@ def fetch_relation_data( a dict of the values stored in the relation data bag for all relation instances (indexed by the relation ID). """ + self._legacy_apply_on_fetch() + if not relation_name: relation_name = self.relation_name @@ -1232,6 +1363,8 @@ def fetch_my_relation_data( NOTE: Since only the leader can read the relation's 'this_app'-side Application databag, the functionality is limited to leaders """ + self._legacy_apply_on_fetch() + if not relation_name: relation_name = self.relation_name @@ -1263,6 +1396,8 @@ def fetch_my_relation_field( @leader_only def update_relation_data(self, relation_id: int, data: dict) -> None: """Update the data within the relation.""" + self._legacy_apply_on_update(list(data.keys())) + relation_name = self.relation_name relation = self.get_relation(relation_name, relation_id) return self._update_relation_data(relation, data) @@ -1270,6 +1405,8 @@ def update_relation_data(self, relation_id: int, data: dict) -> None: @leader_only def delete_relation_data(self, relation_id: int, fields: List[str]) -> None: """Remove field from the relation.""" + self._legacy_apply_on_delete(fields) + relation_name = self.relation_name relation = self.get_relation(relation_name, relation_id) return self._delete_relation_data(relation, fields) @@ -1336,8 +1473,7 @@ def _add_relation_secret( uri_to_databag=True, ) -> bool: """Add a new Juju Secret that will be registered in the relation databag.""" - secret_field = self._generate_secret_field_name(group_mapping) - if uri_to_databag and relation.data[self.component].get(secret_field): + if uri_to_databag and self.get_secret_uri(relation, group_mapping): logging.error("Secret for relation %s already exists, not adding again", relation.id) return False @@ -1348,7 +1484,7 @@ def _add_relation_secret( # According to lint we may not have a Secret ID if uri_to_databag and secret.meta and secret.meta.id: - relation.data[self.component][secret_field] = secret.meta.id + self.set_secret_uri(relation, group_mapping, secret.meta.id) # Return the content that was added return True @@ -1449,8 +1585,7 @@ def _get_relation_secret( if not relation: return - secret_field = self._generate_secret_field_name(group_mapping) - if secret_uri := relation.data[self.local_app].get(secret_field): + if secret_uri := self.get_secret_uri(relation, group_mapping): return self.secrets.get(label, secret_uri) def _fetch_specific_relation_data( @@ -1603,11 +1738,10 @@ def _register_secrets_to_relation(self, relation: Relation, params_name_list: Li for group in SECRET_GROUPS.groups(): secret_field = self._generate_secret_field_name(group) - if secret_field in params_name_list: - if secret_uri := relation.data[relation.app].get(secret_field): - self._register_secret_to_relation( - relation.name, relation.id, secret_uri, group - ) + if secret_field in params_name_list and ( + secret_uri := self.get_secret_uri(relation, group) + ): + self._register_secret_to_relation(relation.name, relation.id, secret_uri, group) def _is_resource_created_for_relation(self, relation: Relation) -> bool: if not relation.app: @@ -1618,6 +1752,17 @@ def _is_resource_created_for_relation(self, relation: Relation) -> bool: ) return bool(data.get("username")) and bool(data.get("password")) + # Public functions + + def get_secret_uri(self, relation: Relation, group: SecretGroup) -> Optional[str]: + """Getting relation secret URI for the corresponding Secret Group.""" + secret_field = self._generate_secret_field_name(group) + return relation.data[relation.app].get(secret_field) + + def set_secret_uri(self, relation: Relation, group: SecretGroup, uri: str) -> None: + """Setting relation secret URI is not possible for a Requirer.""" + raise NotImplementedError("Requirer can not change the relation secret URI.") + def is_resource_created(self, relation_id: Optional[int] = None) -> bool: """Check if the resource has been created. @@ -1768,7 +1913,6 @@ def __init__( secret_field_name: Optional[str] = None, deleted_label: Optional[str] = None, ): - """Manager of base client relations.""" RequirerData.__init__( self, model, @@ -1779,6 +1923,11 @@ def __init__( self.secret_field_name = secret_field_name if secret_field_name else self.SECRET_FIELD_NAME self.deleted_label = deleted_label self._secret_label_map = {} + + # Legacy information holders + self._legacy_labels = [] + self._legacy_secret_uri = None + # Secrets that are being dynamically added within the scope of this event handler run self._new_secrets = [] self._additional_secret_group_mapping = additional_secret_group_mapping @@ -1853,10 +2002,12 @@ def set_secret( value: The string value of the secret group_mapping: The name of the "secret group", in case the field is to be added to an existing secret """ + self._legacy_apply_on_update([field]) + full_field = self._field_to_internal_name(field, group_mapping) if self.secrets_enabled and full_field not in self.current_secret_fields: self._new_secrets.append(full_field) - if self._no_group_with_databag(field, full_field): + if self.valid_field_pattern(field, full_field): self.update_relation_data(relation_id, {full_field: value}) # Unlike for set_secret(), there's no harm using this operation with static secrets @@ -1869,6 +2020,8 @@ def get_secret( group_mapping: Optional[SecretGroup] = None, ) -> Optional[str]: """Public interface method to fetch secrets only.""" + self._legacy_apply_on_fetch() + full_field = self._field_to_internal_name(field, group_mapping) if ( self.secrets_enabled @@ -1876,7 +2029,7 @@ def get_secret( and field not in self.current_secret_fields ): return - if self._no_group_with_databag(field, full_field): + if self.valid_field_pattern(field, full_field): return self.fetch_my_relation_field(relation_id, full_field) @dynamic_secrets_only @@ -1887,14 +2040,19 @@ def delete_secret( group_mapping: Optional[SecretGroup] = None, ) -> Optional[str]: """Public interface method to delete secrets only.""" + self._legacy_apply_on_delete([field]) + full_field = self._field_to_internal_name(field, group_mapping) if self.secrets_enabled and full_field not in self.current_secret_fields: logger.warning(f"Secret {field} from group {group_mapping} was not found") return - if self._no_group_with_databag(field, full_field): + + if self.valid_field_pattern(field, full_field): self.delete_relation_data(relation_id, [full_field]) + ########################################################################## # Helpers + ########################################################################## @staticmethod def _field_to_internal_name(field: str, group: Optional[SecretGroup]) -> str: @@ -1936,10 +2094,69 @@ def _content_for_secret_group( if k in self.secret_fields } - # Backwards compatibility + def valid_field_pattern(self, field: str, full_field: str) -> bool: + """Check that no secret group is attempted to be used together without secrets being enabled. + + Secrets groups are impossible to use with versions that are not yet supporting secrets. + """ + if not self.secrets_enabled and full_field != field: + logger.error( + f"Can't access {full_field}: no secrets available (i.e. no secret groups either)." + ) + return False + return True + + ########################################################################## + # Backwards compatibility / Upgrades + ########################################################################## + # These functions are used to keep backwards compatibility on upgrades + # Policy: + # All data is kept intact until the first write operation. (This allows a minimal + # grace period during which rollbacks are fully safe. For more info see spec.) + # All data involves: + # - databag + # - secrets content + # - secret labels (!!!) + # Legacy functions must return None, and leave an equally consistent state whether + # they are executed or skipped (as a high enough versioned execution environment may + # not require so) + + # Full legacy stack for each operation + + def _legacy_apply_on_fetch(self) -> None: + """All legacy functions to be applied on fetch.""" + relation = self._model.relations[self.relation_name][0] + self._legacy_compat_generate_prev_labels() + self._legacy_compat_secret_uri_from_databag(relation) + + def _legacy_apply_on_update(self, fields) -> None: + """All legacy functions to be applied on update.""" + relation = self._model.relations[self.relation_name][0] + self._legacy_compat_generate_prev_labels() + self._legacy_compat_secret_uri_from_databag(relation) + self._legacy_migration_remove_secret_from_databag(relation, fields) + self._legacy_migration_remove_secret_field_name_from_databag(relation) + + def _legacy_apply_on_delete(self, fields) -> None: + """All legacy functions to be applied on delete.""" + relation = self._model.relations[self.relation_name][0] + self._legacy_compat_generate_prev_labels() + self._legacy_compat_secret_uri_from_databag(relation) + self._legacy_compat_check_deleted_label(relation, fields) + + # Compatibility + + @legacy_apply_from_version(18) + def _legacy_compat_check_deleted_label(self, relation, fields) -> None: + """Helper function for legacy behavior. + + As long as https://bugs.launchpad.net/juju/+bug/2028094 wasn't fixed, + we did not delete fields but rather kept them in the secret with a string value + expressing invalidity. This function is maintainnig that behavior when needed. + """ + if not self.deleted_label: + return - def _check_deleted_label(self, relation, fields) -> None: - """Helper function for legacy behavior.""" current_data = self.fetch_my_relation_data([relation.id], fields) if current_data is not None: # Check if the secret we wanna delete actually exists @@ -1952,7 +2169,43 @@ def _check_deleted_label(self, relation, fields) -> None: ", ".join(non_existent), ) - def _remove_secret_from_databag(self, relation, fields: List[str]) -> None: + @legacy_apply_from_version(18) + def _legacy_compat_secret_uri_from_databag(self, relation) -> None: + """Fetching the secret URI from the databag, in case stored there.""" + self._legacy_secret_uri = relation.data[self.component].get( + self._generate_secret_field_name(), None + ) + + @legacy_apply_from_version(34) + def _legacy_compat_generate_prev_labels(self) -> None: + """Generator for legacy secret label names, for backwards compatibility. + + Secret label is part of the data that MUST be maintained across rolling upgrades. + In case there may be a change on a secret label, the old label must be recognized + after upgrades, and left intact until the first write operation -- when we roll over + to the new label. + + This function keeps "memory" of previously used secret labels. + NOTE: Return value takes decorator into account -- all 'legacy' functions may return `None` + + v0.34 (rev69): Fixing issue https://github.com/canonical/data-platform-libs/issues/155 + meant moving from '.' (i.e. 'mysql.app', 'mysql.unit') + to labels '..' (like 'peer.mysql.app') + """ + if self._legacy_labels: + return + + result = [] + members = [self._model.app.name] + if self.scope: + members.append(self.scope.value) + result.append(f"{'.'.join(members)}") + self._legacy_labels = result + + # Migration + + @legacy_apply_from_version(18) + def _legacy_migration_remove_secret_from_databag(self, relation, fields: List[str]) -> None: """For Rolling Upgrades -- when moving from databag to secrets usage. Practically what happens here is to remove stuff from the databag that is @@ -1966,10 +2219,16 @@ def _remove_secret_from_databag(self, relation, fields: List[str]) -> None: if self._fetch_relation_data_without_secrets(self.component, relation, [field]): self._delete_relation_data_without_secrets(self.component, relation, [field]) - def _remove_secret_field_name_from_databag(self, relation) -> None: + @legacy_apply_from_version(18) + def _legacy_migration_remove_secret_field_name_from_databag(self, relation) -> None: """Making sure that the old databag URI is gone. This action should not be executed more than once. + + There was a phase (before moving secrets usage to libs) when charms saved the peer + secret URI to the databag, and used this URI from then on to retrieve their secret. + When upgrading to charm versions using this library, we need to add a label to the + secret and access it via label from than on, and remove the old traces from the databag. """ # Nothing to do if 'internal-secret' is not in the databag if not (relation.data[self.component].get(self._generate_secret_field_name())): @@ -1985,25 +2244,9 @@ def _remove_secret_field_name_from_databag(self, relation) -> None: # Databag reference to the secret URI can be removed, now that it's labelled relation.data[self.component].pop(self._generate_secret_field_name(), None) - def _previous_labels(self) -> List[str]: - """Generator for legacy secret label names, for backwards compatibility.""" - result = [] - members = [self._model.app.name] - if self.scope: - members.append(self.scope.value) - result.append(f"{'.'.join(members)}") - return result - - def _no_group_with_databag(self, field: str, full_field: str) -> bool: - """Check that no secret group is attempted to be used together with databag.""" - if not self.secrets_enabled and full_field != field: - logger.error( - f"Can't access {full_field}: no secrets available (i.e. no secret groups either)." - ) - return False - return True - + ########################################################################## # Event handlers + ########################################################################## def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" @@ -2013,7 +2256,9 @@ def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: """Event emitted when the secret has changed.""" pass + ########################################################################## # Overrides of Relation Data handling functions + ########################################################################## def _generate_secret_label( self, relation_name: str, relation_id: int, group_mapping: SecretGroup @@ -2050,13 +2295,14 @@ def _get_relation_secret( return label = self._generate_secret_label(relation_name, relation_id, group_mapping) - secret_uri = relation.data[self.component].get(self._generate_secret_field_name(), None) # URI or legacy label is only to applied when moving single legacy secret to a (new) label if group_mapping == SECRET_GROUPS.EXTRA: # Fetching the secret with fallback to URI (in case label is not yet known) # Label would we "stuck" on the secret in case it is found - return self.secrets.get(label, secret_uri, legacy_labels=self._previous_labels()) + return self.secrets.get( + label, self._legacy_secret_uri, legacy_labels=self._legacy_labels + ) return self.secrets.get(label) def _get_group_secret_contents( @@ -2086,7 +2332,6 @@ def _fetch_my_specific_relation_data( @either_static_or_dynamic_secrets def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: """Update data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" - self._remove_secret_from_databag(relation, list(data.keys())) _, normal_fields = self._process_secret_fields( relation, self.secret_fields, @@ -2095,7 +2340,6 @@ def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> Non data=data, uri_to_databag=False, ) - self._remove_secret_field_name_from_databag(relation) normal_content = {k: v for k, v in data.items() if k in normal_fields} self._update_relation_data_without_secrets(self.component, relation, normal_content) @@ -2104,8 +2348,6 @@ def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> Non def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" if self.secret_fields and self.deleted_label: - # Legacy, backwards compatibility - self._check_deleted_label(relation, fields) _, normal_fields = self._process_secret_fields( relation, @@ -2141,7 +2383,9 @@ def fetch_relation_field( "fetch_my_relation_data() and fetch_my_relation_field()" ) + ########################################################################## # Public functions -- inherited + ########################################################################## fetch_my_relation_data = Data.fetch_my_relation_data fetch_my_relation_field = Data.fetch_my_relation_field @@ -2606,6 +2850,14 @@ def set_version(self, relation_id: int, version: str) -> None: """ self.update_relation_data(relation_id, {"version": version}) + def set_subordinated(self, relation_id: int) -> None: + """Raises the subordinated flag in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + """ + self.update_relation_data(relation_id, {"subordinated": "true"}) + class DatabaseProviderEventHandlers(EventHandlers): """Provider-side of the database relation handlers.""" @@ -2842,6 +3094,21 @@ def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the database relation has changed.""" + is_subordinate = False + remote_unit_data = None + for key in event.relation.data.keys(): + if isinstance(key, Unit) and not key.name.startswith(self.charm.app.name): + remote_unit_data = event.relation.data[key] + elif isinstance(key, Application) and key.name != self.charm.app.name: + is_subordinate = event.relation.data[key].get("subordinated") == "true" + + if is_subordinate: + if not remote_unit_data: + return + + if remote_unit_data.get("state") != "ready": + return + # Check which data has changed to emit customs events. diff = self._diff(event) diff --git a/metadata.yaml b/metadata.yaml index 123f250..f5c6992 100644 --- a/metadata.yaml +++ b/metadata.yaml @@ -44,6 +44,9 @@ requires: spark-service-account: interface: spark_service_account limit: 1 + zookeeper: + interface: zookeeper + limit: 1 provides: diff --git a/src/charm.py b/src/charm.py index 8f277b1..eef8cc7 100755 --- a/src/charm.py +++ b/src/charm.py @@ -22,6 +22,7 @@ from events.kyuubi import KyuubiEvents from events.metastore import MetastoreEvents from events.s3 import S3Events +from events.zookeeper import ZookeeperEvents # Log messages can be retrieved using juju debug-log logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ def __init__(self, *args): self.hub_events = SparkIntegrationHubEvents(self, self.context, self.workload) self.metastore_events = MetastoreEvents(self, self.context, self.workload) self.auth_events = AuthenticationEvents(self, self.context, self.workload) + self.zookeeper_events = ZookeeperEvents(self, self.context, self.workload) self.action_events = ActionEvents(self, self.context, self.workload) diff --git a/src/config/kyuubi.py b/src/config/kyuubi.py index 77f41b0..ba9ce84 100644 --- a/src/config/kyuubi.py +++ b/src/config/kyuubi.py @@ -7,16 +7,19 @@ from typing import Optional -from constants import AUTHENTICATION_TABLE_NAME -from core.domain import DatabaseConnectionInfo +from constants import AUTHENTICATION_TABLE_NAME, HA_ZNODE_NAME +from core.domain import DatabaseConnectionInfo, ZookeeperInfo from utils.logging import WithLogging class KyuubiConfig(WithLogging): """Kyuubi Configurations.""" - def __init__(self, db_info: Optional[DatabaseConnectionInfo]): + def __init__( + self, db_info: Optional[DatabaseConnectionInfo], zookeeper_info: Optional[ZookeeperInfo] + ): self.db_info = db_info + self.zookeeper_info = zookeeper_info def _get_db_connection_url(self) -> str: endpoint = self.db_info.endpoint @@ -28,6 +31,13 @@ def _get_authentication_query(self) -> str: "WHERE username=${user} AND passwd=${password}" ) + def _get_zookeeper_auth_digest(self) -> str: + if not self.zookeeper_info: + return "" + username = self.zookeeper_info.username + password = self.zookeeper_info.password + return f"{username}:{password}" + @property def _auth_conf(self) -> dict[str, str]: if not self.db_info: @@ -41,9 +51,22 @@ def _auth_conf(self) -> dict[str, str]: "kyuubi.authentication.jdbc.query": self._get_authentication_query(), } + @property + def _ha_conf(self) -> dict[str, str]: + if not self.zookeeper_info: + return {} + return { + "kyuubi.ha.addresses": self.zookeeper_info.uris, + # FIXME: Get this value from self.context.zookeeper.uris when znode created by + # zookeeper charm has enough permissions for Kyuubi to work + "kyuubi.ha.namespace": HA_ZNODE_NAME, + "kyuubi.ha.zookeeper.auth.type": "DIGEST", + "kyuubi.ha.zookeeper.auth.digest": self._get_zookeeper_auth_digest(), + } + def to_dict(self) -> dict[str, str]: """Return the dict representation of the configuration file.""" - return self._auth_conf + return self._auth_conf | self._ha_conf @property def contents(self) -> str: diff --git a/src/constants.py b/src/constants.py index e167c11..7985b55 100644 --- a/src/constants.py +++ b/src/constants.py @@ -5,22 +5,30 @@ KYUUBI_CONTAINER_NAME = "kyuubi" KYUUBI_SERVICE_NAME = "kyuubi" + +# Database related literals METASTORE_DATABASE_NAME = "hivemetastore" AUTHENTICATION_DATABASE_NAME = "auth_db" AUTHENTICATION_TABLE_NAME = "kyuubi_users" POSTGRESQL_DEFAULT_DATABASE = "postgres" +# Relation names S3_INTEGRATOR_REL = "s3-credentials" POSTGRESQL_METASTORE_DB_REL = "metastore-db" POSTGRESQL_AUTH_DB_REL = "auth-db" SPARK_SERVICE_ACCOUNT_REL = "spark-service-account" +ZOOKEEPER_REL = "zookeeper" +KYUUBI_CLIENT_RELATION_NAME = "jdbc" +# Literals related to K8s NAMESPACE_CONFIG_NAME = "namespace" SERVICE_ACCOUNT_CONFIG_NAME = "service-account" +# Literals related to Kyuubi JDBC_PORT = 10009 - KYUUBI_OCI_IMAGE = "ghcr.io/canonical/charmed-spark-kyuubi:3.4-22.04_edge" - DEFAULT_ADMIN_USERNAME = "admin" -KYUUBI_CLIENT_RELATION_NAME = "jdbc" + +# Zookeeper literals +HA_ZNODE_NAME = "/kyuubi" +HA_ZNODE_NAME_TEMP = "/kyuubi-temp" diff --git a/src/core/context.py b/src/core/context.py index d3d0e59..84c9651 100644 --- a/src/core/context.py +++ b/src/core/context.py @@ -10,13 +10,20 @@ from common.relation.spark_sa import RequirerData from constants import ( AUTHENTICATION_DATABASE_NAME, + HA_ZNODE_NAME_TEMP, METASTORE_DATABASE_NAME, POSTGRESQL_AUTH_DB_REL, POSTGRESQL_METASTORE_DB_REL, S3_INTEGRATOR_REL, SPARK_SERVICE_ACCOUNT_REL, + ZOOKEEPER_REL, +) +from core.domain import ( + DatabaseConnectionInfo, + S3ConnectionInfo, + SparkServiceAccountInfo, + ZookeeperInfo, ) -from core.domain import DatabaseConnectionInfo, S3ConnectionInfo, SparkServiceAccountInfo from utils.logging import WithLogging @@ -36,6 +43,13 @@ def __init__(self, model: Model, config: ConfigData): extra_user_roles="superuser", ) + # FIXME: The database_name currently requested is a dummy name + # This should be replaced with the name of actual znode when znode created + # by zookeeper charm has enough permissions for Kyuubi to work + self.zookeeper_requirer_data = DatabaseRequirerData( + self.model, ZOOKEEPER_REL, database_name=HA_ZNODE_NAME_TEMP + ) + @property def _s3_relation(self) -> Relation | None: """The S3 relation.""" @@ -46,6 +60,11 @@ def _spark_account_relation(self) -> Relation | None: """The integration hub relation.""" return self.model.get_relation(SPARK_SERVICE_ACCOUNT_REL) + @property + def _zookeeper_relation(self) -> Relation | None: + """The zookeeper relation.""" + return self.model.get_relation(ZOOKEEPER_REL) + # --- DOMAIN OBJECTS --- @property @@ -91,6 +110,19 @@ def service_account(self) -> SparkServiceAccountInfo | None: ): return account + @property + def zookeeper(self) -> ZookeeperInfo | None: + """The state of the Zookeeper information.""" + return ( + ZookeeperInfo(rel, self.zookeeper_requirer_data, rel.app) + if (rel := self._zookeeper_relation) + else None + ) + def is_authentication_enabled(self) -> bool: """Returns whether the authentication has been enabled in the Kyuubi charm.""" return bool(self.auth_db) + + def is_ha_enabled(self) -> bool: + """Returns whether HA has been enabled in the Kyuubi charm.""" + return bool(self.zookeeper) diff --git a/src/core/domain.py b/src/core/domain.py index 4a15a58..d026f91 100644 --- a/src/core/domain.py +++ b/src/core/domain.py @@ -34,7 +34,7 @@ class Status(Enum): MISSING_INTEGRATION_HUB = BlockedStatus("Missing integration hub relation") INVALID_NAMESPACE = BlockedStatus("Invalid config option: namespace") INVALID_SERVICE_ACCOUNT = BlockedStatus("Invalid config option: service-account") - + WAITING_ZOOKEEPER = MaintenanceStatus("Waiting for zookeeper credentials") ACTIVE = ActiveStatus("") @@ -152,3 +152,102 @@ def service_account(self): def namespace(self): """Namespace used for running Spark jobs.""" return self.relation_data["namespace"] + + +class ZookeeperInfo(RelationState): + """State collection metadata for a the Zookeeper relation.""" + + def __init__( + self, + relation: Relation | None, + data_interface: Data, + local_app: Application | None = None, + ): + super().__init__(relation, data_interface, None) + self._local_app = local_app + + @property + def username(self) -> str: + """Username to connect to ZooKeeper.""" + if not self.relation: + return "" + + return ( + self.data_interface.fetch_relation_field( + relation_id=self.relation.id, field="username" + ) + or "" + ) + + @property + def password(self) -> str: + """Password of the ZooKeeper user.""" + if not self.relation: + return "" + + return ( + self.data_interface.fetch_relation_field( + relation_id=self.relation.id, field="password" + ) + or "" + ) + + @property + def endpoints(self) -> str: + """IP/host where ZooKeeper is located.""" + if not self.relation: + return "" + + return ( + self.data_interface.fetch_relation_field( + relation_id=self.relation.id, field="endpoints" + ) + or "" + ) + + @property + def database(self) -> str: + """Path allocated for Kyuubi on ZooKeeper.""" + if not self.relation: + return "" + + return ( + self.data_interface.fetch_relation_field( + relation_id=self.relation.id, field="database" + ) + or "" + ) + + @property + def uris(self) -> str: + """Comma separated connection string, containing endpoints.""" + if not self.relation: + return "" + + return ",".join( + sorted( # sorting as they may be disordered + ( + self.data_interface.fetch_relation_field( + relation_id=self.relation.id, field="uris" + ) + or "" + ).split(",") + ) + ).replace(self.database, "") + + @property + def zookeeper_connected(self) -> bool: + """Checks if there is an active ZooKeeper relation with all necessary data. + + Returns: + True if ZooKeeper is currently related with sufficient relation data + for a broker to connect with. Otherwise False + """ + if not all([self.username, self.password, self.database, self.uris]): + return False + + return True + + def __bool__(self) -> bool: + """Return whether this class object has sufficient information.""" + return self.zookeeper_connected diff --git a/src/core/workload/kyuubi.py b/src/core/workload/kyuubi.py index 87d54bb..a570d21 100644 --- a/src/core/workload/kyuubi.py +++ b/src/core/workload/kyuubi.py @@ -28,6 +28,12 @@ def __init__(self, container: Container, user: User = User()): self.container = container self.user = user + def get_ip_address(self) -> str: + """Return the IP address of the unit running the workload.""" + hostname = socket.getfqdn() + ip_address = socket.gethostbyname(hostname) + return ip_address + def get_jdbc_endpoint(self) -> str: """Return the JDBC endpoint to connect to Kyuubi server.""" hostname = socket.getfqdn() diff --git a/src/events/actions.py b/src/events/actions.py index bd174cf..9304671 100644 --- a/src/events/actions.py +++ b/src/events/actions.py @@ -7,7 +7,7 @@ from ops import CharmBase from ops.charm import ActionEvent -from constants import DEFAULT_ADMIN_USERNAME +from constants import DEFAULT_ADMIN_USERNAME, HA_ZNODE_NAME, JDBC_PORT from core.context import Context from core.domain import Status from core.workload import KyuubiWorkloadBase @@ -39,15 +39,22 @@ def _on_get_jdbc_endpoint(self, event: ActionEvent): if not self.workload.ready(): event.fail("The action failed because the workload is not ready yet.") return - if ( - not self.get_app_status( - s3_info=self.context.s3, service_account=self.context.service_account - ) - != Status.ACTIVE - ): + if self.get_app_status() != Status.ACTIVE.value: event.fail("The action failed because the charm is not in active state.") return - result = {"endpoint": self.workload.get_jdbc_endpoint()} + + if self.context.is_ha_enabled(): + address = self.context.zookeeper.uris + # FIXME: Get this value from self.context.zookeeper.uris when znode created by + # zookeeper charm has enough permissions for Kyuubi to work + namespace = HA_ZNODE_NAME + if not address.endswith("/"): + address += "/" + endpoint = f"jdbc:hive2://{address};serviceDiscoveryMode=zooKeeper;zooKeeperNamespace={namespace}" + else: + address = self.workload.get_ip_address() + endpoint = f"jdbc:hive2://{address}:{JDBC_PORT}/" + result = {"endpoint": endpoint} event.set_results(result) def _on_get_password(self, event: ActionEvent) -> None: @@ -61,12 +68,7 @@ def _on_get_password(self, event: ActionEvent) -> None: if not self.workload.ready(): event.fail("The action failed because the workload is not ready yet.") return - if ( - not self.get_app_status( - s3_info=self.context.s3, service_account=self.context.service_account - ) - != Status.ACTIVE - ): + if self.get_app_status() != Status.ACTIVE.value: event.fail("The action failed because the charm is not in active state.") return password = self.auth.get_password(DEFAULT_ADMIN_USERNAME) @@ -89,12 +91,7 @@ def _on_set_password(self, event: ActionEvent) -> None: if not self.workload.ready(): event.fail("The action failed because the workload is not ready yet.") return - if ( - not self.get_app_status( - s3_info=self.context.s3, service_account=self.context.service_account - ) - != Status.ACTIVE - ): + if self.get_app_status() != Status.ACTIVE.value: event.fail("The action failed because the charm is not in active state.") return diff --git a/src/events/auth.py b/src/events/auth.py index d40dbfc..127c904 100644 --- a/src/events/auth.py +++ b/src/events/auth.py @@ -56,6 +56,7 @@ def _on_auth_db_created(self, event: DatabaseCreatedEvent) -> None: metastore_db_info=self.context.metastore_db, auth_db_info=self.context.auth_db, service_account_info=self.context.service_account, + zookeeper_info=self.context.zookeeper, ) @compute_status @@ -67,6 +68,7 @@ def _on_auth_db_endpoints_changed(self, event) -> None: metastore_db_info=self.context.metastore_db, auth_db_info=self.context.auth_db, service_account_info=self.context.service_account, + zookeeper_info=self.context.zookeeper, ) @compute_status @@ -78,6 +80,7 @@ def _on_auth_db_relation_removed(self, event) -> None: metastore_db_info=self.context.metastore_db, auth_db_info=None, service_account_info=self.context.service_account, + zookeeper_info=self.context.zookeeper, ) @compute_status diff --git a/src/events/base.py b/src/events/base.py index ab13f35..3ccd63e 100644 --- a/src/events/base.py +++ b/src/events/base.py @@ -10,7 +10,7 @@ from ops import CharmBase, EventBase, Object, StatusBase from core.context import Context -from core.domain import S3ConnectionInfo, SparkServiceAccountInfo, Status +from core.domain import Status from core.workload import KyuubiWorkloadBase from managers.k8s import K8sManager from managers.s3 import S3Manager @@ -26,28 +26,28 @@ class BaseEventHandler(Object, WithLogging): def get_app_status( self, - s3_info: S3ConnectionInfo | None, - service_account: SparkServiceAccountInfo | None, ) -> StatusBase: """Return the status of the charm.""" if not self.workload.ready(): return Status.WAITING_PEBBLE.value - if s3_info: - s3_manager = S3Manager(s3_info=s3_info) + if self.context.s3: + s3_manager = S3Manager(s3_info=self.context.s3) if not s3_manager.verify(): return Status.INVALID_CREDENTIALS.value - if not service_account: + if not self.context.service_account: return Status.MISSING_INTEGRATION_HUB.value - k8s_manager = K8sManager(service_account_info=service_account, workload=self.workload) + k8s_manager = K8sManager( + service_account_info=self.context.service_account, workload=self.workload + ) # Check whether any one of object storage backend has been configured # Currently, we do this check on the basis of presence of Spark properties # TODO: Rethink on this approach with a more sturdy solution if ( - not s3_info + not self.context.s3 and not k8s_manager.is_s3_configured() and not k8s_manager.is_azure_storage_configured() ): @@ -59,6 +59,9 @@ def get_app_status( if not k8s_manager.is_service_account_valid(): return Status.INVALID_SERVICE_ACCOUNT.value + if self.context._zookeeper_relation and not self.context.zookeeper: + return Status.WAITING_ZOOKEEPER.value + return Status.ACTIVE.value @@ -72,12 +75,8 @@ def wrapper_hook(event_handler: BaseEventHandler, event: EventBase): """Return output after resetting statuses.""" res = hook(event_handler, event) if event_handler.charm.unit.is_leader(): - event_handler.charm.app.status = event_handler.get_app_status( - event_handler.context.s3, event_handler.context.service_account - ) - event_handler.charm.unit.status = event_handler.get_app_status( - event_handler.context.s3, event_handler.context.service_account - ) + event_handler.charm.app.status = event_handler.get_app_status() + event_handler.charm.unit.status = event_handler.get_app_status() return res return wrapper_hook diff --git a/src/events/integration_hub.py b/src/events/integration_hub.py index 2c503de..e461263 100644 --- a/src/events/integration_hub.py +++ b/src/events/integration_hub.py @@ -59,6 +59,7 @@ def _on_account_granted(self, _: ServiceAccountGrantedEvent): metastore_db_info=self.context.metastore_db, auth_db_info=self.context.auth_db, service_account_info=self.context.service_account, + zookeeper_info=self.context.zookeeper, ) @compute_status @@ -71,4 +72,5 @@ def _on_account_gone(self, _: ServiceAccountGoneEvent): metastore_db_info=self.context.metastore_db, auth_db_info=self.context.auth_db, service_account_info=None, + zookeeper_info=self.context.zookeeper, ) diff --git a/src/events/kyuubi.py b/src/events/kyuubi.py index ca8a0af..d9f0c35 100644 --- a/src/events/kyuubi.py +++ b/src/events/kyuubi.py @@ -50,6 +50,7 @@ def _on_config_changed(self, event: ops.ConfigChangedEvent) -> None: metastore_db_info=self.context.metastore_db, auth_db_info=self.context.auth_db, service_account_info=self.context.service_account, + zookeeper_info=self.context.zookeeper, ) @compute_status @@ -67,4 +68,5 @@ def _on_kyuubi_pebble_ready(self, event: ops.PebbleReadyEvent): metastore_db_info=self.context.metastore_db, auth_db_info=self.context.auth_db, service_account_info=self.context.service_account, + zookeeper_info=self.context.zookeeper, ) diff --git a/src/events/metastore.py b/src/events/metastore.py index f322465..a11294f 100644 --- a/src/events/metastore.py +++ b/src/events/metastore.py @@ -52,6 +52,7 @@ def _on_metastore_db_created(self, event: DatabaseCreatedEvent) -> None: metastore_db_info=self.context.metastore_db, auth_db_info=self.context.auth_db, service_account_info=self.context.service_account, + zookeeper_info=self.context.zookeeper, ) @compute_status @@ -63,4 +64,5 @@ def _on_metastore_db_relation_removed(self, event) -> None: metastore_db_info=None, auth_db_info=self.context.auth_db, service_account_info=self.context.service_account, + zookeeper_info=self.context.zookeeper, ) diff --git a/src/events/s3.py b/src/events/s3.py index 1d5d1ee..ebee424 100644 --- a/src/events/s3.py +++ b/src/events/s3.py @@ -47,6 +47,7 @@ def _on_s3_credential_changed(self, _: CredentialsChangedEvent): metastore_db_info=self.context.metastore_db, auth_db_info=self.context.auth_db, service_account_info=self.context.service_account, + zookeeper_info=self.context.zookeeper, ) @compute_status @@ -59,4 +60,5 @@ def _on_s3_credential_gone(self, _: CredentialsGoneEvent): metastore_db_info=self.context.metastore_db, auth_db_info=self.context.auth_db, service_account_info=self.context.service_account, + zookeeper_info=self.context.zookeeper, ) diff --git a/src/events/zookeeper.py b/src/events/zookeeper.py new file mode 100644 index 0000000..eeceae2 --- /dev/null +++ b/src/events/zookeeper.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright 2024 Canonical Limited +# See LICENSE file for licensing details. + +"""Zookeeper related event handlers.""" + +from charms.data_platform_libs.v0.data_interfaces import DatabaseRequirerEventHandlers +from ops import CharmBase + +from constants import ZOOKEEPER_REL +from core.context import Context +from core.workload import KyuubiWorkloadBase +from events.base import BaseEventHandler, compute_status +from managers.kyuubi import KyuubiManager +from utils.logging import WithLogging + + +class ZookeeperEvents(BaseEventHandler, WithLogging): + """Class implementing Zookeeper integration event hooks.""" + + def __init__(self, charm: CharmBase, context: Context, workload: KyuubiWorkloadBase): + super().__init__(charm, "zookeeper") + + self.charm = charm + self.context = context + self.workload = workload + + self.kyuubi = KyuubiManager(self.workload) + self.zookeeper_handler = DatabaseRequirerEventHandlers( + self.charm, self.context.zookeeper_requirer_data + ) + + self.framework.observe( + self.charm.on[ZOOKEEPER_REL].relation_changed, self._on_zookeeper_changed + ) + self.framework.observe( + self.charm.on[ZOOKEEPER_REL].relation_broken, self._on_zookeeper_broken + ) + + @compute_status + def _on_zookeeper_changed(self, _): + self.logger.info("Zookeeper relation changed new...") + self.kyuubi.update( + s3_info=self.context.s3, + metastore_db_info=self.context.metastore_db, + auth_db_info=self.context.auth_db, + service_account_info=self.context.service_account, + zookeeper_info=self.context.zookeeper, + ) + + @compute_status + def _on_zookeeper_broken(self, _): + self.logger.info("Zookeeper relation broken...") + self.kyuubi.update( + s3_info=self.context.s3, + metastore_db_info=self.context.metastore_db, + auth_db_info=self.context.auth_db, + service_account_info=self.context.service_account, + zookeeper_info=None, + ) diff --git a/src/managers/kyuubi.py b/src/managers/kyuubi.py index 1c00063..1ce60fb 100644 --- a/src/managers/kyuubi.py +++ b/src/managers/kyuubi.py @@ -7,7 +7,12 @@ from config.hive import HiveConfig from config.kyuubi import KyuubiConfig from config.spark import SparkConfig -from core.domain import DatabaseConnectionInfo, S3ConnectionInfo, SparkServiceAccountInfo +from core.domain import ( + DatabaseConnectionInfo, + S3ConnectionInfo, + SparkServiceAccountInfo, + ZookeeperInfo, +) from core.workload import KyuubiWorkloadBase from utils.logging import WithLogging @@ -24,13 +29,14 @@ def update( metastore_db_info: DatabaseConnectionInfo | None, auth_db_info: DatabaseConnectionInfo | None, service_account_info: SparkServiceAccountInfo | None, + zookeeper_info: ZookeeperInfo | None, ): """Update Kyuubi service and restart it.""" spark_config = SparkConfig( s3_info=s3_info, service_account_info=service_account_info ).contents hive_config = HiveConfig(db_info=metastore_db_info).contents - kyuubi_config = KyuubiConfig(db_info=auth_db_info).contents + kyuubi_config = KyuubiConfig(db_info=auth_db_info, zookeeper_info=zookeeper_info).contents self.workload.write(spark_config, self.workload.SPARK_PROPERTIES_FILE) self.workload.write(hive_config, self.workload.HIVE_CONFIGURATION_FILE) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 2c77c3e..6bafefe 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -62,6 +62,7 @@ class IntegrationTestsCharms(BaseModel): s3: TestCharm postgres: TestCharm integration_hub: TestCharm + zookeeper: TestCharm @pytest.fixture(scope="module") @@ -88,6 +89,15 @@ def charm_versions() -> IntegrationTestsCharms: "trust": True, } ), + zookeeper=TestCharm( + **{ + "name": "zookeeper-k8s", + "channel": "3/edge", + "series": "jammy", + "alias": "zookeeper", + "num_units": 3, + } + ), ) @@ -141,14 +151,15 @@ def s3_bucket_and_creds(): @pytest.fixture(scope="module") -def test_pod(): +def test_pod(ops_test): logger.info("Preparing test pod fixture...") kyuubi_image = METADATA["resources"]["kyuubi-image"]["upstream-source"] + namespace = ops_test.model_name with open(TEST_POD_SPEC_FILE) as tf: template = Template(tf.read()) - pod_spec = template.substitute(kyuubi_image=kyuubi_image) + pod_spec = template.substitute(kyuubi_image=kyuubi_image, namespace=namespace) # Create test pod by applying pod spec apply_result = subprocess.run( @@ -160,7 +171,17 @@ def test_pod(): # Wait until the pod is in ready state wait_result = subprocess.run( - ["kubectl", "wait", "--for", "condition=Ready", f"pod/{pod_name}", "--timeout", "60s"] + [ + "kubectl", + "wait", + "--for", + "condition=Ready", + f"pod/{pod_name}", + "-n", + namespace, + "--timeout", + "60s", + ] ) assert wait_result.returncode == 0 @@ -169,5 +190,7 @@ def test_pod(): # Cleanup by deleting the pod that was creatd logger.info("Deleting test pod fixture...") - delete_result = subprocess.run(["kubectl", "delete", "pod", pod_name], check=True) + delete_result = subprocess.run( + ["kubectl", "delete", "pod", "-n", namespace, pod_name], check=True + ) assert delete_result.returncode == 0 diff --git a/tests/integration/setup/testpod_spec.yaml.template b/tests/integration/setup/testpod_spec.yaml.template index f802e8d..0e4a44d 100644 --- a/tests/integration/setup/testpod_spec.yaml.template +++ b/tests/integration/setup/testpod_spec.yaml.template @@ -5,6 +5,7 @@ apiVersion: v1 kind: Pod metadata: name: testpod + namespace: ${namespace} spec: containers: - image: ${kyuubi_image} diff --git a/tests/integration/test_charm.py b/tests/integration/test_charm.py index 5fe720b..9a707c5 100644 --- a/tests/integration/test_charm.py +++ b/tests/integration/test_charm.py @@ -3,6 +3,7 @@ # See LICENSE file for licensing details. import logging +import re import subprocess import time import uuid @@ -18,6 +19,7 @@ from constants import ( AUTHENTICATION_DATABASE_NAME, + HA_ZNODE_NAME, KYUUBI_CLIENT_RELATION_NAME, METASTORE_DATABASE_NAME, ) @@ -209,6 +211,7 @@ async def test_jdbc_endpoint_with_default_metastore(ops_test: OpsTest, test_pod) [ "./tests/integration/test_jdbc_endpoint.sh", test_pod, + ops_test.model_name, jdbc_endpoint, "db_default_metastore", "table_default_metastore", @@ -270,6 +273,7 @@ async def test_jdbc_endpoint_with_postgres_metastore(ops_test: OpsTest, test_pod [ "./tests/integration/test_jdbc_endpoint.sh", test_pod, + ops_test.model_name, jdbc_endpoint, "db_postgres_metastore", "table_postgres_metastore", @@ -357,6 +361,7 @@ async def test_jdbc_endpoint_after_removing_postgresql_metastore( [ "./tests/integration/test_jdbc_endpoint.sh", test_pod, + ops_test.model_name, jdbc_endpoint, "db_default_metastore_2", "table_default_metastore_2", @@ -450,6 +455,7 @@ async def test_jdbc_endpoint_no_credentials(ops_test: OpsTest, test_pod): [ "./tests/integration/test_jdbc_endpoint.sh", test_pod, + ops_test.model_name, jdbc_endpoint, "db_111", "table_111", @@ -487,6 +493,7 @@ async def test_jdbc_endpoint_invalid_credentials(ops_test: OpsTest, test_pod): [ "./tests/integration/test_jdbc_endpoint.sh", test_pod, + ops_test.model_name, jdbc_endpoint, "db_222", "table_222", @@ -535,6 +542,7 @@ async def test_jdbc_endpoint_valid_credentials(ops_test: OpsTest, test_pod): [ "./tests/integration/test_jdbc_endpoint.sh", test_pod, + ops_test.model_name, jdbc_endpoint, "db_333", "table_333", @@ -596,6 +604,7 @@ async def test_set_password_action(ops_test: OpsTest, test_pod): [ "./tests/integration/test_jdbc_endpoint.sh", test_pod, + ops_test.model_name, jdbc_endpoint, "db_444", "table_444", @@ -706,6 +715,7 @@ async def test_kyuubi_client_relation_joined(ops_test: OpsTest, test_pod, charm_ [ "./tests/integration/test_jdbc_endpoint.sh", test_pod, + ops_test.model_name, jdbc_endpoint, "db_666", "tbl_666", @@ -805,6 +815,7 @@ async def test_kyuubi_client_relation_removed(ops_test: OpsTest, test_pod, charm [ "./tests/integration/test_jdbc_endpoint.sh", test_pod, + ops_test.model_name, jdbc_endpoint, "db_777", "tbl_777", @@ -855,6 +866,7 @@ async def test_remove_authentication(ops_test: OpsTest, test_pod, charm_versions [ "./tests/integration/test_jdbc_endpoint.sh", test_pod, + ops_test.model_name, jdbc_endpoint, "db_555", "table_555", @@ -869,6 +881,113 @@ async def test_remove_authentication(ops_test: OpsTest, test_pod, charm_versions assert process.returncode == 0 +@pytest.mark.abort_on_fail +async def test_integration_with_zookeeper(ops_test: OpsTest, test_pod, charm_versions): + """Test the charm by integrating it with Zookeeper.""" + # Deploy the charm and wait for waiting status + logger.info("Deploying zookeeper-k8s charm...") + await ops_test.model.deploy(**charm_versions.zookeeper.deploy_dict()), + + logger.info("Waiting for zookeeper app to be active and idle...") + await ops_test.model.wait_for_idle( + apps=[APP_NAME, charm_versions.zookeeper.application_name], timeout=1000, status="active" + ) + + logger.info("Integrating kyuubi charm with zookeeper charm...") + await ops_test.model.integrate(charm_versions.zookeeper.application_name, APP_NAME) + + logger.info("Waiting for zookeeper-k8s and kyuubi charms to be idle idle...") + await ops_test.model.wait_for_idle( + apps=[APP_NAME, charm_versions.s3.application_name], timeout=1000, status="active" + ) + + logger.info("Running action 'get-jdbc-endpoint' on kyuubi-k8s unit...") + kyuubi_unit = ops_test.model.applications[APP_NAME].units[0] + action = await kyuubi_unit.run_action( + action_name="get-jdbc-endpoint", + ) + result = await action.wait() + + jdbc_endpoint = result.results.get("endpoint") + logger.info(f"JDBC endpoint: {jdbc_endpoint}") + + assert "serviceDiscoveryMode=zooKeeper" in jdbc_endpoint + assert f"zooKeeperNamespace={HA_ZNODE_NAME}" in jdbc_endpoint + assert re.match( + r"jdbc:hive2://(.*),(.*),(.*)/;serviceDiscoveryMode=zooKeeper;zooKeeperNamespace=.*", + jdbc_endpoint, + ) + + logger.info("Testing JDBC endpoint by connecting with beeline with no credentials ...") + process = subprocess.run( + [ + "./tests/integration/test_jdbc_endpoint.sh", + test_pod, + ops_test.model_name, + jdbc_endpoint, + "db_999", + "table_999", + ], + capture_output=True, + ) + print("========== test_jdbc_endpoint.sh STDOUT =================") + print(process.stdout.decode()) + print("========== test_jdbc_endpoint.sh STDERR =================") + print(process.stderr.decode()) + logger.info(f"JDBC endpoint test returned with status {process.returncode}") + assert process.returncode == 0 + + +@pytest.mark.abort_on_fail +async def test_remove_zookeeper_relation(ops_test: OpsTest, test_pod, charm_versions): + """Test the charm after the zookeeper relation has been broken.""" + logger.info("Removing relation between zookeeper-k8s and kyuubi-k8s...") + await ops_test.model.applications[APP_NAME].remove_relation( + f"{APP_NAME}:zookeeper", f"{charm_versions.zookeeper.application_name}:zookeeper" + ) + + logger.info("Waiting for zookeeper-k8s and kyuubi-k8s apps to be idle and active...") + await ops_test.model.wait_for_idle( + apps=[APP_NAME, charm_versions.zookeeper.application_name], timeout=1000, status="active" + ) + + logger.info("Running action 'get-jdbc-endpoint' on kyuubi-k8s unit...") + kyuubi_unit = ops_test.model.applications[APP_NAME].units[0] + action = await kyuubi_unit.run_action( + action_name="get-jdbc-endpoint", + ) + result = await action.wait() + + jdbc_endpoint = result.results.get("endpoint") + logger.info(f"JDBC endpoint: {jdbc_endpoint}") + + assert "serviceDiscoveryMode=zooKeeper" not in jdbc_endpoint + assert f"zooKeeperNamespace={HA_ZNODE_NAME}" not in jdbc_endpoint + assert not re.match( + r"jdbc:hive2://(.*),(.*),(.*)/;serviceDiscoveryMode=zooKeeper;zooKeeperNamespace=.*", + jdbc_endpoint, + ) + + logger.info("Testing JDBC endpoint by connecting with beeline with no credentials ...") + process = subprocess.run( + [ + "./tests/integration/test_jdbc_endpoint.sh", + test_pod, + ops_test.model_name, + jdbc_endpoint, + "db_101010", + "table_101010", + ], + capture_output=True, + ) + print("========== test_jdbc_endpoint.sh STDOUT =================") + print(process.stdout.decode()) + print("========== test_jdbc_endpoint.sh STDERR =================") + print(process.stderr.decode()) + logger.info(f"JDBC endpoint test returned with status {process.returncode}") + assert process.returncode == 0 + + @pytest.mark.skip(reason="This tests need re-write and fixes on integration hub level") @pytest.mark.abort_on_fail async def test_read_spark_properties_from_secrets(ops_test: OpsTest, test_pod): @@ -924,6 +1043,7 @@ async def test_read_spark_properties_from_secrets(ops_test: OpsTest, test_pod): [ "./tests/integration/test_jdbc_endpoint.sh", test_pod, + ops_test.model_name, jdbc_endpoint, "db_888", "table_888", diff --git a/tests/integration/test_jdbc_endpoint.sh b/tests/integration/test_jdbc_endpoint.sh index 8e46bf8..d307ce0 100755 --- a/tests/integration/test_jdbc_endpoint.sh +++ b/tests/integration/test_jdbc_endpoint.sh @@ -3,21 +3,22 @@ # See LICENSE file for licensing details. POD_NAME=$1 -JDBC_ENDPOINT=$2 -DB_NAME=${3:-testdb} -TABLE_NAME=${4:-testtable} -USERNAME=${5:-} -PASSWORD=${6:-} +NAMESPACE=${2:-default} +JDBC_ENDPOINT=$3 +DB_NAME=${4:-testdb} +TABLE_NAME=${5:-testtable} +USERNAME=${6:-} +PASSWORD=${7:-} SQL_COMMANDS=$(cat ./tests/integration/setup/test.sql | sed "s/db_name/$DB_NAME/g" | sed "s/table_name/$TABLE_NAME/g") if [ -z "${USERNAME}" ]; then - echo -e "$(kubectl exec $POD_NAME -- \ + echo -e "$(kubectl exec $POD_NAME -n $NAMESPACE -- \ env CMDS="$SQL_COMMANDS" ENDPOINT="$JDBC_ENDPOINT" \ /bin/bash -c 'echo "$CMDS" | /opt/kyuubi/bin/beeline -u $ENDPOINT' )" > /tmp/test_beeline.out else - echo -e "$(kubectl exec $POD_NAME -- \ + echo -e "$(kubectl exec $POD_NAME -n $NAMESPACE -- \ env CMDS="$SQL_COMMANDS" ENDPOINT="$JDBC_ENDPOINT" USER="$USERNAME" PASSWD="$PASSWORD"\ /bin/bash -c 'echo "$CMDS" | /opt/kyuubi/bin/beeline -u $ENDPOINT -n $USER -p $PASSWD' )" > /tmp/test_beeline.out diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 2c54656..9cc8537 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,7 +7,12 @@ from scenario.state import next_relation_id from charm import KyuubiCharm -from constants import KYUUBI_CONTAINER_NAME, S3_INTEGRATOR_REL, SPARK_SERVICE_ACCOUNT_REL +from constants import ( + KYUUBI_CONTAINER_NAME, + S3_INTEGRATOR_REL, + SPARK_SERVICE_ACCOUNT_REL, + ZOOKEEPER_REL, +) @pytest.fixture @@ -93,3 +98,23 @@ def spark_service_account_relation(): local_app_data={"service-account": "kyuubi", "namespace": "spark"}, remote_app_data={"service-account": "kyuubi", "namespace": "spark"}, ) + + +@pytest.fixture +def zookeeper_relation(): + """Provide fixture for the Zookeeper relation.""" + relation_id = next_relation_id(update=True) + + return Relation( + endpoint=ZOOKEEPER_REL, + interface="zookeeper", + remote_app_name="zookeeper-k8s", + relation_id=relation_id, + local_app_data={"database": "/kyuubi"}, + remote_app_data={ + "uris": "host1:2181,host2:2181,host3:2181", + "username": "foobar", + "password": "foopassbarword", + "database": "/kyuubi", + }, + ) diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py index 46f86b0..3f8984b 100644 --- a/tests/unit/test_charm.py +++ b/tests/unit/test_charm.py @@ -25,6 +25,15 @@ def parse_spark_properties(tmp_path: Path) -> dict[str, str]: ) +def parse_kyuubi_configurations(tmp_path: Path) -> dict[str, str]: + """Parse and return Kyuubi configurations from the conf file in the container.""" + file_path = tmp_path / Path(KyuubiWorkload.KYUUBI_CONFIGURATION_FILE).relative_to("/opt") + with file_path.open("r") as fid: + return dict( + row.rsplit("=", maxsplit=1) for line in fid.readlines() if (row := line.strip()) + ) + + def test_start_kyuubi(kyuubi_context): state = State( config={}, @@ -204,6 +213,89 @@ def test_object_storage_backend_removed( assert state_after_relation_broken.unit_status == Status.MISSING_OBJECT_STORAGE_BACKEND.value +@patch("managers.s3.S3Manager.verify", return_value=True) +@patch("managers.k8s.K8sManager.is_namespace_valid", return_value=True) +@patch("managers.k8s.K8sManager.is_service_account_valid", return_value=True) +@patch("managers.k8s.K8sManager.is_s3_configured", return_value=True) +@patch("config.spark.SparkConfig._get_spark_master", return_value="k8s://https://spark.master") +@patch("config.spark.SparkConfig._sa_conf", return_value={}) +def test_zookeeper_relation_joined( + mock_sa_conf, + mock_get_master, + mock_s3_configured, + mock_valid_sa, + mock_valid_ns, + mock_s3_verify, + tmp_path, + kyuubi_context, + kyuubi_container, + s3_relation, + spark_service_account_relation, + zookeeper_relation, +): + state = State( + relations=[s3_relation, spark_service_account_relation, zookeeper_relation], + containers=[kyuubi_container], + ) + out = kyuubi_context.run(zookeeper_relation.changed_event, state) + assert out.unit_status == Status.ACTIVE.value + + kyuubi_configurations = parse_kyuubi_configurations(tmp_path) + + # Assert some of the keys + assert ( + kyuubi_configurations["kyuubi.ha.namespace"] + == zookeeper_relation.remote_app_data["database"] + ) + assert ( + kyuubi_configurations["kyuubi.ha.addresses"] == zookeeper_relation.remote_app_data["uris"] + ) + assert kyuubi_configurations["kyuubi.ha.zookeeper.auth.type"] == "DIGEST" + assert ( + kyuubi_configurations["kyuubi.ha.zookeeper.auth.digest"] + == f"{zookeeper_relation.remote_app_data['username']}:{zookeeper_relation.remote_app_data['password']}" + ) + + +@patch("managers.s3.S3Manager.verify", return_value=True) +@patch("managers.k8s.K8sManager.is_namespace_valid", return_value=True) +@patch("managers.k8s.K8sManager.is_service_account_valid", return_value=True) +@patch("managers.k8s.K8sManager.is_s3_configured", return_value=True) +@patch("config.spark.SparkConfig._get_spark_master", return_value="k8s://https://spark.master") +@patch("config.spark.SparkConfig._sa_conf", return_value={}) +def test_zookeeper_relation_broken( + mock_sa_conf, + mock_get_master, + mock_s3_configured, + mock_valid_sa, + mock_valid_ns, + mock_s3_verify, + tmp_path, + kyuubi_context, + kyuubi_container, + s3_relation, + spark_service_account_relation, + zookeeper_relation, +): + state = State( + relations=[s3_relation, spark_service_account_relation, zookeeper_relation], + containers=[kyuubi_container], + ) + state_after_relation_changed = kyuubi_context.run(zookeeper_relation.changed_event, state) + state_after_relation_broken = kyuubi_context.run( + zookeeper_relation.broken_event, state_after_relation_changed + ) + assert state_after_relation_broken.unit_status == Status.ACTIVE.value + + kyuubi_configurations = parse_kyuubi_configurations(tmp_path) + + # Assert HA configurations do not exist in Kyuubi configuration file + assert "kyuubi.ha.namespace" not in kyuubi_configurations + assert "kyuubi.ha.addresses" not in kyuubi_configurations + assert "kyuubi.ha.zookeeper.auth.type" not in kyuubi_configurations + assert "kyuubi.ha.zookeeper.auth.digest" not in kyuubi_configurations + + @patch("managers.s3.S3Manager.verify", return_value=True) @patch("managers.k8s.K8sManager.is_namespace_valid", return_value=True) @patch("managers.k8s.K8sManager.is_service_account_valid", return_value=True)