Skip to content

Commit

Permalink
Use start_transaction_session in migration script #412
Browse files Browse the repository at this point in the history
  • Loading branch information
joelvdavies committed Dec 9, 2024
1 parent bb46d63 commit 2274dcb
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 31 deletions.
4 changes: 2 additions & 2 deletions inventory_management_system_api/migrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def description(self) -> str:
def forward(self, session: ClientSession):
"""Method for executing the migration."""

def forward_after_transaction(self, session: ClientSession):
def forward_after_transaction(self):
"""Method called after the forward function is called to do anything that can't be done inside a transaction
(ONLY USE IF NECESSARY e.g. dropping a collection)."""

@abstractmethod
def backward(self, session: ClientSession):
"""Method for reversing the migration."""

def backward_after_transaction(self, session: ClientSession):
def backward_after_transaction(self):
"""
Method called after the backward function is called to do anything that can't be done inside a transaction
(ONLY USE IF NECESSARY e.g. dropping a collection).
Expand Down
36 changes: 17 additions & 19 deletions inventory_management_system_api/migrations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from typing import Optional

from inventory_management_system_api.core.database import get_database, mongodb_client
from inventory_management_system_api.core.database import get_database, start_session_transaction
from inventory_management_system_api.migrations.base import BaseMigration

database = get_database()
Expand Down Expand Up @@ -184,16 +184,15 @@ def execute_migrations_forward(migrations: dict[str, BaseMigration]) -> None:
"""

# Run migration inside a session to lock writes and revert the changes if it fails
with mongodb_client.start_session() as session:
with session.start_transaction():
for name, migration in migrations.items():
logger.info("Performing forward migration for '%s'...", name)
migration.forward(session)
set_previous_migration(list(migrations.keys())[-1])
# Run some things outside the transaction e.g. if needing to drop a collection
with start_session_transaction("forward migration") as session:
for name, migration in migrations.items():
logger.info("Finalising forward migration for '%s'...", name)
migration.forward_after_transaction(session)
logger.info("Performing forward migration for '%s'...", name)
migration.forward(session)
set_previous_migration(list(migrations.keys())[-1])
# Run some things outside the transaction e.g. if needing to drop a collection
for name, migration in migrations.items():
logger.info("Finalising forward migration for '%s'...", name)
migration.forward_after_transaction()


def execute_migrations_backward(migrations: dict[str, BaseMigration], final_previous_migration_name: Optional[str]):
Expand All @@ -209,13 +208,12 @@ def execute_migrations_backward(migrations: dict[str, BaseMigration], final_prev
there aren't any.
"""
# Run migration inside a session to lock writes and revert the changes if it fails
with mongodb_client.start_session() as session:
with session.start_transaction():
for name, migration in migrations.items():
logger.info("Performing backward migration for '%s'...", name)
migration.backward(session)
set_previous_migration(final_previous_migration_name)
# Run some things outside the transaction e.g. if needing to drop a collection
with start_session_transaction("backward migration") as session:
for name, migration in migrations.items():
logger.info("Finalising backward migration for '%s'...", name)
migration.backward_after_transaction(session)
logger.info("Performing backward migration for '%s'...", name)
migration.backward(session)
set_previous_migration(final_previous_migration_name)
# Run some things outside the transaction e.g. if needing to drop a collection
for name, migration in migrations.items():
logger.info("Finalising backward migration for '%s'...", name)
migration.backward_after_transaction()
20 changes: 10 additions & 10 deletions test/unit/migrations/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,37 +388,37 @@ def test_get_previous_migration_when_none(mock_database):


@patch("inventory_management_system_api.migrations.core.set_previous_migration")
@patch("inventory_management_system_api.migrations.core.mongodb_client")
def test_execute_migrations_forward(mock_mongodb_client, mock_set_previous_migration):
@patch("inventory_management_system_api.migrations.core.start_session_transaction")
def test_execute_migrations_forward(mock_start_session_transaction, mock_set_previous_migration):
"""Tests that `execute_migrations_forward` functions as expected."""

migrations = {"migration1": MagicMock(), "migration2": MagicMock()}
expected_session = mock_mongodb_client.start_session.return_value.__enter__.return_value
expected_session = mock_start_session_transaction.return_value.__enter__.return_value

execute_migrations_forward(migrations)

expected_session.start_transaction.assert_called_once()
mock_start_session_transaction.assert_called_once_with("forward migration")
for migration in migrations.values():
migration.forward.assert_called_once_with(expected_session)
migration.forward_after_transaction.assert_called_once_with(expected_session)
migration.forward_after_transaction.assert_called_once()

mock_set_previous_migration.assert_called_once_with(list(migrations.keys())[-1])


@patch("inventory_management_system_api.migrations.core.set_previous_migration")
@patch("inventory_management_system_api.migrations.core.mongodb_client")
def test_execute_migrations_backward(mock_mongodb_client, mock_set_previous_migration):
@patch("inventory_management_system_api.migrations.core.start_session_transaction")
def test_execute_migrations_backward(mock_start_session_transaction, mock_set_previous_migration):
"""Tests that `execute_migrations_backward` functions as expected."""

migrations = {"migration1": MagicMock(), "migration2": MagicMock()}
final_previous_migration_name = "final_migration_name"
expected_session = mock_mongodb_client.start_session.return_value.__enter__.return_value
expected_session = mock_start_session_transaction.return_value.__enter__.return_value

execute_migrations_backward(migrations, final_previous_migration_name)

expected_session.start_transaction.assert_called_once()
mock_start_session_transaction.assert_called_once_with("backward migration")
for migration in migrations.values():
migration.backward.assert_called_once_with(expected_session)
migration.backward_after_transaction.assert_called_once_with(expected_session)
migration.backward_after_transaction.assert_called_once()

mock_set_previous_migration.assert_called_once_with(final_previous_migration_name)

0 comments on commit 2274dcb

Please sign in to comment.