Skip to content

Commit

Permalink
Set application name in PG connection string (#1828)
Browse files Browse the repository at this point in the history
* Set application name in PG connection string. That way we will get more info from RDS stats.
* Make sure scripts get application_name set as well.
  • Loading branch information
jonathangreen authored May 10, 2024
1 parent 8c58d16 commit 7f01ee6
Show file tree
Hide file tree
Showing 13 changed files with 85 additions and 29 deletions.
2 changes: 1 addition & 1 deletion bin/configuration/add_saml_federations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from palace.manager.sqlalchemy.model.saml import SAMLFederation
from palace.manager.sqlalchemy.session import production_session

with closing(production_session()) as db:
with closing(production_session("add_saml_federations")) as db:
incommon_federation = (
db.query(SAMLFederation)
.filter(SAMLFederation.type == incommon.FEDERATION_TYPE)
Expand Down
2 changes: 1 addition & 1 deletion bin/hold_notifications
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from palace.manager.core.jobs.holds_notification import HoldsNotificationMonitor
from palace.manager.sqlalchemy.session import production_session

HoldsNotificationMonitor(production_session()).run()
HoldsNotificationMonitor(production_session(HoldsNotificationMonitor)).run()
3 changes: 1 addition & 2 deletions bin/integration_test
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@


from palace.manager.core.jobs.integration_test import IntegrationTest
from palace.manager.sqlalchemy.session import production_session

IntegrationTest(production_session(initialize_data=False)).run()
IntegrationTest().run()
4 changes: 3 additions & 1 deletion bin/patron_activity_sync_notifications
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ from palace.manager.core.jobs.patron_activity_sync import (
)
from palace.manager.sqlalchemy.session import production_session

PatronActivitySyncNotificationScript(production_session()).run()
PatronActivitySyncNotificationScript(
production_session(PatronActivitySyncNotificationScript)
).run()
3 changes: 1 addition & 2 deletions bin/playtime_reporting
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@


from palace.manager.core.jobs.playtime_entries import PlaytimeEntriesEmailReportsScript
from palace.manager.sqlalchemy.session import production_session

PlaytimeEntriesEmailReportsScript(production_session(initialize_data=False)).run()
PlaytimeEntriesEmailReportsScript().run()
3 changes: 1 addition & 2 deletions bin/playtime_summation
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@


from palace.manager.core.jobs.playtime_entries import PlaytimeEntriesSummationScript
from palace.manager.sqlalchemy.session import production_session

PlaytimeEntriesSummationScript(production_session(initialize_data=False)).run()
PlaytimeEntriesSummationScript().run()
2 changes: 1 addition & 1 deletion src/palace/manager/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def initialize_circulation_manager():


def initialize_database():
session_factory = SessionManager.sessionmaker()
session_factory = SessionManager.sessionmaker(application_name="manager")
_db = flask_scoped_session(session_factory, app)
app._db = _db

Expand Down
4 changes: 3 additions & 1 deletion src/palace/manager/celery/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def session_maker(self) -> sessionmaker[Session]:
worker utilization in production.
"""
if self._session_maker is None:
engine = SessionManager.engine(poolclass=NullPool)
engine = SessionManager.engine(
poolclass=NullPool, application_name=self.name
)
maker = sessionmaker(bind=engine)
SessionManager.setup_event_listener(maker)
self._session_maker = maker
Expand Down
2 changes: 1 addition & 1 deletion src/palace/manager/core/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Script:
@property
def _db(self) -> Session:
if not hasattr(self, "_session"):
self._session = production_session()
self._session = production_session(self.__class__)
return self._session

@property
Expand Down
37 changes: 28 additions & 9 deletions src/palace/manager/sqlalchemy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pydantic.json import pydantic_encoder
from sqlalchemy import create_engine, event, literal_column, select, table, text
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine import Connection, Engine, make_url
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import Pool

Expand All @@ -19,6 +19,7 @@
from palace.manager.sqlalchemy.model.key import Key
from palace.manager.sqlalchemy.util import get_one_or_create
from palace.manager.util.datetime_helpers import utc_now
from palace.manager.util.log import LoggerMixin
from palace.manager.util.resources import resources_dir

DEBUG = False
Expand All @@ -37,16 +38,30 @@ def json_serializer(*args, **kwargs) -> str:
return json.dumps(*args, default=json_encoder, **kwargs)


class SessionManager:
class SessionManager(LoggerMixin):
# A function that calculates recursively equivalent identifiers
# is also defined in SQL.
RECURSIVE_EQUIVALENTS_FUNCTION = "recursive_equivalents.sql"

@classmethod
def engine(
cls, url: str | None = None, poolclass: type[Pool] | None = None
cls,
url: str | None = None,
poolclass: type[Pool] | None = None,
application_name: str | None = None,
) -> Engine:
url = url or Configuration.database_url()
url_obj = make_url(url)
if application_name is not None:
if "application_name" in url_obj.query.keys():
cls.logger().warning(
"Overwriting existing application_name in database URL "
f"({url_obj.render_as_string(hide_password=True)}) with {application_name}"
)
url = url_obj.set(
query={**url_obj.query, "application_name": application_name}
).render_as_string(hide_password=False)

return create_engine(
url,
echo=DEBUG,
Expand All @@ -63,8 +78,8 @@ def setup_event_listener(
return session

@classmethod
def sessionmaker(cls):
bind_obj = cls.engine()
def sessionmaker(cls, application_name: str | None = None):
bind_obj = cls.engine(application_name=application_name)
session_factory = sessionmaker(bind=bind_obj)
cls.setup_event_listener(session_factory)
return session_factory
Expand All @@ -76,8 +91,8 @@ def initialize_schema(cls, engine):
Base.metadata.create_all(engine)

@classmethod
def session(cls, url, initialize_data=True, initialize_schema=True):
engine = cls.engine(url)
def session(cls, url: str | None = None, application_name: str | None = None):
engine = cls.engine(url, application_name=application_name)
connection = engine.connect()
return cls.session_from_connection(connection)

Expand Down Expand Up @@ -164,10 +179,14 @@ def initialize_data(cls, session: Session):
return session


def production_session(initialize_data=True) -> Session:
def production_session(application_name: type[object] | str) -> Session:
if isinstance(application_name, str):
application_name = application_name
else:
application_name = f"{application_name.__module__}.{application_name.__name__}"
url = Configuration.database_url()
if url.startswith('"'):
url = url[1:]
logging.debug("Database url: %s", url)
_db = SessionManager.session(url, initialize_data=initialize_data)
_db = SessionManager.session(url, application_name=application_name)
return _db
2 changes: 1 addition & 1 deletion tests/fixtures/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(self, database_name: DatabaseCreationFixture) -> None:
self.connection = self.engine.connect()

def engine_factory(self) -> Engine:
return SessionManager.engine(self.database_name.url)
return SessionManager.engine(self.database_name.url, application_name="test")

def drop_existing_schema(self) -> None:
metadata_obj = MetaData()
Expand Down
6 changes: 5 additions & 1 deletion tests/manager/celery/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@


def test_task_session_maker() -> None:
task_name = "test-task"
task = Task()
task.name = task_name

# If session maker is not initialized, it should be None
assert task._session_maker is None
Expand All @@ -25,7 +27,9 @@ def test_task_session_maker() -> None:
patch("palace.manager.celery.task.sessionmaker") as mock_sessionmaker,
):
assert task.session_maker == mock_sessionmaker.return_value
mock_session_manager.engine.assert_called_once_with(poolclass=NullPool)
mock_session_manager.engine.assert_called_once_with(
poolclass=NullPool, application_name=task_name
)
mock_sessionmaker.assert_called_once_with(
bind=mock_session_manager.engine.return_value
)
Expand Down
44 changes: 38 additions & 6 deletions tests/manager/sqlalchemy/test_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from unittest.mock import patch

import pytest

from palace.manager.core.config import Configuration
from palace.manager.sqlalchemy.model.coverage import Timestamp
from palace.manager.sqlalchemy.session import (
Expand Down Expand Up @@ -32,7 +34,9 @@ def test_initialize_data_does_not_reset_timestamp(

@patch("palace.manager.sqlalchemy.session.create_engine")
@patch.object(Configuration, "database_url")
def test_engine(self, mock_database_url, mock_create_engine):
def test_engine(
self, mock_database_url, mock_create_engine, caplog: pytest.LogCaptureFixture
):
expected_args = {
"echo": False,
"json_serializer": json_serializer,
Expand All @@ -41,9 +45,9 @@ def test_engine(self, mock_database_url, mock_create_engine):
}

# If a URL is passed in, it's used.
SessionManager.engine("url")
SessionManager.engine("postgres://url")
mock_database_url.assert_not_called()
mock_create_engine.assert_called_once_with("url", **expected_args)
mock_create_engine.assert_called_once_with("postgres://url", **expected_args)
mock_create_engine.reset_mock()

# If no URL is passed in, the URL from the configuration is used.
Expand All @@ -52,12 +56,29 @@ def test_engine(self, mock_database_url, mock_create_engine):
mock_create_engine.assert_called_once_with(
mock_database_url.return_value, **expected_args
)
mock_create_engine.reset_mock()

# If we pass in an application name, it's added to the URL.
SessionManager.engine("postgres://url", application_name="test-app")
mock_create_engine.assert_called_once_with(
"postgres://url?application_name=test-app", **expected_args
)
mock_create_engine.reset_mock()

# If the URL already has an application name, it's overwritten.
SessionManager.engine(
"postgres://url?application_name=old-app", application_name="test-app"
)
mock_create_engine.assert_called_once_with(
"postgres://url?application_name=test-app", **expected_args
)
assert "Overwriting existing application_name in database URL" in caplog.text

@patch.object(SessionManager, "engine")
@patch.object(SessionManager, "session_from_connection")
def test_session(self, mock_session_from_connection, mock_engine):
session = SessionManager.session("test-url")
mock_engine.assert_called_once_with("test-url")
mock_engine.assert_called_once_with("test-url", application_name=None)
mock_engine.return_value.connect.assert_called_once()
mock_session_from_connection.assert_called_once_with(
mock_engine.return_value.connect.return_value
Expand All @@ -71,7 +92,18 @@ def test_production_session(mock_database_url, mock_session):
# Make sure production_session() calls session() with the URL from the
# configuration.
mock_database_url.return_value = "test-url"
session = production_session()
session = production_session("test-app")
mock_database_url.assert_called_once()
mock_session.assert_called_once_with("test-url", initialize_data=True)
mock_session.assert_called_once_with("test-url", application_name="test-app")
assert session == mock_session.return_value

# production_session can also be called with a class that sets the application name
mock_session.reset_mock()

class Mock:
...

production_session(Mock)
mock_session.assert_called_once_with(
"test-url", application_name="tests.manager.sqlalchemy.test_session.Mock"
)

0 comments on commit 7f01ee6

Please sign in to comment.