Skip to content

Commit

Permalink
A couple minor typing fixes (#14859)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Aug 6, 2024
1 parent 853a01e commit aea46a8
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions src/prefect/cli/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
import functools
import sys
from typing import List, Optional
from typing import Callable, List, Optional

import typer
from rich.console import Console
Expand Down Expand Up @@ -96,7 +96,7 @@ def add_typer(
typer_instance: "PrefectTyper",
*args,
no_args_is_help: bool = True,
aliases: List[str] = None,
aliases: Optional[List[str]] = None,
**kwargs,
) -> None:
"""
Expand All @@ -121,7 +121,7 @@ def command(
self,
name: Optional[str] = None,
*args,
aliases: List[str] = None,
aliases: Optional[List[str]] = None,
deprecated: bool = False,
deprecated_start_date: Optional[str] = None,
deprecated_help: str = "",
Expand All @@ -136,7 +136,7 @@ def command(
`deprecated_name` and `deprecated_start_date` must be provided.
"""

def wrapper(fn):
def wrapper(original_fn: Callable):
# click doesn't support async functions, so we wrap them in
# asyncio.run(). This has the advantage of keeping the function in
# the main thread, which means signal handling works for e.g. the
Expand All @@ -145,16 +145,19 @@ def wrapper(fn):
# can not be called nested). In that (rare) circumstance, refactor
# the CLI command so its business logic can be invoked separately
# from its entrypoint.
if is_async_fn(fn):
_fn = fn
if is_async_fn(original_fn):
async_fn = original_fn

@functools.wraps(fn)
def fn(*args, **kwargs):
return asyncio.run(_fn(*args, **kwargs))
@functools.wraps(original_fn)
def sync_fn(*args, **kwargs):
return asyncio.run(async_fn(*args, **kwargs))

fn.aio = _fn
sync_fn.aio = async_fn
wrapped_fn = sync_fn
else:
wrapped_fn = original_fn

fn = with_cli_exception_handling(fn)
wrapped_fn = with_cli_exception_handling(wrapped_fn)
if deprecated:
if not deprecated_name or not deprecated_start_date:
raise ValueError(
Expand All @@ -165,15 +168,19 @@ def fn(*args, **kwargs):
start_date=deprecated_start_date,
help=deprecated_help,
)
fn = with_deprecated_message(command_deprecated_message)(fn)
wrapped_fn = with_deprecated_message(command_deprecated_message)(
wrapped_fn
)
elif self.deprecated:
fn = with_deprecated_message(self.deprecated_message)(fn)
wrapped_fn = with_deprecated_message(self.deprecated_message)(
wrapped_fn
)

# register fn with its original name
command_decorator = super(PrefectTyper, self).command(
name=name, *args, **kwargs
)
original_command = command_decorator(fn)
original_command = command_decorator(wrapped_fn)

# register fn for each alias, e.g. @marvin_app.command(aliases=["r"])
if aliases:
Expand All @@ -182,7 +189,7 @@ def fn(*args, **kwargs):
name=alias,
*args,
**{k: v for k, v in kwargs.items() if k != "aliases"},
)(fn)
)(wrapped_fn)

return original_command

Expand Down

0 comments on commit aea46a8

Please sign in to comment.