Skip to content

Commit

Permalink
Fix bug where non-canonical protocol names were causing settings not …
Browse files Browse the repository at this point in the history
…to show up for existing protocols.
  • Loading branch information
jonathangreen committed Sep 5, 2024
1 parent 6d93cda commit 871cb3c
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 32 deletions.
35 changes: 18 additions & 17 deletions src/palace/manager/api/admin/controller/integration_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,18 @@ def configured_service_info(
f"IntegrationConfiguration {service.name}({service.id}) has no goal set. Skipping."
)
return None

if service.protocol is None or service.protocol not in self.registry:
self.log.warning(
f"Unknown protocol: {service.protocol} for goal {self.registry.goal}"
)
return None

canonical_protocol = self.registry.canonicalize(service.protocol)
return {
"id": service.id,
"name": service.name,
"protocol": service.protocol,
"protocol": canonical_protocol,
"settings": service.settings_dict,
"goal": service.goal.value,
}
Expand Down Expand Up @@ -138,12 +146,6 @@ def configured_services(self) -> list[dict[str, Any]]:
.filter(IntegrationConfiguration.goal == self.registry.goal)
.order_by(IntegrationConfiguration.name)
):
if service.protocol not in self.registry:
self.log.warning(
f"Unknown protocol: {service.protocol} for goal {self.registry.goal}"
)
continue

