Skip to content

Commit

Permalink
Merge pull request #433 from ral-facilities/handle-property-migration…
Browse files Browse the repository at this point in the history
…-conflict-#412

Handle property migration write conflicts #412
  • Loading branch information
joelvdavies authored Dec 9, 2024
2 parents 2930356 + 2274dcb commit 98ef6b5
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 79 deletions.
30 changes: 29 additions & 1 deletion inventory_management_system_api/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
Module for connecting to a MongoDB database.
"""

from typing import Annotated
from contextlib import contextmanager
from typing import Annotated, Generator

from fastapi import Depends
from pymongo import MongoClient
from pymongo.client_session import ClientSession
from pymongo.database import Database
from pymongo.errors import OperationFailure

from inventory_management_system_api.core.config import config
from inventory_management_system_api.core.exceptions import WriteConflictError

db_config = config.database
mongodb_client = MongoClient(
Expand All @@ -28,4 +32,28 @@ def get_database() -> Database:
return mongodb_client[db_config.name.get_secret_value()]


@contextmanager
def start_session_transaction(action_description: str) -> Generator[ClientSession, None, None]:
"""
Starts a MongoDB session followed by a transaction and returns the session to use.
Also handles write conflicts.
:param action_description: Description of what the transaction is doing so it can be used in any raised errors.
:raises WriteConflictError: If there a write conflict during the transaction.
:returns: MongoDB session that has a transaction started on it.
"""

with mongodb_client.start_session() as session:
with session.start_transaction():
try:
yield session
except OperationFailure as exc:
if "write conflict" in str(exc).lower():
raise WriteConflictError(
f"Write conflict while {action_description}. Please try again later."
) from exc
raise exc


DatabaseDep = Annotated[Database, Depends(get_database)]
6 changes: 6 additions & 0 deletions inventory_management_system_api/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,9 @@ class InvalidActionError(DatabaseError):
"""
Exception raised when trying to update an item's catalogue item ID
"""


class WriteConflictError(DatabaseError):
"""
Exception raised when a transaction has a write conflict.
"""
4 changes: 2 additions & 2 deletions inventory_management_system_api/migrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def description(self) -> str:
def forward(self, session: ClientSession):
"""Method for executing the migration."""

def forward_after_transaction(self, session: ClientSession):
def forward_after_transaction(self):
"""Method called after the forward function is called to do anything that can't be done inside a transaction
(ONLY USE IF NECESSARY e.g. dropping a collection)."""

@abstractmethod
def backward(self, session: ClientSession):
"""Method for reversing the migration."""

def backward_after_transaction(self, session: ClientSession):
def backward_after_transaction(self):
"""
Method called after the backward function is called to do anything that can't be done inside a transaction
(ONLY USE IF NECESSARY e.g. dropping a collection).
Expand Down
36 changes: 17 additions & 19 deletions inventory_management_system_api/migrations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from typing import Optional

from inventory_management_system_api.core.database import get_database, mongodb_client
from inventory_management_system_api.core.database import get_database, start_session_transaction
from inventory_management_system_api.migrations.base import BaseMigration

database = get_database()
Expand Down Expand Up @@ -184,16 +184,15 @@ def execute_migrations_forward(migrations: dict[str, BaseMigration]) -> None:
"""

# Run migration inside a session to lock writes and revert the changes if it fails
with mongodb_client.start_session() as session:
with session.start_transaction():
for name, migration in migrations.items():
logger.info("Performing forward migration for '%s'...", name)
migration.forward(session)
set_previous_migration(list(migrations.keys())[-1])
# Run some things outside the transaction e.g. if needing to drop a collection
with start_session_transaction("forward migration") as session:
for name, migration in migrations.items():
logger.info("Finalising forward migration for '%s'...", name)
migration.forward_after_transaction(session)
logger.info("Performing forward migration for '%s'...", name)
migration.forward(session)
set_previous_migration(list(migrations.keys())[-1])
# Run some things outside the transaction e.g. if needing to drop a collection
for name, migration in migrations.items():
logger.info("Finalising forward migration for '%s'...", name)
migration.forward_after_transaction()


def execute_migrations_backward(migrations: dict[str, BaseMigration], final_previous_migration_name: Optional[str]):
Expand All @@ -209,13 +208,12 @@ def execute_migrations_backward(migrations: dict[str, BaseMigration], final_prev
there aren't any.
"""
# Run migration inside a session to lock writes and revert the changes if it fails
with mongodb_client.start_session() as session:
with session.start_transaction():
for name, migration in migrations.items():
logger.info("Performing backward migration for '%s'...", name)
migration.backward(session)
set_previous_migration(final_previous_migration_name)
# Run some things outside the transaction e.g. if needing to drop a collection
with start_session_transaction("backward migration") as session:
for name, migration in migrations.items():
logger.info("Finalising backward migration for '%s'...", name)
migration.backward_after_transaction(session)
logger.info("Performing backward migration for '%s'...", name)
migration.backward(session)
set_previous_migration(final_previous_migration_name)
# Run some things outside the transaction e.g. if needing to drop a collection
for name, migration in migrations.items():
logger.info("Finalising backward migration for '%s'...", name)
migration.backward_after_transaction()
11 changes: 10 additions & 1 deletion inventory_management_system_api/routers/v1/catalogue_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
InvalidObjectIdError,
LeafCatalogueCategoryError,
MissingRecordError,
WriteConflictError,
)
from inventory_management_system_api.schemas.breadcrumbs import BreadcrumbsGetSchema
from inventory_management_system_api.schemas.catalogue_category import (
CATALOGUE_CATEGORY_WITH_CHILD_NON_EDITABLE_FIELDS,
CatalogueCategoryPatchSchema,
CatalogueCategoryPostSchema,
CatalogueCategoryPropertyPatchSchema,
CatalogueCategoryPropertyPostSchema,
CatalogueCategoryPropertySchema,
CatalogueCategorySchema,
CATALOGUE_CATEGORY_WITH_CHILD_NON_EDITABLE_FIELDS,
)
from inventory_management_system_api.services.catalogue_category import CatalogueCategoryService
from inventory_management_system_api.services.catalogue_category_property import CatalogueCategoryPropertyService
Expand Down Expand Up @@ -272,6 +273,10 @@ def create_property(
message = str(exc)
logger.exception(message)
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=message) from exc
except WriteConflictError as exc:
message = str(exc)
logger.exception(message)
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=message) from exc


