diff --git a/src/palace/manager/api/admin/controller/integration_settings.py b/src/palace/manager/api/admin/controller/integration_settings.py index fe4cf36c9f..ff9824bfd2 100644 --- a/src/palace/manager/api/admin/controller/integration_settings.py +++ b/src/palace/manager/api/admin/controller/integration_settings.py @@ -106,10 +106,18 @@ def configured_service_info( f"IntegrationConfiguration {service.name}({service.id}) has no goal set. Skipping." ) return None + + if service.protocol is None or service.protocol not in self.registry: + self.log.warning( + f"Unknown protocol: {service.protocol} for goal {self.registry.goal}" + ) + return None + + canonical_protocol = self.registry.canonicalize(service.protocol) return { "id": service.id, "name": service.name, - "protocol": service.protocol, + "protocol": canonical_protocol, "settings": service.settings_dict, "goal": service.goal.value, } @@ -138,12 +146,6 @@ def configured_services(self) -> list[dict[str, Any]]: .filter(IntegrationConfiguration.goal == self.registry.goal) .order_by(IntegrationConfiguration.name) ): - if service.protocol not in self.registry: - self.log.warning( - f"Unknown protocol: {service.protocol} for goal {self.registry.goal}" - ) - continue - service_info = self.configured_service_info(service) if service_info is None: continue @@ -168,9 +170,9 @@ def get_existing_service( """ Query for an existing service to edit. - Raises ProblemError if the service doesn't exist, or if the protocol + Raises ProblemDetailException if the service doesn't exist, or if the protocol doesn't match. If the name is provided, the service will be renamed if - necessary and a ProblemError will be raised if the name is already in + necessary and a ProblemDetailException will be raised if the name is already in use. """ service: IntegrationConfiguration | None = get_one( @@ -181,8 +183,12 @@ def get_existing_service( ) if service is None: raise ProblemDetailException(MISSING_SERVICE) - if protocol is not None and service.protocol != protocol: + if protocol is not None and not self.registry.equivalent( + service.protocol, protocol + ): raise ProblemDetailException(CANNOT_CHANGE_PROTOCOL) + if service.protocol is None or service.protocol not in self.registry: + raise ProblemDetailException(UNKNOWN_PROTOCOL) if name is not None and service.name != name: service_with_name = get_one(self._db, IntegrationConfiguration, name=name) if service_with_name is not None: @@ -251,13 +257,8 @@ def get_service( if protocol is None and _id is None: raise ProblemDetailException(NO_PROTOCOL_FOR_NEW_SERVICE) - # Lookup the protocol class to make sure it exists - # this will raise a ProblemError if the protocol is unknown - self.get_protocol_class(protocol) - - # This should never happen, due to the call to get_protocol_class but - # mypy doesn't know that, so we make sure that protocol is not None before we use it. - assert protocol is not None + if protocol is None or protocol not in self.registry: + raise ProblemDetailException(UNKNOWN_PROTOCOL) if _id is not None: # Find an existing service to edit diff --git a/src/palace/manager/service/integration_registry/base.py b/src/palace/manager/service/integration_registry/base.py index d3eee73e73..badb9a809e 100644 --- a/src/palace/manager/service/integration_registry/base.py +++ b/src/palace/manager/service/integration_registry/base.py @@ -2,18 +2,23 @@ from collections import defaultdict from collections.abc import Iterator -from typing import Generic, TypeVar, overload +from typing import Generic, Literal, TypeVar, cast, overload +from palace.manager.core.exceptions import BasePalaceException from palace.manager.integration.goals import Goals T = TypeVar("T", covariant=True) V = TypeVar("V") -class IntegrationRegistryException(ValueError): +class RegistrationException(BasePalaceException, ValueError): """An error occurred while registering an integration.""" +class LookupException(BasePalaceException, LookupError): + """An error occurred while looking up an integration.""" + + class IntegrationRegistry(Generic[T]): def __init__(self, goal: Goals, integrations: dict[str, type[T]] | None = None): """Initialize a new IntegrationRegistry.""" @@ -52,7 +57,7 @@ def register( for protocol in [canonical] + aliases: if protocol in self._lookup and self._lookup[protocol] != integration: - raise IntegrationRegistryException( + raise RegistrationException( f"Integration {protocol} already registered" ) self._lookup[protocol] = integration @@ -78,18 +83,25 @@ def get(self, protocol: str, default: V | None = None) -> type[T] | V | None: def get_protocol(self, integration: type[T], default: None = ...) -> str | None: ... + @overload + def get_protocol(self, integration: type[T], default: Literal[False]) -> str: + ... + @overload def get_protocol(self, integration: type[T], default: V) -> str | V: ... def get_protocol( - self, integration: type[T], default: V | None = None + self, integration: type[T], default: V | None | Literal[False] = None ) -> str | V | None: """Look up the canonical protocol for an integration class.""" names = self.get_protocols(integration, default) - if not isinstance(names, list): - return default - return names[0] + # We have to cast here because mypy doesn't understand that + # if default is False, names is a list[str] due to the overload + # for get_protocols. + if names is default: + return cast(V | None, names) + return cast(list[str], names)[0] @overload def get_protocols( @@ -97,15 +109,21 @@ def get_protocols( ) -> list[str] | None: ... + @overload + def get_protocols(self, integration: type[T], default: Literal[False]) -> list[str]: + ... + @overload def get_protocols(self, integration: type[T], default: V) -> list[str] | V: ... def get_protocols( - self, integration: type[T], default: V | None = None + self, integration: type[T], default: V | None | Literal[False] = None ) -> list[str] | V | None: """Look up all protocols for an integration class.""" if integration not in self._reverse_lookup: + if default is False: + raise LookupException(f"Integration {integration} not found") return default return self._reverse_lookup[integration] @@ -117,7 +135,7 @@ def integrations(self) -> set[type[T]]: def update(self, other: IntegrationRegistry[T]) -> None: """Update registry to include integrations in other.""" if self.goal != other.goal: - raise IntegrationRegistryException( + raise RegistrationException( f"IntegrationRegistry's goals must be the same. (Self: {self.goal}, Other: {other.goal})" ) @@ -126,13 +144,32 @@ def update(self, other: IntegrationRegistry[T]) -> None: assert isinstance(names, list) self.register(integration, canonical=names[0], aliases=names[1:]) + def canonicalize(self, protocol: str) -> str: + """Return the canonical protocol name for a given protocol.""" + return self.get_protocol(self[protocol], default=False) + + def equivalent(self, protocol1: str | None, protocol2: str | None) -> bool: + """Return whether two protocols are equivalent.""" + if ( + protocol1 is None + or protocol1 not in self + or protocol2 is None + or protocol2 not in self + ): + return False + + return self[protocol1] is self[protocol2] + def __iter__(self) -> Iterator[tuple[str, type[T]]]: for integration, names in self._reverse_lookup.items(): yield names[0], integration def __getitem__(self, protocol: str) -> type[T]: """Look up an integration class by protocol, using the [] operator.""" - return self._lookup[protocol] + try: + return self._lookup[protocol] + except KeyError as e: + raise LookupException(f"Integration {protocol} not found") from e def __len__(self) -> int: """Return the number of registered integration classes.""" diff --git a/tests/manager/api/admin/controller/test_integration_settings.py b/tests/manager/api/admin/controller/test_integration_settings.py new file mode 100644 index 0000000000..6528d4063e --- /dev/null +++ b/tests/manager/api/admin/controller/test_integration_settings.py @@ -0,0 +1,122 @@ +from functools import partial + +import pytest + +from palace.manager.api.admin.controller.integration_settings import ( + IntegrationSettingsController, +) +from palace.manager.integration.base import HasIntegrationConfiguration +from palace.manager.integration.goals import Goals +from palace.manager.integration.settings import BaseSettings +from palace.manager.service.integration_registry.base import IntegrationRegistry +from palace.manager.sqlalchemy.model.integration import IntegrationConfiguration +from palace.manager.util.problem_detail import ProblemDetailException +from tests.fixtures.database import DatabaseTransactionFixture + + +class MockIntegrationBase(HasIntegrationConfiguration): + @classmethod + def label(cls) -> str: + return cls.__name__ + + @classmethod + def description(cls) -> str: + return "A mock integration" + + @classmethod + def settings_class(cls) -> type[BaseSettings]: + return BaseSettings + + +class MockIntegration1(MockIntegrationBase): + ... + + +class MockIntegration2(MockIntegrationBase): + ... + + +class MockIntegration3(MockIntegrationBase): + ... + + +class MockController(IntegrationSettingsController[MockIntegrationBase]): + ... + + +class IntegrationSettingsControllerFixture: + def __init__(self, db: DatabaseTransactionFixture) -> None: + self._db = db + self.goal = Goals.PATRON_AUTH_GOAL + self.registry: IntegrationRegistry[MockIntegrationBase] = IntegrationRegistry( + self.goal + ) + + self.registry.register(MockIntegration1, canonical="mock_integration_1") + self.registry.register(MockIntegration2, canonical="mock_integration_2") + self.registry.register( + MockIntegration3, aliases=["mock_integration_3", "mockIntegration3"] + ) + + self.controller = MockController(db.session, self.registry) + + self.integration_configuration = partial( + db.integration_configuration, goal=self.goal + ) + + +@pytest.fixture +def integration_settings_controller_fixture( + db: DatabaseTransactionFixture, +) -> IntegrationSettingsControllerFixture: + return IntegrationSettingsControllerFixture(db) + + +class TestIntegrationSettingsController: + def test_configured_service_info( + self, + integration_settings_controller_fixture: IntegrationSettingsControllerFixture, + ): + controller = integration_settings_controller_fixture.controller + integration_configuration = ( + integration_settings_controller_fixture.integration_configuration + ) + integration = integration_configuration("mock_integration_3") + assert controller.configured_service_info(integration) == { + "id": integration.id, + "name": integration.name, + "goal": integration_settings_controller_fixture.goal.value, + "protocol": "MockIntegration3", + "settings": integration.settings_dict, + } + + # Integration protocol is not registered + integration = integration_configuration("mock_integration_4") + assert controller.configured_service_info(integration) is None + + # Integration has no protocol set + integration = IntegrationConfiguration() + assert controller.configured_service_info(integration) is None + + def test_get_existing_service( + self, + integration_settings_controller_fixture: IntegrationSettingsControllerFixture, + db: DatabaseTransactionFixture, + ): + controller = integration_settings_controller_fixture.controller + integration_configuration = ( + integration_settings_controller_fixture.integration_configuration + ) + integration = integration_configuration("MockIntegration1") + assert integration.id is not None + assert controller.get_existing_service(integration.id) is integration + assert ( + controller.get_existing_service( + integration.id, protocol="mock_integration_1" + ) + is integration + ) + with pytest.raises(ProblemDetailException, match="Cannot change protocol"): + controller.get_existing_service( + integration.id, protocol="mock_integration_2" + ) diff --git a/tests/manager/service/integration_registry/test_base.py b/tests/manager/service/integration_registry/test_base.py index 7d0a525337..c5f3d9ae97 100644 --- a/tests/manager/service/integration_registry/test_base.py +++ b/tests/manager/service/integration_registry/test_base.py @@ -5,7 +5,8 @@ from palace.manager.integration.goals import Goals from palace.manager.service.integration_registry.base import ( IntegrationRegistry, - IntegrationRegistryException, + LookupException, + RegistrationException, ) @@ -47,7 +48,7 @@ def test_registry_register_raises_value_error_if_name_already_registered( registry.register(object) # registering a different object with the same name raises an error - with pytest.raises(IntegrationRegistryException): + with pytest.raises(RegistrationException): registry.register(list, canonical="object") @@ -100,7 +101,7 @@ def test_registry_get_returns_default_if_name_not_registered( assert registry.get("test_class", "default") == "default" # __get__ throws KeyError - with pytest.raises(KeyError): + with pytest.raises(LookupException): _ = registry["test_class"] @@ -113,6 +114,46 @@ def test_registry_get_protocol_returns_default_if_integration_not_registered( # default is not none assert registry.get_protocol(object, "default") == "default" + # default is a list + assert registry.get_protocol(object, ["default"]) == ["default"] + + # If default is False, raises exception + with pytest.raises(LookupException): + registry.get_protocol(object, False) + + +def test_registry_canonicalize(registry: IntegrationRegistry): + """Test that canonicalize() works.""" + registry.register(object, canonical="test") + assert registry.canonicalize("test") == "test" + assert registry.canonicalize("object") == "test" + + with pytest.raises(LookupException): + registry.canonicalize("not_registered") + + +@pytest.mark.parametrize( + "protocol1, protocol2, expected", + [ + ("test", "test", True), + ("object", "test", True), + ("list", "list", True), + ("object", "list", False), + ("object", "not_registered", False), + ("not_registered", "not_registered", False), + ("not_registered1", "not_registered2", False), + ], +) +def test_registry_equivalent( + protocol1: str, protocol2: str, expected: bool, registry: IntegrationRegistry +): + """Test that equivalent() works.""" + registry.register(object, canonical="test") + registry.register(list) + + assert registry.equivalent(protocol1, protocol2) is expected + assert registry.equivalent(protocol2, protocol1) is expected + def test_registry_update(): """Test that update() works.""" @@ -138,7 +179,7 @@ def test_registry_update_raises_different_goals(): registry = IntegrationRegistry(Goals.PATRON_AUTH_GOAL) registry2 = IntegrationRegistry(Goals.LICENSE_GOAL) - with pytest.raises(IntegrationRegistryException): + with pytest.raises(RegistrationException): registry.update(registry2) @@ -171,7 +212,7 @@ def test_registry_add_errors(): registry = IntegrationRegistry(Goals.PATRON_AUTH_GOAL) registry2 = IntegrationRegistry(Goals.LICENSE_GOAL) - with pytest.raises(IntegrationRegistryException): + with pytest.raises(RegistrationException): registry + registry2 with pytest.raises(TypeError):