Skip to content

Commit

Permalink
Add Block Type registration for its nested block types in field annot…
Browse files Browse the repository at this point in the history
…ations (#15032)

Co-authored-by: Ladislav Gál <[email protected]>
Co-authored-by: Ladislav Gál <[email protected]>
  • Loading branch information
3 people authored Aug 22, 2024
1 parent a1286e1 commit b8c27aa
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 31 deletions.
69 changes: 39 additions & 30 deletions src/prefect/blocks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,20 @@ def _is_subclass(cls, parent_cls) -> bool:
Checks if a given class is a subclass of another class. Unlike issubclass,
this will not throw an exception if cls is an instance instead of a type.
"""
return inspect.isclass(cls) and issubclass(cls, parent_cls)
# For python<=3.11 inspect.isclass() will return True for parametrized types (e.g. list[str])
# so we need to check for get_origin() to avoid TypeError for issubclass.
return inspect.isclass(cls) and not get_origin(cls) and issubclass(cls, parent_cls)


def _collect_secret_fields(name: str, type_: Type, secrets: List[str]) -> None:
"""
Recursively collects all secret fields from a given type and adds them to the
secrets list, supporting nested Union / BaseModel fields. Also, note, this function
mutates the input secrets list, thus does not return anything.
secrets list, supporting nested Union / Dict / Tuple / List / BaseModel fields.
Also, note, this function mutates the input secrets list, thus does not return anything.
"""
if get_origin(type_) is Union:
for union_type in get_args(type_):
_collect_secret_fields(name, union_type, secrets)
if get_origin(type_) in (Union, dict, list, tuple):
for nested_type in get_args(type_):
_collect_secret_fields(name, nested_type, secrets)
return
elif _is_subclass(type_, BaseModel):
for field in type_.__fields__.values():
Expand Down Expand Up @@ -239,25 +241,29 @@ def schema_extra(schema: Dict[str, Any], model: Type["Block"]):

# create block schema references
refs = schema["block_schema_references"] = {}

def collect_block_schema_references(
field_name: str, annotation: type
) -> None:
"""Walk through the annotation and collect block schemas for any nested blocks."""
if Block.is_block_class(annotation):
if isinstance(refs.get(field_name), list):
refs[field_name].append(
annotation._to_block_schema_reference_dict()
)
elif isinstance(refs.get(field_name), dict):
refs[field_name] = [
refs[field_name],
annotation._to_block_schema_reference_dict(),
]
else:
refs[field_name] = annotation._to_block_schema_reference_dict()
if get_origin(annotation) in (Union, list, tuple, dict):
for type_ in get_args(annotation):
collect_block_schema_references(field_name, type_)

for field in model.__fields__.values():
if Block.is_block_class(field.type_):
refs[field.name] = field.type_._to_block_schema_reference_dict()
if get_origin(field.type_) is Union:
for type_ in get_args(field.type_):
if Block.is_block_class(type_):
if isinstance(refs.get(field.name), list):
refs[field.name].append(
type_._to_block_schema_reference_dict()
)
elif isinstance(refs.get(field.name), dict):
refs[field.name] = [
refs[field.name],
type_._to_block_schema_reference_dict(),
]
else:
refs[
field.name
] = type_._to_block_schema_reference_dict()
collect_block_schema_references(field.name, field.type_)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -886,13 +892,16 @@ async def register_type_and_schema(cls, client: "PrefectClient" = None):
"subclass and not on a Block interface class directly."
)

async def register_blocks_in_annotation(annotation: type) -> None:
"""Walk through the annotation and register any nested blocks."""
if Block.is_block_class(annotation):
await annotation.register_type_and_schema(client=client)
elif get_origin(annotation) in (Union, tuple, list, dict):
for inner_annotation in get_args(annotation):
await register_blocks_in_annotation(inner_annotation)

for field in cls.__fields__.values():
if Block.is_block_class(field.type_):
await field.type_.register_type_and_schema(client=client)
if get_origin(field.type_) is Union:
for type_ in get_args(field.type_):
if Block.is_block_class(type_):
await type_.register_type_and_schema(client=client)
await register_blocks_in_annotation(field.annotation)

try:
block_type = await client.read_block_type_by_slug(
Expand Down
169 changes: 168 additions & 1 deletion tests/blocks/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import warnings
from textwrap import dedent
from typing import Dict, Type, Union
from typing import Dict, List, Tuple, Type, Union
from unittest.mock import Mock
from uuid import UUID, uuid4

Expand Down Expand Up @@ -1230,6 +1230,173 @@ class Umbrella(Block):
)
assert umbrella_block_schema is not None

async def test_register_nested_block_list(self, prefect_client: PrefectClient):
class A(Block):
a: str

class B(Block):
b: str

class ListCollection(Block):
a_list: List[A]
b_list: List[B]

await ListCollection.register_type_and_schema()

a_block_type = await prefect_client.read_block_type_by_slug(slug="a")
assert a_block_type is not None
b_block_type = await prefect_client.read_block_type_by_slug(slug="b")
assert b_block_type is not None

list_collection_block_type = await prefect_client.read_block_type_by_slug(
slug="listcollection"
)
assert list_collection_block_type is not None

a_block_schema = await prefect_client.read_block_schema_by_checksum(
checksum=A._calculate_schema_checksum()
)
assert a_block_schema is not None
b_block_schema = await prefect_client.read_block_schema_by_checksum(
checksum=B._calculate_schema_checksum()
)
assert b_block_schema is not None

list_collection_block_type = await prefect_client.read_block_schema_by_checksum(
checksum=ListCollection._calculate_schema_checksum()
)
assert list_collection_block_type is not None

async def test_register_nested_block_tuple(self, prefect_client: PrefectClient):
class A(Block):
a: str

class B(Block):
b: str

class C(Block):
c: str

class TupleCollection(Block):
a_tuple: Tuple[A]
b_tuple: Tuple[B, ...]
c_tuple: Tuple[C, ...]

await TupleCollection.register_type_and_schema()

a_block_type = await prefect_client.read_block_type_by_slug(slug="a")
assert a_block_type is not None
b_block_type = await prefect_client.read_block_type_by_slug(slug="b")
assert b_block_type is not None
c_block_type = await prefect_client.read_block_type_by_slug(slug="c")
assert c_block_type is not None

tuple_collection_block_type = await prefect_client.read_block_type_by_slug(
slug="tuplecollection"
)
assert tuple_collection_block_type is not None

a_block_schema = await prefect_client.read_block_schema_by_checksum(
checksum=A._calculate_schema_checksum()
)
assert a_block_schema is not None
b_block_schema = await prefect_client.read_block_schema_by_checksum(
checksum=B._calculate_schema_checksum()
)
assert b_block_schema is not None
c_block_schema = await prefect_client.read_block_schema_by_checksum(
checksum=C._calculate_schema_checksum()
)
assert c_block_schema is not None

tuple_collection_block_type = (
await prefect_client.read_block_schema_by_checksum(
checksum=TupleCollection._calculate_schema_checksum()
)
)
assert tuple_collection_block_type is not None

async def test_register_nested_block_dict(self, prefect_client: PrefectClient):
class A(Block):
a: str

class B(Block):
b: str

class DictCollection(Block):
block_dict: Dict[A, B]

await DictCollection.register_type_and_schema()

a_block_type = await prefect_client.read_block_type_by_slug(slug="a")
assert a_block_type is not None
b_block_type = await prefect_client.read_block_type_by_slug(slug="b")
assert b_block_type is not None

dict_collection_block_type = await prefect_client.read_block_type_by_slug(
slug="dictcollection"
)
assert dict_collection_block_type is not None

a_block_schema = await prefect_client.read_block_schema_by_checksum(
checksum=A._calculate_schema_checksum()
)
assert a_block_schema is not None
b_block_schema = await prefect_client.read_block_schema_by_checksum(
checksum=B._calculate_schema_checksum()
)
assert b_block_schema is not None

dict_collection_block_type = await prefect_client.read_block_schema_by_checksum(
checksum=DictCollection._calculate_schema_checksum()
)
assert dict_collection_block_type is not None

async def test_register_nested_block_type_nested(
self, prefect_client: PrefectClient
):
class A(Block):
a: str

class B(Block):
b: str

class C(Block):
c: str

class Depth(Block):
depths: List[Union[A, List[Union[B, Tuple[C, ...]]]]]

await Depth.register_type_and_schema()

a_block_type = await prefect_client.read_block_type_by_slug(slug="a")
assert a_block_type is not None
b_block_type = await prefect_client.read_block_type_by_slug(slug="b")
assert b_block_type is not None
c_block_type = await prefect_client.read_block_type_by_slug(slug="c")
assert c_block_type is not None

depth_block_type = await prefect_client.read_block_type_by_slug(slug="depth")
assert depth_block_type is not None

a_block_schema = await prefect_client.read_block_schema_by_checksum(
checksum=A._calculate_schema_checksum()
)
assert a_block_schema is not None
b_block_schema = await prefect_client.read_block_schema_by_checksum(
checksum=B._calculate_schema_checksum()
)
assert b_block_schema is not None
c_block_schema = await prefect_client.read_block_schema_by_checksum(
checksum=C._calculate_schema_checksum()
)
assert c_block_schema is not None

depth_block_type = await prefect_client.read_block_schema_by_checksum(
checksum=Depth._calculate_schema_checksum()
)
assert depth_block_type is not None

async def test_register_raises_block_base_class(self):
with pytest.raises(
InvalidBlockRegistration,
Expand Down

0 comments on commit b8c27aa

Please sign in to comment.