From 15274df8074d79b0d5196da9ea3b594fc636e36c Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Thu, 1 Aug 2024 15:14:38 -0500 Subject: [PATCH] Add support for custom flow decorators to `prefect deploy` (#14782) --- src/prefect/cli/deploy.py | 17 +- src/prefect/flows.py | 159 ++++++++- src/prefect/utilities/callables.py | 8 +- src/prefect/utilities/importtools.py | 196 +++++++---- tests/cli/test_deploy.py | 44 +++ .../wrapped-flow-project/__init__.py | 0 .../wrapped-flow-project/flow.py | 6 + .../wrapped-flow-project/missing_imports.py | 8 + .../wrapped-flow-project/utils.py | 11 + tests/test_flows.py | 317 +++++++++++++++++- 10 files changed, 682 insertions(+), 84 deletions(-) create mode 100644 tests/test-projects/wrapped-flow-project/__init__.py create mode 100644 tests/test-projects/wrapped-flow-project/flow.py create mode 100644 tests/test-projects/wrapped-flow-project/missing_imports.py create mode 100644 tests/test-projects/wrapped-flow-project/utils.py diff --git a/src/prefect/cli/deploy.py b/src/prefect/cli/deploy.py index 8b70c3ef4567..63fb15319c73 100644 --- a/src/prefect/cli/deploy.py +++ b/src/prefect/cli/deploy.py @@ -65,14 +65,14 @@ from prefect.deployments.steps.core import run_steps from prefect.events import DeploymentTriggerTypes, TriggerTypes from prefect.exceptions import ObjectNotFound, PrefectHTTPStatusError -from prefect.flows import load_flow_arguments_from_entrypoint +from prefect.flows import load_flow_from_entrypoint from prefect.settings import ( PREFECT_DEFAULT_WORK_POOL_NAME, PREFECT_UI_URL, ) from prefect.utilities.annotations import NotSet from prefect.utilities.callables import ( - parameter_schema_from_entrypoint, + parameter_schema, ) from prefect.utilities.collections import get_from_dict from prefect.utilities.slugify import slugify @@ -481,20 +481,17 @@ async def _run_single_deploy( ) deploy_config["entrypoint"] = await prompt_entrypoint(app.console) - flow_decorator_arguments = load_flow_arguments_from_entrypoint( - deploy_config["entrypoint"] - ) + flow = load_flow_from_entrypoint(deploy_config["entrypoint"]) + + deploy_config["flow_name"] = flow.name - deploy_config["flow_name"] = flow_decorator_arguments["name"] deployment_name = deploy_config.get("name") if not deployment_name: if not is_interactive(): raise ValueError("A deployment name must be provided.") deploy_config["name"] = prompt("Deployment name", default="default") - deploy_config["parameter_openapi_schema"] = parameter_schema_from_entrypoint( - deploy_config["entrypoint"] - ) + deploy_config["parameter_openapi_schema"] = parameter_schema(flow) deploy_config["schedules"] = _construct_schedules( deploy_config, @@ -675,7 +672,7 @@ async def _run_single_deploy( deploy_config["work_pool"]["job_variables"]["image"] = "{{ build-image.image }}" if not deploy_config.get("description"): - deploy_config["description"] = flow_decorator_arguments.get("description") + deploy_config["description"] = flow.description # save deploy_config before templating deploy_config_before_templating = deepcopy(deploy_config) diff --git a/src/prefect/flows.py b/src/prefect/flows.py index 0fbfe5ed5451..ec19bd5a56e6 100644 --- a/src/prefect/flows.py +++ b/src/prefect/flows.py @@ -1676,6 +1676,7 @@ def load_flow_from_entrypoint( if ":" in entrypoint: # split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff path, func_name = entrypoint.rsplit(":", maxsplit=1) + else: path, func_name = entrypoint.rsplit(".", maxsplit=1) try: @@ -1684,15 +1685,13 @@ def load_flow_from_entrypoint( raise MissingFlowError( f"Flow function with name {func_name!r} not found in {path!r}. " ) from exc - except ScriptError as exc: + except ScriptError: # If the flow has dependencies that are not installed in the current - # environment, fallback to loading the flow via AST parsing. The - # drawback of this approach is that we're unable to actually load the - # function, so we create a placeholder flow that will re-raise this - # exception when called. - + # environment, fallback to loading the flow via AST parsing. if use_placeholder_flow: - flow = load_placeholder_flow(entrypoint=entrypoint, raises=exc) + flow = safe_load_flow_from_entrypoint(entrypoint) + if flow is None: + raise else: raise @@ -1855,6 +1854,147 @@ async def async_placeholder_flow(*args, **kwargs): return Flow(**arguments) +def safe_load_flow_from_entrypoint(entrypoint: str) -> Optional[Flow]: + """ + Load a flow from an entrypoint and return None if an exception is raised. + + Args: + entrypoint: a string in the format `:` + or a module path to a flow function + """ + func_def, source_code = _entrypoint_definition_and_source(entrypoint) + path = None + if ":" in entrypoint: + path = entrypoint.rsplit(":")[0] + namespace = safe_load_namespace(source_code, filepath=path) + if func_def.name in namespace: + return namespace[func_def.name] + else: + # If the function is not in the namespace, if may be due to missing dependencies + # for the function. We will attempt to compile each annotation and default value + # and remove them from the function definition to see if the function can be + # compiled without them. + + return _sanitize_and_load_flow(func_def, namespace) + + +def _sanitize_and_load_flow( + func_def: Union[ast.FunctionDef, ast.AsyncFunctionDef], namespace: Dict[str, Any] +) -> Optional[Flow]: + """ + Attempt to load a flow from the function definition after sanitizing the annotations + and defaults that can't be compiled. + + Args: + func_def: the function definition + namespace: the namespace to load the function into + + Returns: + The loaded function or None if the function can't be loaded + after sanitizing the annotations and defaults. + """ + args = func_def.args.posonlyargs + func_def.args.args + func_def.args.kwonlyargs + if func_def.args.vararg: + args.append(func_def.args.vararg) + if func_def.args.kwarg: + args.append(func_def.args.kwarg) + # Remove annotations that can't be compiled + for arg in args: + if arg.annotation is not None: + try: + code = compile( + ast.Expression(arg.annotation), + filename="", + mode="eval", + ) + exec(code, namespace) + except Exception as e: + logger.debug( + "Failed to evaluate annotation for argument %s due to the following error. Ignoring annotation.", + arg.arg, + exc_info=e, + ) + arg.annotation = None + + # Remove defaults that can't be compiled + new_defaults = [] + for default in func_def.args.defaults: + try: + code = compile(ast.Expression(default), "", "eval") + exec(code, namespace) + new_defaults.append(default) + except Exception as e: + logger.debug( + "Failed to evaluate default value %s due to the following error. Ignoring default.", + default, + exc_info=e, + ) + new_defaults.append( + ast.Constant( + value=None, lineno=default.lineno, col_offset=default.col_offset + ) + ) + func_def.args.defaults = new_defaults + + # Remove kw_defaults that can't be compiled + new_kw_defaults = [] + for default in func_def.args.kw_defaults: + if default is not None: + try: + code = compile(ast.Expression(default), "", "eval") + exec(code, namespace) + new_kw_defaults.append(default) + except Exception as e: + logger.debug( + "Failed to evaluate default value %s due to the following error. Ignoring default.", + default, + exc_info=e, + ) + new_kw_defaults.append( + ast.Constant( + value=None, + lineno=default.lineno, + col_offset=default.col_offset, + ) + ) + else: + new_kw_defaults.append( + ast.Constant( + value=None, + lineno=func_def.lineno, + col_offset=func_def.col_offset, + ) + ) + func_def.args.kw_defaults = new_kw_defaults + + if func_def.returns is not None: + try: + code = compile( + ast.Expression(func_def.returns), filename="", mode="eval" + ) + exec(code, namespace) + except Exception as e: + logger.debug( + "Failed to evaluate return annotation due to the following error. Ignoring annotation.", + exc_info=e, + ) + func_def.returns = None + + # Attempt to compile the function without annotations and defaults that + # can't be compiled + try: + code = compile( + ast.Module(body=[func_def], type_ignores=[]), + filename="", + mode="exec", + ) + exec(code, namespace) + except Exception as e: + logger.debug("Failed to compile: %s", e) + else: + return namespace.get(func_def.name) + + def load_flow_arguments_from_entrypoint( entrypoint: str, arguments: Optional[Union[List[str], Set[str]]] = None ) -> Dict[str, Any]: @@ -1870,6 +2010,9 @@ def load_flow_arguments_from_entrypoint( """ func_def, source_code = _entrypoint_definition_and_source(entrypoint) + path = None + if ":" in entrypoint: + path = entrypoint.rsplit(":")[0] if arguments is None: # If no arguments are provided default to known arguments that are of @@ -1905,7 +2048,7 @@ def load_flow_arguments_from_entrypoint( # if the arg value is not a raw str (i.e. a variable or expression), # then attempt to evaluate it - namespace = safe_load_namespace(source_code) + namespace = safe_load_namespace(source_code, filepath=path) literal_arg_value = ast.get_source_segment(source_code, keyword.value) cleaned_value = ( literal_arg_value.replace("\n", "") if literal_arg_value else "" diff --git a/src/prefect/utilities/callables.py b/src/prefect/utilities/callables.py index fd09949b7f25..9c0fe08c7ac2 100644 --- a/src/prefect/utilities/callables.py +++ b/src/prefect/utilities/callables.py @@ -346,17 +346,19 @@ def parameter_schema_from_entrypoint(entrypoint: str) -> ParameterSchema: Returns: ParameterSchema: The parameter schema for the function. """ + filepath = None if ":" in entrypoint: # split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff path, func_name = entrypoint.rsplit(":", maxsplit=1) source_code = Path(path).read_text() + filepath = path else: path, func_name = entrypoint.rsplit(".", maxsplit=1) spec = importlib.util.find_spec(path) if not spec or not spec.origin: raise ValueError(f"Could not find module {path!r}") source_code = Path(spec.origin).read_text() - signature = _generate_signature_from_source(source_code, func_name) + signature = _generate_signature_from_source(source_code, func_name, filepath) docstring = _get_docstring_from_source(source_code, func_name) return generate_parameter_schema(signature, parameter_docstrings(docstring)) @@ -424,7 +426,7 @@ def raise_for_reserved_arguments(fn: Callable, reserved_arguments: Iterable[str] def _generate_signature_from_source( - source_code: str, func_name: str + source_code: str, func_name: str, filepath: Optional[str] = None ) -> inspect.Signature: """ Extract the signature of a function from its source code. @@ -440,7 +442,7 @@ def _generate_signature_from_source( """ # Load the namespace from the source code. Missing imports and exceptions while # loading local class definitions are ignored. - namespace = safe_load_namespace(source_code) + namespace = safe_load_namespace(source_code, filepath=filepath) # Parse the source code into an AST parsed_code = ast.parse(source_code) diff --git a/src/prefect/utilities/importtools.py b/src/prefect/utilities/importtools.py index 5e364e37b49f..92f7fbf09c26 100644 --- a/src/prefect/utilities/importtools.py +++ b/src/prefect/utilities/importtools.py @@ -362,79 +362,159 @@ def exec_module(self, _: ModuleType) -> None: sys.modules[self.alias] = root_module -def safe_load_namespace(source_code: str): +def safe_load_namespace( + source_code: str, filepath: Optional[str] = None +) -> Dict[str, Any]: """ - Safely load a namespace from source code. + Safely load a namespace from source code, optionally handling relative imports. - This function will attempt to import all modules and classes defined in the source - code. If an import fails, the error is caught and the import is skipped. This function - will also attempt to compile and evaluate class and function definitions locally. + If a `filepath` is provided, `sys.path` is modified to support relative imports. + Changes to `sys.path` are reverted after completion, but this function is not thread safe + and use of it in threaded contexts may result in undesirable behavior. Args: source_code: The source code to load + filepath: Optional file path of the source code. If provided, enables relative imports. Returns: - The namespace loaded from the source code. Can be used when evaluating source - code. + The namespace loaded from the source code. """ parsed_code = ast.parse(source_code) - namespace = {"__name__": "prefect_safe_namespace_loader"} + namespace: Dict[str, Any] = {"__name__": "prefect_safe_namespace_loader"} - # Remove the body of the if __name__ == "__main__": block from the AST to prevent - # execution of guarded code - new_body = [] - for node in parsed_code.body: - if _is_main_block(node): - continue - new_body.append(node) + # Remove the body of the if __name__ == "__main__": block + new_body = [node for node in parsed_code.body if not _is_main_block(node)] parsed_code.body = new_body - # Walk through the AST and find all import statements - for node in ast.walk(parsed_code): - if isinstance(node, ast.Import): - for alias in node.names: - module_name = alias.name - as_name = alias.asname if alias.asname else module_name - try: - # Attempt to import the module - namespace[as_name] = importlib.import_module(module_name) - logger.debug("Successfully imported %s", module_name) - except ImportError as e: - logger.debug(f"Failed to import {module_name}: {e}") - elif isinstance(node, ast.ImportFrom): - module_name = node.module - if module_name is None: - continue - try: - module = importlib.import_module(module_name) + temp_module = None + original_sys_path = None + + if filepath: + # Setup for relative imports + file_dir = os.path.dirname(os.path.abspath(filepath)) + package_name = os.path.basename(file_dir) + parent_dir = os.path.dirname(file_dir) + + # Save original sys.path and modify it + original_sys_path = sys.path.copy() + sys.path.insert(0, parent_dir) + + # Create a temporary module for import context + temp_module = ModuleType(package_name) + temp_module.__file__ = filepath + temp_module.__package__ = package_name + + # Create a spec for the module + temp_module.__spec__ = ModuleSpec(package_name, None) + temp_module.__spec__.loader = None + temp_module.__spec__.submodule_search_locations = [file_dir] + + try: + for node in parsed_code.body: + if isinstance(node, ast.Import): for alias in node.names: - name = alias.name - asname = alias.asname if alias.asname else name + module_name = alias.name + as_name = alias.asname or module_name try: - # Get the specific attribute from the module - attribute = getattr(module, name) - namespace[asname] = attribute - except AttributeError as e: - logger.debug( - "Failed to retrieve %s from %s: %s", name, module_name, e - ) - except ImportError as e: - logger.debug("Failed to import from %s: %s", node.module, e) - - # Handle local definitions - for node in parsed_code.body: - if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.Assign)): - try: - # Compile and execute each class and function definition and assignment - code = compile( - ast.Module(body=[node], type_ignores=[]), - filename="", - mode="exec", - ) - exec(code, namespace) - except Exception as e: - logger.debug("Failed to compile: %s", e) + namespace[as_name] = importlib.import_module(module_name) + logger.debug("Successfully imported %s", module_name) + except ImportError as e: + logger.debug(f"Failed to import {module_name}: {e}") + elif isinstance(node, ast.ImportFrom): + module_name = node.module or "" + if filepath: + try: + if node.level > 0: + # For relative imports, use the parent package to inform the import + package_parts = temp_module.__package__.split(".") + if len(package_parts) < node.level: + raise ImportError( + "Attempted relative import beyond top-level package" + ) + parent_package = ".".join( + package_parts[: (1 - node.level)] + if node.level > 1 + else package_parts + ) + module = importlib.import_module( + f".{module_name}" if module_name else "", + package=parent_package, + ) + else: + # Absolute imports are handled as normal + module = importlib.import_module(module_name) + + for alias in node.names: + name = alias.name + asname = alias.asname or name + if name == "*": + # Handle 'from module import *' + module_dict = { + k: v + for k, v in module.__dict__.items() + if not k.startswith("_") + } + namespace.update(module_dict) + else: + try: + attribute = getattr(module, name) + namespace[asname] = attribute + except AttributeError as e: + logger.debug( + "Failed to retrieve %s from %s: %s", + name, + module_name, + e, + ) + except ImportError as e: + logger.debug("Failed to import from %s: %s", module_name, e) + else: + # Handle as absolute import when no filepath is provided + try: + module = importlib.import_module(module_name) + for alias in node.names: + name = alias.name + asname = alias.asname or name + if name == "*": + # Handle 'from module import *' + module_dict = { + k: v + for k, v in module.__dict__.items() + if not k.startswith("_") + } + namespace.update(module_dict) + else: + try: + attribute = getattr(module, name) + namespace[asname] = attribute + except AttributeError as e: + logger.debug( + "Failed to retrieve %s from %s: %s", + name, + module_name, + e, + ) + except ImportError as e: + logger.debug("Failed to import from %s: %s", module_name, e) + # Handle local definitions + for node in parsed_code.body: + if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.Assign)): + try: + code = compile( + ast.Module(body=[node], type_ignores=[]), + filename="", + mode="exec", + ) + exec(code, namespace) + except Exception as e: + logger.debug("Failed to compile: %s", e) + + finally: + # Restore original sys.path if it was modified + if original_sys_path: + sys.path[:] = original_sys_path + return namespace diff --git a/tests/cli/test_deploy.py b/tests/cli/test_deploy.py index 32390cdfb999..ea3423413f98 100644 --- a/tests/cli/test_deploy.py +++ b/tests/cli/test_deploy.py @@ -291,6 +291,50 @@ async def test_project_deploy(self, project_dir, prefect_client): assert deployment.job_variables == {"env": "prod"} assert deployment.enforce_parameter_schema is False + async def test_deploy_with_wrapped_flow_decorator( + self, project_dir, work_pool, prefect_client + ): + await run_sync_in_worker_thread( + invoke_and_assert, + command=( + f"deploy ./wrapped-flow-project/flow.py:test_flow -n test-name -p {work_pool.name}" + ), + expected_code=0, + expected_output_does_not_contain=["test-flow"], + expected_output_contains=[ + "wrapped-flow/test-name", + f"prefect worker start --pool '{work_pool.name}'", + ], + ) + + deployment = await prefect_client.read_deployment_by_name( + "wrapped-flow/test-name" + ) + assert deployment.name == "test-name" + assert deployment.work_pool_name == work_pool.name + + async def test_deploy_with_missing_imports( + self, project_dir, work_pool, prefect_client + ): + await run_sync_in_worker_thread( + invoke_and_assert, + command=( + f"deploy ./wrapped-flow-project/missing_imports.py:bloop_flow -n test-name -p {work_pool.name}" + ), + expected_code=0, + expected_output_does_not_contain=["test-flow"], + expected_output_contains=[ + "wrapped-flow/test-name", + f"prefect worker start --pool '{work_pool.name}'", + ], + ) + + deployment = await prefect_client.read_deployment_by_name( + "wrapped-flow/test-name" + ) + assert deployment.name == "test-name" + assert deployment.work_pool_name == work_pool.name + async def test_project_deploy_with_default_work_pool( self, project_dir, prefect_client ): diff --git a/tests/test-projects/wrapped-flow-project/__init__.py b/tests/test-projects/wrapped-flow-project/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test-projects/wrapped-flow-project/flow.py b/tests/test-projects/wrapped-flow-project/flow.py new file mode 100644 index 000000000000..da18ddf8346a --- /dev/null +++ b/tests/test-projects/wrapped-flow-project/flow.py @@ -0,0 +1,6 @@ +from .utils import pipeline_flow + + +@pipeline_flow(name="wrapped-flow") +def test_flow(): + return "I'm a pipeline flow!" diff --git a/tests/test-projects/wrapped-flow-project/missing_imports.py b/tests/test-projects/wrapped-flow-project/missing_imports.py new file mode 100644 index 000000000000..6df54760fbea --- /dev/null +++ b/tests/test-projects/wrapped-flow-project/missing_imports.py @@ -0,0 +1,8 @@ +import bloop + +from .utils import pipeline_flow + + +@pipeline_flow(name="wrapped-flow") +def bloop_flow(bloop: bloop.Bloop = bloop.DEFAULT): + return "I do a thing with bloop!" diff --git a/tests/test-projects/wrapped-flow-project/utils.py b/tests/test-projects/wrapped-flow-project/utils.py new file mode 100644 index 000000000000..b6f0bf5126dd --- /dev/null +++ b/tests/test-projects/wrapped-flow-project/utils.py @@ -0,0 +1,11 @@ +from prefect import flow + + +def pipeline_flow(name=None): + def setup(__fn): + return flow( + __fn, + name=name, + ) + + return setup diff --git a/tests/test_flows.py b/tests/test_flows.py index a7e1e5178302..544245b63220 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -25,6 +25,7 @@ else: import pydantic +import pendulum import pytest import regex as re @@ -52,6 +53,7 @@ Flow, load_flow_arguments_from_entrypoint, load_flow_from_entrypoint, + safe_load_flow_from_entrypoint, ) from prefect.runtime import flow_run as flow_run_ctx from prefect.server.schemas.core import TaskRunResult @@ -2313,9 +2315,8 @@ def dog(): assert flow.name == "dog" assert flow.description == "Says woof!" - # But if the flow is called, it should raise the ScriptError - with pytest.raises(ScriptError): - flow.fn() + # Should still be callable + assert flow() == "woof!" async def test_handling_script_with_unprotected_call_in_flow_script( self, @@ -2367,8 +2368,7 @@ def dog(): # Test with use_placeholder_flow=True (default behavior) flow = load_flow_from_entrypoint(f"{fpath}:dog") assert isinstance(flow, Flow) - with pytest.raises(ScriptError): - flow.fn() + assert flow() == "woof!" # Test with use_placeholder_flow=False with pytest.raises(ScriptError): @@ -4107,3 +4107,310 @@ def test_load_no_flow(self, tmp_path: Path): with pytest.raises(ValueError, match="Could not find flow"): load_flow_arguments_from_entrypoint(entrypoint) + + +class TestSafeLoadFlowFromEntrypoint: + def test_flow_not_found(self, tmp_path: Path): + source_code = dedent( + """ + from prefect import flow + + @flow + def f(): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + + with pytest.raises(ValueError): + safe_load_flow_from_entrypoint(f"{tmp_path}/test.py:g") + + def test_basic_operation(self, tmp_path: Path): + flow_source = dedent( + ''' + + from prefect import flow + + @flow(name="My custom name") + def flow_function(name: str) -> str: + """ + My docstring + + Args: + name (str): A name + """ + return name + ''' + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + result = safe_load_flow_from_entrypoint(entrypoint) + + assert result is not None + assert isinstance(result, Flow) + assert result.name == "My custom name" + assert result("marvin") == "marvin" + assert result.__doc__ is not None + assert "My docstring" in result.__doc__ + assert "Args:" in result.__doc__ + assert "name (str): A name" in result.__doc__ + + def test_get_parameter_schema_from_safe_loaded_flow(self, tmp_path: Path): + flow_source = dedent( + """ + + from prefect import flow + + @flow + def flow_function(name: str) -> str: + return name + """ + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + result = safe_load_flow_from_entrypoint(entrypoint) + + assert result is not None + assert parameter_schema(result).dict() == { + "properties": {"name": {"position": 0, "title": "name", "type": "string"}}, + "required": ["name"], + "title": "Parameters", + "type": "object", + } + + def test_dynamic_name_fstring(self, tmp_path: Path): + flow_source = dedent( + """ + + from prefect import flow + + version = "1.0" + + @flow(name=f"flow-function-{version}") + def flow_function(name: str) -> str: + return name + """ + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + result = safe_load_flow_from_entrypoint(entrypoint) + + assert result is not None + assert result.name == "flow-function-1.0" + + def test_dynamic_name_function(self, tmp_path: Path): + flow_source = dedent( + """ + + from prefect import flow + + def get_name(): + return "from-a-function" + + @flow(name=get_name()) + def flow_function(name: str) -> str: + return name + """ + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + result = safe_load_flow_from_entrypoint(entrypoint) + + assert result is not None + + def test_dynamic_name_depends_on_missing_import(self, tmp_path: Path): + flow_source = dedent( + """ + + from prefect import flow + + from non_existent import get_name + + @flow(name=get_name()) + def flow_function(name: str) -> str: + return name + """ + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + result = safe_load_flow_from_entrypoint(entrypoint) + + # We expect this to be None because the flow function cannot be loaded + assert result is None + + def test_annotations_and_defaults_rely_on_imports(self, tmp_path: Path): + source_code = dedent( + """ + import pendulum + import datetime + from prefect import flow + + @flow(validate_parameters=False) + def f( + x: datetime.datetime, + y: pendulum.DateTime = pendulum.datetime(2025, 1, 1), + z: datetime.timedelta = datetime.timedelta(seconds=5), + ): + return x, y, z + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + result = safe_load_flow_from_entrypoint(f"{tmp_path}/test.py:f") + assert result is not None + assert result(datetime.datetime(2025, 1, 1)) == ( + datetime.datetime(2025, 1, 1), + pendulum.datetime(2025, 1, 1), + datetime.timedelta(seconds=5), + ) + + def test_annotations_rely_on_missing_import(self, tmp_path: Path): + """ + This test ensures missing types for annotations are handled gracefully + for all argument types (positional-only, positional-or-keyword, + keyword-only, varargs, and varkwargs). + """ + flow_source = dedent( + """ + + from prefect import flow + from typing import Dict, Tuple + + from non_existent import Type1, Type2, Type3, Type4, Type5 + + @flow + def flow_function(x: Type1, /, y: Type2, *args: Type4, z: Type3, **kwargs: Type5) -> str: + return x, y, z, args, kwargs + """ + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + result = safe_load_flow_from_entrypoint(entrypoint) + assert result is not None + assert result(1, 2, 4, z=3, a=5) == (1, 2, 3, (4,), {"a": 5}) + + def test_defaults_rely_on_missing_import(self, tmp_path: Path): + flow_source = dedent( + """ + + from prefect import flow + + from non_existent import DEFAULT_NAME, DEFAULT_AGE + + @flow + def flow_function(name = DEFAULT_NAME, age = DEFAULT_AGE) -> str: + return name, age + """ + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + result = safe_load_flow_from_entrypoint(entrypoint) + assert result is not None + assert result() == (None, None) + + def test_function_with_enum_argument(self, tmp_path: Path): + class Color(enum.Enum): + RED = "RED" + GREEN = "GREEN" + BLUE = "BLUE" + + source_code = dedent( + """ + from enum import Enum + + from prefect import flow + + class Color(Enum): + RED = "RED" + GREEN = "GREEN" + BLUE = "BLUE" + + @flow + def f(x: Color = Color.RED): + return x + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + + entrypoint = f"{tmp_path.joinpath('test.py')}:f" + + result = safe_load_flow_from_entrypoint(entrypoint) + assert result is not None + assert result().value == Color.RED.value + + def test_handles_dynamically_created_models(self, tmp_path: Path): + source_code = dedent( + """ + from typing import Optional + from prefect import flow + from pydantic import BaseModel, create_model, Field + + + def get_model() -> BaseModel: + return create_model( + "MyModel", + param=( + int, + Field( + title="param", + default=1, + ), + ), + ) + + + MyModel = get_model() + + + @flow + def f( + param: Optional[MyModel] = None, + ) -> None: + return MyModel() + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + entrypoint = f"{tmp_path.joinpath('test.py')}:f" + + result = safe_load_flow_from_entrypoint(entrypoint) + assert result is not None + assert result().param == 1 + + def test_raises_name_error_when_loaded_flow_cannot_run(self, tmp_path): + source_code = dedent( + """ + from not_a_module import not_a_function + + from prefect import flow + + @flow(description="Says woof!") + def dog(): + return not_a_function('dog') + """ + ) + + tmp_path.joinpath("test.py").write_text(source_code) + entrypoint = f"{tmp_path.joinpath('test.py')}:dog" + + with pytest.raises(NameError, match="name 'not_a_function' is not defined"): + safe_load_flow_from_entrypoint(entrypoint)()