Skip to content

Commit

Permalink
Add built-in types for "dbt Classes" and refine type checking impleme…
Browse files Browse the repository at this point in the history
…ntation
  • Loading branch information
peterallenwebb committed Jan 13, 2025
1 parent 130c2d9 commit cf3617c
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 69 deletions.
4 changes: 2 additions & 2 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def parse_signature(self, node: Union[jinja2.nodes.Macro, jinja2.nodes.CallBlock
arg = self.parse_assign_target(name_only=True)
arg.set_ctx("param")

type_name: Optional[str]
type_name: Optional[MacroType]
if self.stream.skip_if("colon"):
node.has_type_annotations = True # type: ignore
type_name = self.parse_type_name()
else:
type_name = ""
type_name = None

node.arg_types.append(type_name) # type: ignore

Expand Down
184 changes: 131 additions & 53 deletions dbt_common/clients/jinja_macro_call.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
import dataclasses
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Iterable

import jinja2
import jinja2.nodes

from dbt_common.clients.jinja import get_environment, MacroType

PRIMITIVE_TYPES = ["Any", "bool", "float", "int", "str"]
DBT_CLASSES = ["Column", "Relation", "Result"]


class FailureType(Enum):
TYPE_MISMATCH = "mismatch"
UNKNOWN_TYPE = "unknown"
TYPE_MISMATCH = "type_mismatch"
UNKNOWN_TYPE = "unknown_type"
PARAMETER_COUNT = "param_count"
EXTRA_ARGUMENT = "extra_arg"
MISSING_ARGUMENT = "missing_arg"

@dataclasses.dataclass
class TypeCheckFailure:
type: FailureType
msg: str

@dataclasses.dataclass
class DbtMacroCall:
class MacroCallChecker:
"""An instance of this class represents a jinja macro call in a template
for the purposes of recording information for type checking."""

Expand All @@ -29,16 +34,134 @@ class DbtMacroCall:
kwarg_types: Dict[str, Optional[MacroType]] = dataclasses.field(default_factory=dict)

@classmethod
def from_call(cls, call: jinja2.nodes.Call, name: str) -> "DbtMacroCall":
def from_call(cls, call: jinja2.nodes.Call, name: str) -> "MacroCallChecker":
dbt_call = cls(name, "")
for arg in call.args: # type: ignore
dbt_call.arg_types.append(cls.get_type(arg))
dbt_call.arg_types.append(TypeChecker.get_type(arg))
for arg in call.kwargs: # type: ignore
dbt_call.kwarg_types[arg.key] = cls.get_type(arg.value)
dbt_call.kwarg_types[arg.key] = TypeChecker.get_type(arg.value)
return dbt_call

def check(self, macro_text: str) -> List[TypeCheckFailure]:
failures: List[TypeCheckFailure] = []

macro_checker = MacroChecker.from_jinja(macro_text)

unassigned_args = list(macro_checker.args)

# Each positional argument in this call should correspond to an expected
# positional argument with a compatible type.
for i, arg_type in enumerate(self.arg_types):
target_name = macro_checker.args[i]
target_type = macro_checker.arg_types[i]
unassigned_args.remove(target_name)
if arg_type is not None and target_type is not None and arg_type != target_type:
failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {target_type.name} for argument {target_name} but found {arg_type.name}/"))

# Each keyword argument in this call should correspond to an expected
# argument that has not already been assigned, and have a compatible type.
for arg_name, arg_type in self.kwarg_types.items():
if arg_name not in macro_checker.args:
failures.append(TypeCheckFailure(FailureType.EXTRA_ARGUMENT, f"Unexpected keyword argument {arg_name}."))
elif arg_name not in unassigned_args:
failures.append(TypeCheckFailure(FailureType.EXTRA_ARGUMENT, f"Argument {arg_name} was specified more than once."))
else:
unassigned_args.remove(arg_name)
expected_type = macro_checker.get_arg_type(arg_name)
if arg_type is not None and expected_type is not None and arg_type != expected_type:
failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {expected_type.name} for argument {arg_name} but found {arg_type.name}/"))

# Any remaining unassigned parameters must have a default.
for arg_name in unassigned_args:
if not macro_checker.has_default(arg_name):
failures.append(TypeCheckFailure(FailureType.MISSING_ARGUMENT, f"Missing argument {arg_name}."))

# Check that any arguments specified by keyword have the correct type
for arg_name, arg_type in self.kwarg_types.items():
expected_type = macro_checker.get_arg_type(arg_name)
if arg_type is not None and expected_type is not None and arg_type != expected_type:
failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {expected_type.name} as argument {arg_name} but found {arg_type.name}/"))

return failures


@dataclasses.dataclass
class MacroChecker:
_jinja_macro: jinja2.nodes.Macro

@property
def args(self) -> List[str]:
return [a.name for a in self._jinja_macro.args]

@property
def arg_types(self) -> List[Optional[MacroType]]:
return self._jinja_macro.arg_types # type: ignore

@property
def defaults(self) -> List[str]:
return self._jinja_macro.defaults

def get_arg_type(self, arg_name: str) -> Optional[MacroType]:
args = self.args
if arg_name not in args:
return None
else:
return self.arg_types[args.index(arg_name)]

def has_default(self, arg_name: str) -> bool:
args = self.args
return args.index(arg_name) >= len(self.args) - len(self.defaults)

@classmethod
def get_type(cls, param: Any) -> Optional[MacroType]:
def from_jinja(cls, jinja_text: str) -> "MacroChecker":
template = get_environment(None, capture_macros=True).parse(jinja_text)
jinja_macro = template.body[0]

if not isinstance(jinja_macro, jinja2.nodes.Macro):
raise Exception("Expected jinja macro.")

return MacroChecker(jinja_macro)

def type_check(self) -> List[TypeCheckFailure]:
# Every annotated parameter of the macro being called must have a valid
# type.
failures: List[TypeCheckFailure] = []
for arg_type in self._jinja_macro.arg_types: # type: ignore
failures = TypeChecker.check(arg_type)
if failures:
failures.extend(failures)

return failures


class TypeChecker:
@staticmethod
def check(t: Optional[MacroType]) -> List[TypeCheckFailure]:
if t is None or len(t.type_params) == 0 and t.name in (PRIMITIVE_TYPES + DBT_CLASSES):
return []

failures: List[TypeCheckFailure] = []
if t.name == "Dict":
if len(t.type_params) != 2:
failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected two type parameters for Dict[], found {len(t.type_params)}."))
else:
if t.type_params[0].name not in PRIMITIVE_TYPES:
failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, "First type parameter of Dict[] must be a primitive type."))