service_info = self.configured_service_info(service)
if service_info is None:
continue
Expand All @@ -168,9 +170,9 @@ def get_existing_service(
"""
Query for an existing service to edit.
Raises ProblemError if the service doesn't exist, or if the protocol
Raises ProblemDetailException if the service doesn't exist, or if the protocol
doesn't match. If the name is provided, the service will be renamed if
necessary and a ProblemError will be raised if the name is already in
necessary and a ProblemDetailException will be raised if the name is already in
use.
"""
service: IntegrationConfiguration | None = get_one(
Expand All @@ -181,8 +183,12 @@ def get_existing_service(
)
if service is None:
raise ProblemDetailException(MISSING_SERVICE)
if protocol is not None and service.protocol != protocol:
if protocol is not None and not self.registry.equivalent(
service.protocol, protocol
):
raise ProblemDetailException(CANNOT_CHANGE_PROTOCOL)
if service.protocol is None or service.protocol not in self.registry:
raise ProblemDetailException(UNKNOWN_PROTOCOL)
if name is not None and service.name != name:
service_with_name = get_one(self._db, IntegrationConfiguration, name=name)
if service_with_name is not None:
Expand Down Expand Up @@ -251,13 +257,8 @@ def get_service(
if protocol is None and _id is None:
raise ProblemDetailException(NO_PROTOCOL_FOR_NEW_SERVICE)

# Lookup the protocol class to make sure it exists
# this will raise a ProblemError if the protocol is unknown
self.get_protocol_class(protocol)

# This should never happen, due to the call to get_protocol_class but
# mypy doesn't know that, so we make sure that protocol is not None before we use it.
assert protocol is not None
if protocol is None or protocol not in self.registry:
raise ProblemDetailException(UNKNOWN_PROTOCOL)

if _id is not None:
# Find an existing service to edit
Expand Down
57 changes: 47 additions & 10 deletions src/palace/manager/service/integration_registry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,23 @@

from collections import defaultdict
from collections.abc import Iterator
from typing import Generic, TypeVar, overload
from typing import Generic, Literal, TypeVar, cast, overload

from palace.manager.core.exceptions import BasePalaceException
from palace.manager.integration.goals import Goals

T = TypeVar("T", covariant=True)
V = TypeVar("V")


class IntegrationRegistryException(ValueError):
class RegistrationException(BasePalaceException, ValueError):
"""An error occurred while registering an integration."""


class LookupException(BasePalaceException, LookupError):
"""An error occurred while looking up an integration."""


class IntegrationRegistry(Generic[T]):
def __init__(self, goal: Goals, integrations: dict[str, type[T]] | None = None):
"""Initialize a new IntegrationRegistry."""
Expand Down Expand Up @@ -52,7 +57,7 @@ def register(

for protocol in [canonical] + aliases:
if protocol in self._lookup and self._lookup[protocol] != integration:
raise IntegrationRegistryException(
raise RegistrationException(
f"Integration {protocol} already registered"
)
self._lookup[protocol] = integration
Expand All @@ -78,34 +83,47 @@ def get(self, protocol: str, default: V | None = None) -> type[T] | V | None:
def get_protocol(self, integration: type[T], default: None = ...) -> str | None:
...

@overload
def get_protocol(self, integration: type[T], default: Literal[False]) -> str:
...

@overload
def get_protocol(self, integration: type[T], default: V) -> str | V:
...

def get_protocol(
self, integration: type[T], default: V | None = None
self, integration: type[T], default: V | None | Literal[False] = None
) -> str | V | None:
"""Look up the canonical protocol for an integration class."""
names = self.get_protocols(integration, default)
if not isinstance(names, list):
return default
return names[0]
# We have to cast here because mypy doesn't understand that
# if default is False, names is a list[str] due to the overload
# for get_protocols.
if names is default:
return cast(V | None, names)
return cast(list[str], names)[0]

@overload
def get_protocols(
self, integration: type[T], default: None = ...
) -> list[str] | None:
...

@overload
def get_protocols(self, integration: type[T], default: Literal[False]) -> list[str]:
...

@overload
def get_protocols(self, integration: type[T], default: V) -> list[str] | V:
...

def get_protocols(
self, integration: type[T], default: V | None = None
self, integration: type[T], default: V | None | Literal[False] = None
) -> list[str] | V | None:
"""Look up all protocols for an integration class."""
if integration not in self._reverse_lookup:
if default is False:
raise LookupException(f"Integration {integration} not found")
return default
return self._reverse_lookup[integration]

Expand All @@ -117,7 +135,7 @@ def integrations(self) -> set[type[T]]:
def update(self, other: IntegrationRegistry[T]) -> None:
"""Update registry to include integrations in other."""
if self.goal != other.goal:
raise IntegrationRegistryException(
raise RegistrationException(
f"IntegrationRegistry's goals must be the same. (Self: {self.goal}, Other: {other.goal})"
)

Expand All @@ -126,13 +144,32 @@ def update(self, other: IntegrationRegistry[T]) -> None:
assert isinstance(names, list)
self.register(integration, canonical=names[0], aliases=names[1:])

def canonicalize(self, protocol: str) -> str:
"""Return the canonical protocol name for a given protocol."""
return self.get_protocol(self[protocol], default=False)

def equivalent(self, protocol1: str | None, protocol2: str | None) -> bool:
"""Return whether two protocols are equivalent."""
if (
protocol1 is None
or protocol1 not in self
or protocol2 is None
or protocol2 not in self
):
return False

return self[protocol1] is self[protocol2]

def __iter__(self) -> Iterator[tuple[str, type[T]]]:
for integration, names in self._reverse_lookup.items():
yield names[0], integration

def __getitem__(self, protocol: str) -> type[T]:
"""Look up an integration class by protocol, using the [] operator."""
return self._lookup[protocol]
try:
return self._lookup[protocol]
except KeyError as e:
raise LookupException(f"Integration {protocol} not found") from e

def __len__(self) -> int:
"""Return the number of registered integration classes."""
Expand Down
122 changes: 122 additions & 0 deletions tests/manager/api/admin/controller/test_integration_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from functools import partial

import pytest

from palace.manager.api.admin.controller.integration_settings import (
IntegrationSettingsController,
)
from palace.manager.integration.base import HasIntegrationConfiguration
from palace.manager.integration.goals import Goals
from palace.manager.integration.settings import BaseSettings
from palace.manager.service.integration_registry.base import IntegrationRegistry
from palace.manager.sqlalchemy.model.integration import IntegrationConfiguration
from palace.manager.util.problem_detail import ProblemDetailException
from tests.fixtures.database import DatabaseTransactionFixture


class MockIntegrationBase(HasIntegrationConfiguration):
@classmethod
def label(cls) -> str:
return cls.__name__

@classmethod
def description(cls) -> str:
return "A mock integration"

@classmethod
def settings_class(cls) -> type[BaseSettings]:
return BaseSettings


class MockIntegration1(MockIntegrationBase):
...


class MockIntegration2(MockIntegrationBase):
...


class MockIntegration3(MockIntegrationBase):
...


class MockController(IntegrationSettingsController[MockIntegrationBase]):
...


class IntegrationSettingsControllerFixture:
def __init__(self, db: DatabaseTransactionFixture) -> None:
self._db = db
self.goal = Goals.PATRON_AUTH_GOAL
self.registry: IntegrationRegistry[MockIntegrationBase] = IntegrationRegistry(
self.goal
)

self.registry.register(MockIntegration1, canonical="mock_integration_1")
self.registry.register(MockIntegration2, canonical="mock_integration_2")
self.registry.register(
MockIntegration3, aliases=["mock_integration_3", "mockIntegration3"]
)

self.controller = MockController(db.session, self.registry)

self.integration_configuration = partial(
db.integration_configuration, goal=self.goal
)


@pytest.fixture
def integration_settings_controller_fixture(
db: DatabaseTransactionFixture,
) -> IntegrationSettingsControllerFixture:
return IntegrationSettingsControllerFixture(db)


class TestIntegrationSettingsController:
def test_configured_service_info(
self,
integration_settings_controller_fixture: IntegrationSettingsControllerFixture,
):
controller = integration_settings_controller_fixture.controller
integration_configuration = (
integration_settings_controller_fixture.integration_configuration
)
integration = integration_configuration("mock_integration_3")
assert controller.configured_service_info(integration) == {
"id": integration.id,
"name": integration.name,
"goal": integration_settings_controller_fixture.goal.value,
"protocol": "MockIntegration3",
"settings": integration.settings_dict,
}

# Integration protocol is not registered
integration = integration_configuration("mock_integration_4")
assert controller.configured_service_info(integration) is None

# Integration has no protocol set
integration = IntegrationConfiguration()
assert controller.configured_service_info(integration) is None

def test_get_existing_service(
self,
integration_settings_controller_fixture: IntegrationSettingsControllerFixture,
db: DatabaseTransactionFixture,
):
controller = integration_settings_controller_fixture.controller
integration_configuration = (
integration_settings_controller_fixture.integration_configuration
)
integration = integration_configuration("MockIntegration1")
assert integration.id is not None
assert controller.get_existing_service(integration.id) is integration
assert (
controller.get_existing_service(
integration.id, protocol="mock_integration_1"
)
is integration
)
with pytest.raises(ProblemDetailException, match="Cannot change protocol"):
controller.get_existing_service(
integration.id, protocol="mock_integration_2"
)
Loading

0 comments on commit 871cb3c

Please sign in to comment.