From b8c27aa06d9d3892cbb1d1d69e78e604a527adae Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Thu, 22 Aug 2024 15:13:55 -0500 Subject: [PATCH] Add Block Type registration for its nested block types in field annotations (#15032) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Ladislav Gál <129292521+GalLadislav@users.noreply.github.com> Co-authored-by: Ladislav Gál --- src/prefect/blocks/core.py | 69 ++++++++------- tests/blocks/test_core.py | 169 ++++++++++++++++++++++++++++++++++++- 2 files changed, 207 insertions(+), 31 deletions(-) diff --git a/src/prefect/blocks/core.py b/src/prefect/blocks/core.py index 71201860de65..e81b79d94eaa 100644 --- a/src/prefect/blocks/core.py +++ b/src/prefect/blocks/core.py @@ -118,18 +118,20 @@ def _is_subclass(cls, parent_cls) -> bool: Checks if a given class is a subclass of another class. Unlike issubclass, this will not throw an exception if cls is an instance instead of a type. """ - return inspect.isclass(cls) and issubclass(cls, parent_cls) + # For python<=3.11 inspect.isclass() will return True for parametrized types (e.g. list[str]) + # so we need to check for get_origin() to avoid TypeError for issubclass. + return inspect.isclass(cls) and not get_origin(cls) and issubclass(cls, parent_cls) def _collect_secret_fields(name: str, type_: Type, secrets: List[str]) -> None: """ Recursively collects all secret fields from a given type and adds them to the - secrets list, supporting nested Union / BaseModel fields. Also, note, this function - mutates the input secrets list, thus does not return anything. + secrets list, supporting nested Union / Dict / Tuple / List / BaseModel fields. + Also, note, this function mutates the input secrets list, thus does not return anything. """ - if get_origin(type_) is Union: - for union_type in get_args(type_): - _collect_secret_fields(name, union_type, secrets) + if get_origin(type_) in (Union, dict, list, tuple): + for nested_type in get_args(type_): + _collect_secret_fields(name, nested_type, secrets) return elif _is_subclass(type_, BaseModel): for field in type_.__fields__.values(): @@ -239,25 +241,29 @@ def schema_extra(schema: Dict[str, Any], model: Type["Block"]): # create block schema references refs = schema["block_schema_references"] = {} + + def collect_block_schema_references( + field_name: str, annotation: type + ) -> None: + """Walk through the annotation and collect block schemas for any nested blocks.""" + if Block.is_block_class(annotation): + if isinstance(refs.get(field_name), list): + refs[field_name].append( + annotation._to_block_schema_reference_dict() + ) + elif isinstance(refs.get(field_name), dict): + refs[field_name] = [ + refs[field_name], + annotation._to_block_schema_reference_dict(), + ] + else: + refs[field_name] = annotation._to_block_schema_reference_dict() + if get_origin(annotation) in (Union, list, tuple, dict): + for type_ in get_args(annotation): + collect_block_schema_references(field_name, type_) + for field in model.__fields__.values(): - if Block.is_block_class(field.type_): - refs[field.name] = field.type_._to_block_schema_reference_dict() - if get_origin(field.type_) is Union: - for type_ in get_args(field.type_): - if Block.is_block_class(type_): - if isinstance(refs.get(field.name), list): - refs[field.name].append( - type_._to_block_schema_reference_dict() - ) - elif isinstance(refs.get(field.name), dict): - refs[field.name] = [ - refs[field.name], - type_._to_block_schema_reference_dict(), - ] - else: - refs[ - field.name - ] = type_._to_block_schema_reference_dict() + collect_block_schema_references(field.name, field.type_) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -886,13 +892,16 @@ async def register_type_and_schema(cls, client: "PrefectClient" = None): "subclass and not on a Block interface class directly." ) + async def register_blocks_in_annotation(annotation: type) -> None: + """Walk through the annotation and register any nested blocks.""" + if Block.is_block_class(annotation): + await annotation.register_type_and_schema(client=client) + elif get_origin(annotation) in (Union, tuple, list, dict): + for inner_annotation in get_args(annotation): + await register_blocks_in_annotation(inner_annotation) + for field in cls.__fields__.values(): - if Block.is_block_class(field.type_): - await field.type_.register_type_and_schema(client=client) - if get_origin(field.type_) is Union: - for type_ in get_args(field.type_): - if Block.is_block_class(type_): - await type_.register_type_and_schema(client=client) + await register_blocks_in_annotation(field.annotation) try: block_type = await client.read_block_type_by_slug( diff --git a/tests/blocks/test_core.py b/tests/blocks/test_core.py index 96adca21f9e3..5e4470469343 100644 --- a/tests/blocks/test_core.py +++ b/tests/blocks/test_core.py @@ -2,7 +2,7 @@ import json import warnings from textwrap import dedent -from typing import Dict, Type, Union +from typing import Dict, List, Tuple, Type, Union from unittest.mock import Mock from uuid import UUID, uuid4 @@ -1230,6 +1230,173 @@ class Umbrella(Block): ) assert umbrella_block_schema is not None + async def test_register_nested_block_list(self, prefect_client: PrefectClient): + class A(Block): + a: str + + class B(Block): + b: str + + class ListCollection(Block): + a_list: List[A] + b_list: List[B] + + await ListCollection.register_type_and_schema() + + a_block_type = await prefect_client.read_block_type_by_slug(slug="a") + assert a_block_type is not None + b_block_type = await prefect_client.read_block_type_by_slug(slug="b") + assert b_block_type is not None + + list_collection_block_type = await prefect_client.read_block_type_by_slug( + slug="listcollection" + ) + assert list_collection_block_type is not None + + a_block_schema = await prefect_client.read_block_schema_by_checksum( + checksum=A._calculate_schema_checksum() + ) + assert a_block_schema is not None + b_block_schema = await prefect_client.read_block_schema_by_checksum( + checksum=B._calculate_schema_checksum() + ) + assert b_block_schema is not None + + list_collection_block_type = await prefect_client.read_block_schema_by_checksum( + checksum=ListCollection._calculate_schema_checksum() + ) + assert list_collection_block_type is not None + + async def test_register_nested_block_tuple(self, prefect_client: PrefectClient): + class A(Block): + a: str + + class B(Block): + b: str + + class C(Block): + c: str + + class TupleCollection(Block): + a_tuple: Tuple[A] + b_tuple: Tuple[B, ...] + c_tuple: Tuple[C, ...] + + await TupleCollection.register_type_and_schema() + + a_block_type = await prefect_client.read_block_type_by_slug(slug="a") + assert a_block_type is not None + b_block_type = await prefect_client.read_block_type_by_slug(slug="b") + assert b_block_type is not None + c_block_type = await prefect_client.read_block_type_by_slug(slug="c") + assert c_block_type is not None + + tuple_collection_block_type = await prefect_client.read_block_type_by_slug( + slug="tuplecollection" + ) + assert tuple_collection_block_type is not None + + a_block_schema = await prefect_client.read_block_schema_by_checksum( + checksum=A._calculate_schema_checksum() + ) + assert a_block_schema is not None + b_block_schema = await prefect_client.read_block_schema_by_checksum( + checksum=B._calculate_schema_checksum() + ) + assert b_block_schema is not None + c_block_schema = await prefect_client.read_block_schema_by_checksum( + checksum=C._calculate_schema_checksum() + ) + assert c_block_schema is not None + + tuple_collection_block_type = ( + await prefect_client.read_block_schema_by_checksum( + checksum=TupleCollection._calculate_schema_checksum() + ) + ) + assert tuple_collection_block_type is not None + + async def test_register_nested_block_dict(self, prefect_client: PrefectClient): + class A(Block): + a: str + + class B(Block): + b: str + + class DictCollection(Block): + block_dict: Dict[A, B] + + await DictCollection.register_type_and_schema() + + a_block_type = await prefect_client.read_block_type_by_slug(slug="a") + assert a_block_type is not None + b_block_type = await prefect_client.read_block_type_by_slug(slug="b") + assert b_block_type is not None + + dict_collection_block_type = await prefect_client.read_block_type_by_slug( + slug="dictcollection" + ) + assert dict_collection_block_type is not None + + a_block_schema = await prefect_client.read_block_schema_by_checksum( + checksum=A._calculate_schema_checksum() + ) + assert a_block_schema is not None + b_block_schema = await prefect_client.read_block_schema_by_checksum( + checksum=B._calculate_schema_checksum() + ) + assert b_block_schema is not None + + dict_collection_block_type = await prefect_client.read_block_schema_by_checksum( + checksum=DictCollection._calculate_schema_checksum() + ) + assert dict_collection_block_type is not None + + async def test_register_nested_block_type_nested( + self, prefect_client: PrefectClient + ): + class A(Block): + a: str + + class B(Block): + b: str + + class C(Block): + c: str + + class Depth(Block): + depths: List[Union[A, List[Union[B, Tuple[C, ...]]]]] + + await Depth.register_type_and_schema() + + a_block_type = await prefect_client.read_block_type_by_slug(slug="a") + assert a_block_type is not None + b_block_type = await prefect_client.read_block_type_by_slug(slug="b") + assert b_block_type is not None + c_block_type = await prefect_client.read_block_type_by_slug(slug="c") + assert c_block_type is not None + + depth_block_type = await prefect_client.read_block_type_by_slug(slug="depth") + assert depth_block_type is not None + + a_block_schema = await prefect_client.read_block_schema_by_checksum( + checksum=A._calculate_schema_checksum() + ) + assert a_block_schema is not None + b_block_schema = await prefect_client.read_block_schema_by_checksum( + checksum=B._calculate_schema_checksum() + ) + assert b_block_schema is not None + c_block_schema = await prefect_client.read_block_schema_by_checksum( + checksum=C._calculate_schema_checksum() + ) + assert c_block_schema is not None + + depth_block_type = await prefect_client.read_block_schema_by_checksum( + checksum=Depth._calculate_schema_checksum() + ) + assert depth_block_type is not None + async def test_register_raises_block_base_class(self): with pytest.raises( InvalidBlockRegistration,