failures.extend(TypeChecker.check(t.type_params[1]))
elif t.name in ("List", "Optional"):
if len(t.type_params) != 1:
failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected one type parameter for {t.name}[], found {len(t.type_params)}."))

failures.extend(TypeChecker.check(t.type_params[0]))
else:
failures.append(TypeCheckFailure(FailureType.UNKNOWN_TYPE, f"Unknown type {t.name} encountered."))

return failures


@staticmethod
def get_type(param: Any) -> Optional[MacroType]:
if isinstance(param, jinja2.nodes.Name):
return None # TODO: infer types from variable names

Expand Down Expand Up @@ -69,48 +192,3 @@ def get_type(cls, param: Any) -> Optional[MacroType]:
return None

return None

def check_type(self, t: MacroType) -> List[TypeCheckFailure]:
if len(t.type_params) == 0 and t.name in PRIMITIVE_TYPES:
return []

failures: List[TypeCheckFailure] = []
if t.name == "Dict":
if len(t.type_params) != 2:
failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected two type parameters for Dict[], found {len(t.type_params)}."))
else:
if t.type_params[0].name not in PRIMITIVE_TYPES:
failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, "First type parameter of Dict[] must be a primitive type."))

failures.extend(self.check_type(t.type_params[1]))
elif t.name in ("List", "Optional"):
if len(t.type_params) != 1:
failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, "Expected one type parameter for {t.name}[], found {len(t.type_params)}."))

failures.extend(self.check_type(t.type_params[0]))

return failures

def check(self, macro_text: str) -> List[TypeCheckFailure]:
failures: List[TypeCheckFailure] = []
template = get_environment(None, capture_macros=True).parse(macro_text)
jinja_macro = template.body[0]

# This could be arguably be done elsewhere, but check that every
# parameter passed to the macro has a valid type.
for arg_type in jinja_macro.arg_types:
failures = self.check_type(arg_type)
if failures:
failures.extend(failures)

# Check that each positional argument matches the type of the
for i, arg_type in enumerate(self.arg_types):
expected_type = jinja_macro.arg_types[i]
if arg_type != expected_type:
failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {expected_type.name} as argument {i} but found {arg_type.name}/"))

# Check whether there were more positional arguments than expected.
if len(self.arg_types) > len(jinja_macro.arg_types):
failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected {len(self.arg_types)} type arguments, got {len(jinja_macro.arg_types)}."))

return failures
71 changes: 57 additions & 14 deletions tests/unit/test_jinja_macro_call.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
from dbt_common.clients.jinja import MacroType
from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DbtMacroCall, FailureType
from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DBT_CLASSES, FailureType, MacroCallChecker, MacroChecker

single_param_macro_text = """{% macro call_me(param: TYPE) %}
{% endmacro %}"""