@router.patch(
Expand Down Expand Up @@ -320,3 +325,7 @@ def partial_update_property(
message = str(exc)
logger.exception(message)
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=message) from exc
except WriteConflictError as exc:
message = str(exc)
logger.exception(message)
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=message) from exc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from fastapi import Depends

from inventory_management_system_api.core.database import mongodb_client
from inventory_management_system_api.core.database import start_session_transaction
from inventory_management_system_api.core.exceptions import InvalidActionError, MissingRecordError
from inventory_management_system_api.models.catalogue_category import (
AllowedValues,
Expand Down Expand Up @@ -105,32 +105,31 @@ def create(
)

# Run all subsequent edits within a transaction to ensure they will all succeed or fail together
with mongodb_client.start_session() as session:
with session.start_transaction():
# Firstly update the catalogue category
catalogue_category_property_out = self._catalogue_category_repository.create_property(
catalogue_category_id, catalogue_category_property_in, session=session
)
with start_session_transaction("adding property") as session:
# Firstly update the catalogue category
catalogue_category_property_out = self._catalogue_category_repository.create_property(
catalogue_category_id, catalogue_category_property_in, session=session
)

property_in = PropertyIn(
id=str(catalogue_category_property_in.id),
name=catalogue_category_property_in.name,
value=catalogue_category_property.default_value,
unit=unit_value,
unit_id=catalogue_category_property.unit_id,
)
property_in = PropertyIn(
id=str(catalogue_category_property_in.id),
name=catalogue_category_property_in.name,
value=catalogue_category_property.default_value,
unit=unit_value,
unit_id=catalogue_category_property.unit_id,
)

# Add property to all catalogue items of the catalogue category
self._catalogue_item_repository.insert_property_to_all_matching(
catalogue_category_id, property_in, session=session
)
# Add property to all catalogue items of the catalogue category
self._catalogue_item_repository.insert_property_to_all_matching(
catalogue_category_id, property_in, session=session
)

# Add property to all items of the catalogue items
# Obtain a list of ids to do this rather than iterate one by one as its faster. Limiting factor
# would be memory to store these ids and the network bandwidth it takes to send the request to the
# database but for 10000 items being updated this only takes 4.92 KB
catalogue_item_ids = self._catalogue_item_repository.list_ids(catalogue_category_id, session=session)
self._item_repository.insert_property_to_all_in(catalogue_item_ids, property_in, session=session)
# Add property to all items of the catalogue items
# Obtain a list of ids to do this rather than iterate one by one as its faster. Limiting factor
# would be memory to store these ids and the network bandwidth it takes to send the request to the
# database but for 10000 items being updated this only takes 4.92 KB
catalogue_item_ids = self._catalogue_item_repository.list_ids(catalogue_category_id, session=session)
self._item_repository.insert_property_to_all_in(catalogue_item_ids, property_in, session=session)

return catalogue_category_property_out

Expand Down Expand Up @@ -228,20 +227,19 @@ def update(
property_in = CatalogueCategoryPropertyIn(**{**existing_property_out.model_dump(), **update_data})

# Run all subsequent edits within a transaction to ensure they will all succeed or fail together
with mongodb_client.start_session() as session:
with session.start_transaction():
# Firstly update the catalogue category
property_out = self._catalogue_category_repository.update_property(
catalogue_category_id, catalogue_category_property_id, property_in, session=session
)
with start_session_transaction("updating property") as session:
# Firstly update the catalogue category
property_out = self._catalogue_category_repository.update_property(
catalogue_category_id, catalogue_category_property_id, property_in, session=session
)

# Avoid propagating changes unless absolutely necessary
if updating_name:
self._catalogue_item_repository.update_names_of_all_properties_with_id(
catalogue_category_property_id, catalogue_category_property.name, session=session
)
self._item_repository.update_names_of_all_properties_with_id(
catalogue_category_property_id, catalogue_category_property.name, session=session
)
# Avoid propagating changes unless absolutely necessary
if updating_name:
self._catalogue_item_repository.update_names_of_all_properties_with_id(
catalogue_category_property_id, catalogue_category_property.name, session=session
)
self._item_repository.update_names_of_all_properties_with_id(
catalogue_category_property_id, catalogue_category_property.name, session=session
)

return property_out
55 changes: 55 additions & 0 deletions test/unit/core/test_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Unit tests for functions inside the `database` module.
"""

from unittest.mock import patch

import pytest
from pymongo.errors import OperationFailure

from inventory_management_system_api.core.database import start_session_transaction
from inventory_management_system_api.core.exceptions import WriteConflictError


@patch("inventory_management_system_api.core.database.mongodb_client")
def test_start_session_transaction(mock_mongodb_client):
"""Test `start_session_transaction`."""

expected_session = mock_mongodb_client.start_session.return_value.__enter__.return_value

with start_session_transaction("testing") as session:
pass

assert expected_session == session
expected_session.start_transaction.assert_called_once()


@patch("inventory_management_system_api.core.database.mongodb_client")
def test_start_session_transaction_with_operation_failure(mock_mongodb_client):
"""Test `start_session_transaction` when there is an operation failure inside the transaction."""

expected_session = mock_mongodb_client.start_session.return_value.__enter__.return_value

with pytest.raises(OperationFailure) as exc:
with start_session_transaction("testing") as session:
raise OperationFailure("Some operation error.")

assert expected_session == session
expected_session.start_transaction.assert_called_once()
assert str(exc.value) == "Some operation error."


@patch("inventory_management_system_api.core.database.mongodb_client")
def test_start_session_transaction_with_operation_failure_write_conflict(mock_mongodb_client):
"""Test `start_session_transaction` when there is an operation failure due to a write conflict inside the
transaction."""

expected_session = mock_mongodb_client.start_session.return_value.__enter__.return_value

with pytest.raises(WriteConflictError) as exc:
with start_session_transaction("testing") as session:
raise OperationFailure("Write conflict during plan execution and yielding is disabled.")

assert expected_session == session
expected_session.start_transaction.assert_called_once()
assert str(exc.value) == "Write conflict while testing. Please try again later."
20 changes: 10 additions & 10 deletions test/unit/migrations/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,37 +388,37 @@ def test_get_previous_migration_when_none(mock_database):


@patch("inventory_management_system_api.migrations.core.set_previous_migration")
@patch("inventory_management_system_api.migrations.core.mongodb_client")
def test_execute_migrations_forward(mock_mongodb_client, mock_set_previous_migration):
@patch("inventory_management_system_api.migrations.core.start_session_transaction")
def test_execute_migrations_forward(mock_start_session_transaction, mock_set_previous_migration):
"""Tests that `execute_migrations_forward` functions as expected."""

migrations = {"migration1": MagicMock(), "migration2": MagicMock()}
expected_session = mock_mongodb_client.start_session.return_value.__enter__.return_value
expected_session = mock_start_session_transaction.return_value.__enter__.return_value

execute_migrations_forward(migrations)

expected_session.start_transaction.assert_called_once()
mock_start_session_transaction.assert_called_once_with("forward migration")
for migration in migrations.values():
migration.forward.assert_called_once_with(expected_session)
migration.forward_after_transaction.assert_called_once_with(expected_session)
migration.forward_after_transaction.assert_called_once()

mock_set_previous_migration.assert_called_once_with(list(migrations.keys())[-1])


@patch("inventory_management_system_api.migrations.core.set_previous_migration")
@patch("inventory_management_system_api.migrations.core.mongodb_client")
def test_execute_migrations_backward(mock_mongodb_client, mock_set_previous_migration):
@patch("inventory_management_system_api.migrations.core.start_session_transaction")
def test_execute_migrations_backward(mock_start_session_transaction, mock_set_previous_migration):
"""Tests that `execute_migrations_backward` functions as expected."""

migrations = {"migration1": MagicMock(), "migration2": MagicMock()}
final_previous_migration_name = "final_migration_name"
expected_session = mock_mongodb_client.start_session.return_value.__enter__.return_value
expected_session = mock_start_session_transaction.return_value.__enter__.return_value

execute_migrations_backward(migrations, final_previous_migration_name)

expected_session.start_transaction.assert_called_once()
mock_start_session_transaction.assert_called_once_with("backward migration")
for migration in migrations.values():
migration.backward.assert_called_once_with(expected_session)
migration.backward_after_transaction.assert_called_once_with(expected_session)
migration.backward_after_transaction.assert_called_once()

mock_set_previous_migration.assert_called_once_with(final_previous_migration_name)
Loading

0 comments on commit 98ef6b5

Please sign in to comment.