From 7f01ee6424e37162fa10c70ff681088db939d020 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Fri, 10 May 2024 10:38:23 -0300 Subject: [PATCH] Set application name in PG connection string (#1828) * Set application name in PG connection string. That way we will get more info from RDS stats. * Make sure scripts get application_name set as well. --- bin/configuration/add_saml_federations.py | 2 +- bin/hold_notifications | 2 +- bin/integration_test | 3 +- bin/patron_activity_sync_notifications | 4 ++- bin/playtime_reporting | 3 +- bin/playtime_summation | 3 +- src/palace/manager/api/app.py | 2 +- src/palace/manager/celery/task.py | 4 ++- src/palace/manager/core/scripts.py | 2 +- src/palace/manager/sqlalchemy/session.py | 37 ++++++++++++++----- tests/fixtures/database.py | 2 +- tests/manager/celery/test_task.py | 6 +++- tests/manager/sqlalchemy/test_session.py | 44 +++++++++++++++++++---- 13 files changed, 85 insertions(+), 29 deletions(-) diff --git a/bin/configuration/add_saml_federations.py b/bin/configuration/add_saml_federations.py index 4d59c464b9..8865f263e9 100755 --- a/bin/configuration/add_saml_federations.py +++ b/bin/configuration/add_saml_federations.py @@ -7,7 +7,7 @@ from palace.manager.sqlalchemy.model.saml import SAMLFederation from palace.manager.sqlalchemy.session import production_session -with closing(production_session()) as db: +with closing(production_session("add_saml_federations")) as db: incommon_federation = ( db.query(SAMLFederation) .filter(SAMLFederation.type == incommon.FEDERATION_TYPE) diff --git a/bin/hold_notifications b/bin/hold_notifications index 033569127a..5129411756 100755 --- a/bin/hold_notifications +++ b/bin/hold_notifications @@ -5,4 +5,4 @@ from palace.manager.core.jobs.holds_notification import HoldsNotificationMonitor from palace.manager.sqlalchemy.session import production_session -HoldsNotificationMonitor(production_session()).run() +HoldsNotificationMonitor(production_session(HoldsNotificationMonitor)).run() diff --git a/bin/integration_test b/bin/integration_test index 6103656e6c..ce63e84ec8 100755 --- a/bin/integration_test +++ b/bin/integration_test @@ -3,6 +3,5 @@ from palace.manager.core.jobs.integration_test import IntegrationTest -from palace.manager.sqlalchemy.session import production_session -IntegrationTest(production_session(initialize_data=False)).run() +IntegrationTest().run() diff --git a/bin/patron_activity_sync_notifications b/bin/patron_activity_sync_notifications index f27d0c0a38..4fb4ab63e0 100755 --- a/bin/patron_activity_sync_notifications +++ b/bin/patron_activity_sync_notifications @@ -7,4 +7,6 @@ from palace.manager.core.jobs.patron_activity_sync import ( ) from palace.manager.sqlalchemy.session import production_session -PatronActivitySyncNotificationScript(production_session()).run() +PatronActivitySyncNotificationScript( + production_session(PatronActivitySyncNotificationScript) +).run() diff --git a/bin/playtime_reporting b/bin/playtime_reporting index 78de124f5a..751444f0d9 100755 --- a/bin/playtime_reporting +++ b/bin/playtime_reporting @@ -3,6 +3,5 @@ from palace.manager.core.jobs.playtime_entries import PlaytimeEntriesEmailReportsScript -from palace.manager.sqlalchemy.session import production_session -PlaytimeEntriesEmailReportsScript(production_session(initialize_data=False)).run() +PlaytimeEntriesEmailReportsScript().run() diff --git a/bin/playtime_summation b/bin/playtime_summation index 00d16c4c83..cf5b73d7ba 100755 --- a/bin/playtime_summation +++ b/bin/playtime_summation @@ -3,6 +3,5 @@ from palace.manager.core.jobs.playtime_entries import PlaytimeEntriesSummationScript -from palace.manager.sqlalchemy.session import production_session -PlaytimeEntriesSummationScript(production_session(initialize_data=False)).run() +PlaytimeEntriesSummationScript().run() diff --git a/src/palace/manager/api/app.py b/src/palace/manager/api/app.py index 705288b6c3..d32d2071f4 100644 --- a/src/palace/manager/api/app.py +++ b/src/palace/manager/api/app.py @@ -81,7 +81,7 @@ def initialize_circulation_manager(): def initialize_database(): - session_factory = SessionManager.sessionmaker() + session_factory = SessionManager.sessionmaker(application_name="manager") _db = flask_scoped_session(session_factory, app) app._db = _db diff --git a/src/palace/manager/celery/task.py b/src/palace/manager/celery/task.py index 5a3d3e238e..c327aa4469 100644 --- a/src/palace/manager/celery/task.py +++ b/src/palace/manager/celery/task.py @@ -68,7 +68,9 @@ def session_maker(self) -> sessionmaker[Session]: worker utilization in production. """ if self._session_maker is None: - engine = SessionManager.engine(poolclass=NullPool) + engine = SessionManager.engine( + poolclass=NullPool, application_name=self.name + ) maker = sessionmaker(bind=engine) SessionManager.setup_event_listener(maker) self._session_maker = maker diff --git a/src/palace/manager/core/scripts.py b/src/palace/manager/core/scripts.py index db088f2461..4614e96d7b 100644 --- a/src/palace/manager/core/scripts.py +++ b/src/palace/manager/core/scripts.py @@ -61,7 +61,7 @@ class Script: @property def _db(self) -> Session: if not hasattr(self, "_session"): - self._session = production_session() + self._session = production_session(self.__class__) return self._session @property diff --git a/src/palace/manager/sqlalchemy/session.py b/src/palace/manager/sqlalchemy/session.py index 8f9d731fcb..cc415b69ee 100644 --- a/src/palace/manager/sqlalchemy/session.py +++ b/src/palace/manager/sqlalchemy/session.py @@ -6,7 +6,7 @@ from pydantic.json import pydantic_encoder from sqlalchemy import create_engine, event, literal_column, select, table, text -from sqlalchemy.engine import Connection, Engine +from sqlalchemy.engine import Connection, Engine, make_url from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import Pool @@ -19,6 +19,7 @@ from palace.manager.sqlalchemy.model.key import Key from palace.manager.sqlalchemy.util import get_one_or_create from palace.manager.util.datetime_helpers import utc_now +from palace.manager.util.log import LoggerMixin from palace.manager.util.resources import resources_dir DEBUG = False @@ -37,16 +38,30 @@ def json_serializer(*args, **kwargs) -> str: return json.dumps(*args, default=json_encoder, **kwargs) -class SessionManager: +class SessionManager(LoggerMixin): # A function that calculates recursively equivalent identifiers # is also defined in SQL. RECURSIVE_EQUIVALENTS_FUNCTION = "recursive_equivalents.sql" @classmethod def engine( - cls, url: str | None = None, poolclass: type[Pool] | None = None + cls, + url: str | None = None, + poolclass: type[Pool] | None = None, + application_name: str | None = None, ) -> Engine: url = url or Configuration.database_url() + url_obj = make_url(url) + if application_name is not None: + if "application_name" in url_obj.query.keys(): + cls.logger().warning( + "Overwriting existing application_name in database URL " + f"({url_obj.render_as_string(hide_password=True)}) with {application_name}" + ) + url = url_obj.set( + query={**url_obj.query, "application_name": application_name} + ).render_as_string(hide_password=False) + return create_engine( url, echo=DEBUG, @@ -63,8 +78,8 @@ def setup_event_listener( return session @classmethod - def sessionmaker(cls): - bind_obj = cls.engine() + def sessionmaker(cls, application_name: str | None = None): + bind_obj = cls.engine(application_name=application_name) session_factory = sessionmaker(bind=bind_obj) cls.setup_event_listener(session_factory) return session_factory @@ -76,8 +91,8 @@ def initialize_schema(cls, engine): Base.metadata.create_all(engine) @classmethod - def session(cls, url, initialize_data=True, initialize_schema=True): - engine = cls.engine(url) + def session(cls, url: str | None = None, application_name: str | None = None): + engine = cls.engine(url, application_name=application_name) connection = engine.connect() return cls.session_from_connection(connection) @@ -164,10 +179,14 @@ def initialize_data(cls, session: Session): return session -def production_session(initialize_data=True) -> Session: +def production_session(application_name: type[object] | str) -> Session: + if isinstance(application_name, str): + application_name = application_name + else: + application_name = f"{application_name.__module__}.{application_name.__name__}" url = Configuration.database_url() if url.startswith('"'): url = url[1:] logging.debug("Database url: %s", url) - _db = SessionManager.session(url, initialize_data=initialize_data) + _db = SessionManager.session(url, application_name=application_name) return _db diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index 8c65370681..42d43d5197 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -268,7 +268,7 @@ def __init__(self, database_name: DatabaseCreationFixture) -> None: self.connection = self.engine.connect() def engine_factory(self) -> Engine: - return SessionManager.engine(self.database_name.url) + return SessionManager.engine(self.database_name.url, application_name="test") def drop_existing_schema(self) -> None: metadata_obj = MetaData() diff --git a/tests/manager/celery/test_task.py b/tests/manager/celery/test_task.py index a8d946a41f..3b70289d54 100644 --- a/tests/manager/celery/test_task.py +++ b/tests/manager/celery/test_task.py @@ -8,7 +8,9 @@ def test_task_session_maker() -> None: + task_name = "test-task" task = Task() + task.name = task_name # If session maker is not initialized, it should be None assert task._session_maker is None @@ -25,7 +27,9 @@ def test_task_session_maker() -> None: patch("palace.manager.celery.task.sessionmaker") as mock_sessionmaker, ): assert task.session_maker == mock_sessionmaker.return_value - mock_session_manager.engine.assert_called_once_with(poolclass=NullPool) + mock_session_manager.engine.assert_called_once_with( + poolclass=NullPool, application_name=task_name + ) mock_sessionmaker.assert_called_once_with( bind=mock_session_manager.engine.return_value ) diff --git a/tests/manager/sqlalchemy/test_session.py b/tests/manager/sqlalchemy/test_session.py index 0672879f57..09f1dd333f 100644 --- a/tests/manager/sqlalchemy/test_session.py +++ b/tests/manager/sqlalchemy/test_session.py @@ -1,5 +1,7 @@ from unittest.mock import patch +import pytest + from palace.manager.core.config import Configuration from palace.manager.sqlalchemy.model.coverage import Timestamp from palace.manager.sqlalchemy.session import ( @@ -32,7 +34,9 @@ def test_initialize_data_does_not_reset_timestamp( @patch("palace.manager.sqlalchemy.session.create_engine") @patch.object(Configuration, "database_url") - def test_engine(self, mock_database_url, mock_create_engine): + def test_engine( + self, mock_database_url, mock_create_engine, caplog: pytest.LogCaptureFixture + ): expected_args = { "echo": False, "json_serializer": json_serializer, @@ -41,9 +45,9 @@ def test_engine(self, mock_database_url, mock_create_engine): } # If a URL is passed in, it's used. - SessionManager.engine("url") + SessionManager.engine("postgres://url") mock_database_url.assert_not_called() - mock_create_engine.assert_called_once_with("url", **expected_args) + mock_create_engine.assert_called_once_with("postgres://url", **expected_args) mock_create_engine.reset_mock() # If no URL is passed in, the URL from the configuration is used. @@ -52,12 +56,29 @@ def test_engine(self, mock_database_url, mock_create_engine): mock_create_engine.assert_called_once_with( mock_database_url.return_value, **expected_args ) + mock_create_engine.reset_mock() + + # If we pass in an application name, it's added to the URL. + SessionManager.engine("postgres://url", application_name="test-app") + mock_create_engine.assert_called_once_with( + "postgres://url?application_name=test-app", **expected_args + ) + mock_create_engine.reset_mock() + + # If the URL already has an application name, it's overwritten. + SessionManager.engine( + "postgres://url?application_name=old-app", application_name="test-app" + ) + mock_create_engine.assert_called_once_with( + "postgres://url?application_name=test-app", **expected_args + ) + assert "Overwriting existing application_name in database URL" in caplog.text @patch.object(SessionManager, "engine") @patch.object(SessionManager, "session_from_connection") def test_session(self, mock_session_from_connection, mock_engine): session = SessionManager.session("test-url") - mock_engine.assert_called_once_with("test-url") + mock_engine.assert_called_once_with("test-url", application_name=None) mock_engine.return_value.connect.assert_called_once() mock_session_from_connection.assert_called_once_with( mock_engine.return_value.connect.return_value @@ -71,7 +92,18 @@ def test_production_session(mock_database_url, mock_session): # Make sure production_session() calls session() with the URL from the # configuration. mock_database_url.return_value = "test-url" - session = production_session() + session = production_session("test-app") mock_database_url.assert_called_once() - mock_session.assert_called_once_with("test-url", initialize_data=True) + mock_session.assert_called_once_with("test-url", application_name="test-app") assert session == mock_session.return_value + + # production_session can also be called with a class that sets the application name + mock_session.reset_mock() + + class Mock: + ... + + production_session(Mock) + mock_session.assert_called_once_with( + "test-url", application_name="tests.manager.sqlalchemy.test_session.Mock" + )