Skip to content

Commit

Permalink
Rollback Operation for Inserts and Upserts (#3718)
Browse files Browse the repository at this point in the history
[W-14522337](https://gus.lightning.force.com/lightning/r/ADM_Work__c/a07EE00001et8KzYAI/view)
and
[W-14522327](https://gus.lightning.force.com/lightning/r/ADM_Work__c/a07EE00001etNspYAE/view)

The implementation is as follows:
1. For inserts, we maintain a separate table `{sobject}_insert_rollback`
for records that have been created (either through insert or upsert)
where we enter only the `sf_id`
2. For upserts, before we load the records, we store the previous values
of the records that will get updated (not upserted) in a table called
`{sobject}_upsert_rollback`. It will have all the fields that are given
in the `mapping.yml` file along with the `sf_id`.
3. When rollback occurs, it is as follows, in reverse order of tables
created (and only considering the rollback tables):
    a. For insert, we delete all the `sf_id`
    b. For upsert, we upsert again the previous values

---------

Co-authored-by: Naman Jain <[email protected]>
Co-authored-by: Jaipal Reddy Kasturi <[email protected]>
  • Loading branch information
3 people authored Dec 28, 2023
1 parent a9ff001 commit a627caa
Show file tree
Hide file tree
Showing 6 changed files with 707 additions and 43 deletions.
199 changes: 192 additions & 7 deletions cumulusci/tasks/bulkdata/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import Session

from cumulusci.core.enums import StrEnum
from cumulusci.core.exceptions import BulkDataException, TaskOptionsError
from cumulusci.core.utils import process_bool_arg
from cumulusci.salesforce_api.org_schema import get_org_schema
Expand All @@ -28,9 +29,11 @@
)
from cumulusci.tasks.bulkdata.step import (
DEFAULT_BULK_BATCH_SIZE,
DataApi,
DataOperationJobResult,
DataOperationStatus,
DataOperationType,
RestApiDmlOperation,
get_dml_operation,
)
from cumulusci.tasks.bulkdata.upsert_utils import (
Expand Down Expand Up @@ -88,6 +91,9 @@ class LoadData(SqlAlchemyMixin, BaseSalesforceApiTask):
"org_shape_match_only": {
"description": "When True, all path options are ignored and only a dataset matching the org shape name will be loaded. Defaults to False."
},
"enable_rollback": {
"description": "When True, performs rollback operation incase of error. Defaults to False"
},
}
row_warning_limit = 10

Expand Down Expand Up @@ -115,6 +121,9 @@ def _init_options(self, kwargs):
self.options["set_recently_viewed"] = process_bool_arg(
self.options.get("set_recently_viewed", True)
)
self.options["enable_rollback"] = process_bool_arg(
self.options.get("enable_rollback", False)
)

def _init_dataset(self):
"""Find the dataset paths to use with the following sequence:
Expand Down Expand Up @@ -261,13 +270,33 @@ def _execute_step(
step, query = self.configure_step(mapping)

with tempfile.TemporaryFile(mode="w+t") as local_ids:
# Store the previous values of the records before upsert
# This is so that we can perform rollback
if (
mapping.action
in [
DataOperationType.ETL_UPSERT,
DataOperationType.UPSERT,
DataOperationType.UPDATE,
]
and self.options["enable_rollback"]
):
UpdateRollback.prepare_for_rollback(
self, step, self._stream_queried_data(mapping, local_ids, query)
)
step.start()
step.load_records(self._stream_queried_data(mapping, local_ids, query))
step.end()

# Process Job Results
if step.job_result.status is not DataOperationStatus.JOB_FAILURE:
local_ids.seek(0)
self._process_job_results(mapping, step, local_ids)
elif (
step.job_result.status is DataOperationStatus.JOB_FAILURE
and self.options["enable_rollback"]
):
Rollback._perform_rollback(self)

return step.job_result

Expand Down Expand Up @@ -454,7 +483,7 @@ def _process_job_results(self, mapping, step, local_ids):
id_table_name = self._initialize_id_table(mapping, self.reset_oids)
conn = self.session.connection()

results_generator = self._generate_results_id_map(step, local_ids)
sf_id_results = self._generate_results_id_map(step, local_ids)

# If we know we have no successful inserts, don't attempt to persist Ids.
# Do, however, drain the generator to get error-checking behavior.
Expand All @@ -465,11 +494,8 @@ def _process_job_results(self, mapping, step, local_ids):
connection=conn,
table=self.metadata.tables[id_table_name],
columns=("id", "sf_id"),
record_iterable=results_generator,
record_iterable=sf_id_results,
)
else:
for r in results_generator:
pass # Drain generator to validate results

# Contact records for Person Accounts are inserted during an Account
# sf_object step. Insert records into the Contact ID table for
Expand All @@ -496,16 +522,37 @@ def _process_job_results(self, mapping, step, local_ids):

def _generate_results_id_map(self, step, local_ids):
"""Consume results from load and prepare rows for id table.
Raise BulkDataException on row errors if configured to do so."""
Raise BulkDataException on row errors if configured to do so.
Adds created records into insert_rollback Table
Performs rollback in case of any errors if enable_rollback is True"""
error_checker = RowErrorChecker(
self.logger, self.options["ignore_row_errors"], self.row_warning_limit
)
local_ids = (lid.strip("\n") for lid in local_ids)
sf_id_results = []
created_results = []
failed_results = []
for result, local_id in zip(step.get_results(), local_ids):
if result.success:
yield (local_id, result.id)
sf_id_results.append([local_id, result.id])
if result.created:
created_results.append([result.id])
else:
failed_results.append([result, local_id])

# We record failed_results separately since if a unsuccesful record
# was in between, it would not store all the successful ids
for result, local_id in failed_results:
try:
error_checker.check_for_row_error(result, local_id)
except Exception as e:
if self.options["enable_rollback"]:
CreateRollback.prepare_for_rollback(self, step, created_results)
Rollback._perform_rollback(self)
raise e
if self.options["enable_rollback"]:
CreateRollback.prepare_for_rollback(self, step, created_results)
return sf_id_results

def _initialize_id_table(self, mapping, should_reset_table):
"""initalize or find table to hold the inserted SF Ids
Expand Down Expand Up @@ -568,6 +615,9 @@ def _init_db(self):
self.metadata.bind = connection
self.inspector = inspect(parent_engine)

# empty the record of initalized tables
Rollback._initialized_rollback_tables_api = {}

# initialize the automap mapping
self.base = automap_base(bind=connection, metadata=self.metadata)
self.base.prepare(connection, reflect=True)
Expand Down Expand Up @@ -810,6 +860,141 @@ def _set_viewed(self) -> T.List["SetRecentlyViewedInfo"]:
return results


class RollbackType(StrEnum):
"""Enum to specify type of rollback"""

UPSERT = "upsert_rollback"
INSERT = "insert_rollback"


class Rollback:
# Store the table name and it's corresponding API (rest or bulk)
_initialized_rollback_tables_api = {}

@staticmethod
def _create_tables_for_rollback(context, step, rollback_type: RollbackType) -> str:
"""Create the tables required for upsert and insert rollback"""
table_name = f"{step.sobject}_{rollback_type}"

if table_name not in Rollback._initialized_rollback_tables_api:
common_columns = [Column("Id", Unicode(255), primary_key=True)]

additional_columns = (
[Column(field, Unicode(255)) for field in step.fields if field != "Id"]
if rollback_type is RollbackType.UPSERT
else []
)

columns = common_columns + additional_columns

# Create the table
rollback_table = Table(table_name, context.metadata, *columns)
rollback_table.create()

# Store the API in the initialized tables dictionary
if isinstance(step, RestApiDmlOperation):
Rollback._initialized_rollback_tables_api[table_name] = DataApi.REST
else:
Rollback._initialized_rollback_tables_api[table_name] = DataApi.BULK

return table_name

@staticmethod
def _perform_rollback(context):
"""Perform total rollback"""
context.logger.info("--Initiated Rollback Procedure--")
for table in reversed(context.metadata.sorted_tables):
if table.name.endswith(RollbackType.INSERT):
CreateRollback._perform_rollback(context, table)
elif table.name.endswith(RollbackType.UPSERT):
UpdateRollback._perform_rollback(context, table)
context.logger.info("--Finished Rollback Procedure--")


class UpdateRollback:
@staticmethod
def prepare_for_rollback(context, step, records):
"""Retrieve previous values for records being updated"""
results, columns = step.get_prev_record_values(records)
if results:
table_name = Rollback._create_tables_for_rollback(
context, step, RollbackType.UPSERT
)
conn = context.session.connection()
sql_bulk_insert_from_records(
connection=conn,
table=context.metadata.tables[table_name],
columns=columns,
record_iterable=results,
)

@staticmethod
def _perform_rollback(context, table: Table) -> None:
"""Perform rollback for updated records"""
sf_object = table.name.split(f"_{RollbackType.UPSERT.value}")[0]
records = context.session.query(table).all()

if records:
context.logger.info(f"Reverting upserts for {sf_object}")
api_options = {"update_key": "Id"}

# Use get_dml_operation to create an UPSERT step
step = get_dml_operation(
sobject=sf_object,
operation=DataOperationType.UPSERT,
api_options=api_options,
context=context,
fields=[column.name for column in table.columns],
api=Rollback._initialized_rollback_tables_api[table.name],
volume=len(records),
)
step.start()
step.load_records(records)
step.end()
context.logger.info("Done")


class CreateRollback:
@staticmethod
def prepare_for_rollback(context, step, records):
"""Store the sf_ids of all records that were created
to prepare for rollback"""
if records:
table_name = Rollback._create_tables_for_rollback(
context, step, RollbackType.INSERT
)
conn = context.session.connection()
sql_bulk_insert_from_records(
connection=conn,
table=context.metadata.tables[table_name],
columns=["Id"],
record_iterable=records,
)

@staticmethod
def _perform_rollback(context, table: Table) -> None:
"""Perform rollback for insert operation"""
sf_object = table.name.split(f"_{RollbackType.INSERT.value}")[0]
records = context.session.query(table).all()

if records:
context.logger.info(f"Deleting {sf_object} records")
# Perform DELETE operation using get_dml_operation
step = get_dml_operation(
sobject=sf_object,
operation=DataOperationType.DELETE,
fields=["Id"],
api_options={},
context=context,
api=Rollback._initialized_rollback_tables_api[table.name],
volume=len(records),
)
step.start()
step.load_records(records)
step.end()
context.logger.info("Done")


class StepResultInfo(T.NamedTuple):
"""Represent a Step Result in a form easily convertible to JSON"""

Expand Down
Loading

0 comments on commit a627caa

Please sign in to comment.