def test_primitive_type_checks() -> None:
"""Test that primitive types can all be used to annotate macro parameters."""
for type_name in PRIMITIVE_TYPES:
macro_text = single_param_macro_text.replace("TYPE", type_name)
call = DbtMacroCall("call_me", "call_me", [MacroType(type_name, [])], {})
call = MacroCallChecker("call_me", "call_me", [MacroType(type_name, [])], {})
failures = call.check(macro_text)
assert not failures


def test_primitive_type_checks_wrong() -> None:
for type_name in PRIMITIVE_TYPES:
def test_dbt_class_type_checks() -> None:
"""Test that 'dbt Classes' like Relation, Column, and Result can all be used
to annotate macro parameters."""
for type_name in DBT_CLASSES:
macro_text = single_param_macro_text.replace("TYPE", type_name)
call = MacroCallChecker("call_me", "call_me", [MacroType(type_name, [])], {})
failures = call.check(macro_text)
assert not failures

def test_type_checks_wrong() -> None:
"""Test that calls to annotated macros with incorrect types fail type checks."""
for type_name in PRIMITIVE_TYPES + DBT_CLASSES:
macro_text = single_param_macro_text.replace("TYPE", type_name)
wrong_type = next(t for t in PRIMITIVE_TYPES if t != type_name)
call = DbtMacroCall("call_me", "call_me", [MacroType(wrong_type, [])], {})
call = MacroCallChecker("call_me", "call_me", [MacroType(wrong_type, [])], {})
failures = call.check(macro_text)
assert len([f for f in failures if f.type == FailureType.TYPE_MISMATCH]) == 1

Expand All @@ -26,7 +37,7 @@ def test_list_type_checks() -> None:
for type_name in PRIMITIVE_TYPES:
macro_text = single_param_macro_text.replace("TYPE", f"List[{type_name}]")
expected_type = MacroType("List", [MacroType(type_name)])
call = DbtMacroCall("call_me", "call_me", [expected_type], {})
call = MacroCallChecker("call_me", "call_me", [expected_type], {})
failures = call.check(macro_text)
assert not failures

Expand All @@ -35,17 +46,49 @@ def test_dict_type_checks() -> None:
for type_name in PRIMITIVE_TYPES:
macro_text = single_param_macro_text.replace("TYPE", f"Dict[{type_name}, {type_name}]")
expected_type = MacroType("Dict", [MacroType(type_name), MacroType(type_name)])
call = DbtMacroCall("call_me", "call_me", [expected_type], {})
call = MacroCallChecker("call_me", "call_me", [expected_type], {})
assert not any(call.check(macro_text))


kwarg_param_macro_text = """{% macro call_me(param: int = 10, arg_one = "val1", arg_two: int = 2, arg_three: str = "val3" ) %}
kwarg_param_macro_text = """{% macro call_me(arg1: int, arg2: int, arg3: str = "val3", arg4: int = 4, arg5: str = "val5") %}
{% endmacro %}"""


# Better structured exceptions
# Test detection of macro called with too few positional args
# Test detection of macro called with too many positional args
# Test detection of macro called with keyword arg having wrong type
# Test detection of macro called with non-existent keyword arg
# Test detection of macro with invalid default value for param type
def test_too_few_pos_args() -> None:
call = MacroCallChecker("call_me", "", [MacroType("int")])
failures = call.check(kwarg_param_macro_text)
assert len(failures) == 1
assert failures[0].type == FailureType.MISSING_ARGUMENT


def test_unknown_kwarg() -> None:
call = MacroCallChecker("call_me", "", [MacroType("int"), MacroType("int")], {"unk": MacroType("str")})
failures = call.check(kwarg_param_macro_text)
assert len(failures) == 1
assert failures[0].type == FailureType.EXTRA_ARGUMENT


def test_kwarg_type() -> None:
"""Test that annotated kwargs pass type checks when used by name."""
call = MacroCallChecker("call_me", "", [MacroType("int"), MacroType("int")], {"arg3": MacroType("str")})
failures = call.check(kwarg_param_macro_text)
assert not failures


def test_wrong_kwarg_type() -> None:
"""Test that annotated kwargs pass type checks fail when the wrong type is used."""
call = MacroCallChecker("call_me", "", [], {"arg3": MacroType("int")})
failures = call.check(kwarg_param_macro_text)
assert failures[0].type == FailureType.TYPE_MISMATCH

# TODO: Test detection of macro with invalid default value for param type
# TODO: Test detection of macro called with invalid variable parameter, as known from macro parameter annotation.


def test_unknown_type_check() -> None:
"""Test that macro parameter annotations with unknown types fail type checks."""
macro_text = single_param_macro_text.replace("TYPE", "Invalid")
checker = MacroChecker.from_jinja(macro_text)
failures = checker.type_check()
assert failures
assert any(f for f in failures if f.type == FailureType.UNKNOWN_TYPE)

0 comments on commit cf3617c

Please sign in to comment.