diff --git a/snowfakery/data_generator_runtime.py b/snowfakery/data_generator_runtime.py index c7ccdddd..74aa5693 100644 --- a/snowfakery/data_generator_runtime.py +++ b/snowfakery/data_generator_runtime.py @@ -83,6 +83,13 @@ def __init__(self, nicknames_and_tables: Mapping[str, str], id_manager: IdManage def first_new_id(self, tablename): return self.orig_used_ids.get(tablename, 0) + 1 + def last_id_for_table(self, tablename): + last_obj = self.last_seen_obj_by_table.get(tablename) + if last_obj: + return last_obj.id + else: + return self.orig_used_ids.get(tablename) + class Globals: """Globally named objects and other aspects of global scope @@ -129,8 +136,7 @@ def register_object( self.transients.nicknamed_objects[nickname] = obj if persistent_object: self.persistent_objects_by_table[obj._tablename] = obj - else: - self.transients.last_seen_obj_by_table[obj._tablename] = obj + self.transients.last_seen_obj_by_table[obj._tablename] = obj @property def object_names(self): diff --git a/snowfakery/data_generator_runtime_object_model.py b/snowfakery/data_generator_runtime_object_model.py index 42250386..f6013e37 100644 --- a/snowfakery/data_generator_runtime_object_model.py +++ b/snowfakery/data_generator_runtime_object_model.py @@ -333,7 +333,7 @@ def render(self, context: RuntimeContext) -> FieldValue: ) with self.exception_handling( - "Cannot evaluate function `{}`:\n {e}", self.function_name + "Cannot evaluate function `{}`:\n {e}", [self.function_name] ): value = evaluate_function(func, self.args, self.kwargs, context) diff --git a/snowfakery/standard_plugins/Salesforce.py b/snowfakery/standard_plugins/Salesforce.py index 24c5e850..d4862f4b 100644 --- a/snowfakery/standard_plugins/Salesforce.py +++ b/snowfakery/standard_plugins/Salesforce.py @@ -330,6 +330,9 @@ def ProfileId(self, name): # TODO: Tests for this class class SOQLDatasetImpl(DatasetBase): + iterator = None + tempdir = None + def __init__(self, plugin, *args, **kwargs): from cumulusci.tasks.bulkdata.step import ( get_query_operation, diff --git a/snowfakery/template_funcs.py b/snowfakery/template_funcs.py index 0358d40b..0ccdcc83 100644 --- a/snowfakery/template_funcs.py +++ b/snowfakery/template_funcs.py @@ -254,14 +254,14 @@ def random_reference(self, tablename: str, scope: str = "current-iteration"): """ globls = self.context.interpreter.globals - last_object = globls.transients.last_seen_obj_by_table.get(tablename) - if last_object: - last_id = last_object.id + last_id = globls.transients.last_id_for_table(tablename) + if last_id: if scope == "prior-and-current-iterations": first_id = 1 warnings.warn("Global scope is an experimental feature.") elif scope == "current-iteration": first_id = globls.first_new_id(tablename) + last_id = max(first_id, last_id) else: raise DataGenError( f"Scope must be 'prior-and-current-iterations' or 'current-iteration' not {scope}", diff --git a/tests/parent-child-just-once.yml b/tests/parent-child-just-once.yml new file mode 100644 index 00000000..fa623391 --- /dev/null +++ b/tests/parent-child-just-once.yml @@ -0,0 +1,11 @@ +- object: Parent + just_once: true + nickname: ParentNickname + +- object: Child + fields: + parent: + random_reference: Parent + # should fail: + # parent2: + # random_reference: ParentNickname diff --git a/tests/test_restartability.py b/tests/test_continuation.py similarity index 81% rename from tests/test_restartability.py rename to tests/test_continuation.py index 14bf5679..1c1c14a3 100644 --- a/tests/test_restartability.py +++ b/tests/test_continuation.py @@ -4,7 +4,7 @@ from snowfakery.data_generator import generate -class TestRestart: +class TestContinuation: def test_nicknames_persist(self, generated_rows): yaml = """ - object: foo @@ -103,3 +103,27 @@ def test_circular_references(self, write_row): StringIO(yaml_data), continuation_file=StringIO(continuation_yaml), ) + + def test_reference_just_once(self, generated_rows): + yaml_data = """ + - object: Parent + just_once: true + + - object: Child + fields: + parent: + random_reference: Parent + """ + generate_twice(yaml_data) + assert generated_rows() + + +def generate_twice(yaml): + continuation_file = StringIO() + generate(StringIO(yaml), generate_continuation_file=continuation_file) + next_contination_file = StringIO() + generate( + StringIO(yaml), + continuation_file=StringIO(continuation_file.getvalue()), + generate_continuation_file=next_contination_file, + ) diff --git a/tests/test_references.py b/tests/test_references.py index 5027109d..24574eb1 100644 --- a/tests/test_references.py +++ b/tests/test_references.py @@ -490,3 +490,31 @@ def test_reference_really_wrong_type(self): with pytest.raises(DataGenError) as e: generate(StringIO(yaml)) assert "can't get reference to object" in str(e).lower() + + def test_random_reference_to_just_once_obj(self, generated_rows): + yaml = """ + - object: Parent + just_once: true + + - object: Child + fields: + parent: + random_reference: Parent + """ + generate(StringIO(yaml), stopping_criteria=StoppingCriteria("Child", 2)) + assert len(generated_rows.mock_calls) == 3 + + def test_random_reference_to_nickname_fails(self): + yaml = """ + - object: Parent + nickname: ParentNickname + just_once: true + + - object: Child + fields: + parent: + random_reference: ParentNickname + """ + with pytest.raises(DataGenError) as e: + generate(StringIO(yaml)) + assert "there is no table named parent" in str(e).lower()