generated from dbt-labs/dbt-oss-template
-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move some code from dbt-core to dbt-common.
- Loading branch information
1 parent
243568e
commit 3a20cdc
Showing
3 changed files
with
160 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import dataclasses | ||
from typing import Any, Dict, List, Optional | ||
|
||
import jinja2 | ||
|
||
from dbt_common.clients.jinja import get_environment, MacroType | ||
|
||
PRIMITIVE_TYPES = ["Any", "bool", "float", "int", "str"] | ||
|
||
|
||
@dataclasses.dataclass | ||
class TypeCheckFailure: | ||
msg: str | ||
|
||
|
||
@dataclasses.dataclass | ||
class DbtMacroCall: | ||
"""An instance of this class represents a jinja macro call in a template | ||
for the purposes of recording information for type checking.""" | ||
|
||
name: str | ||
source: str | ||
arg_types: List[Optional[MacroType]] = dataclasses.field(default_factory=list) | ||
kwarg_types: Dict[str, Optional[MacroType]] = dataclasses.field(default_factory=dict) | ||
|
||
@classmethod | ||
def from_call(cls, call: jinja2.nodes.Call, name: str) -> "DbtMacroCall": | ||
dbt_call = cls(name, "") | ||
for arg in call.args: # type: ignore | ||
dbt_call.arg_types.append(cls.get_type(arg)) | ||
for arg in call.kwargs: # type: ignore | ||
dbt_call.kwarg_types[arg.key] = cls.get_type(arg.value) | ||
return dbt_call | ||
|
||
@classmethod | ||
def get_type(cls, param: Any) -> Optional[MacroType]: | ||
if isinstance(param, jinja2.nodes.Name): | ||
return None # TODO: infer types from variable names | ||
|
||
if isinstance(param, jinja2.nodes.Call): | ||
return None # TODO: infer types from function/macro calls | ||
|
||
if isinstance(param, jinja2.nodes.Getattr): | ||
return None # TODO: infer types from . operator | ||
|
||
if isinstance(param, jinja2.nodes.Concat): | ||
return None | ||
|
||
if isinstance(param, jinja2.nodes.Const): | ||
if isinstance(param.value, str): # type: ignore | ||
return MacroType("str") | ||
elif isinstance(param.value, bool): # type: ignore | ||
return MacroType("bool") | ||
elif isinstance(param.value, int): # type: ignore | ||
return MacroType("int") | ||
elif isinstance(param.value, float): # type: ignore | ||
return MacroType("float") | ||
elif param.value is None: # type: ignore | ||
return None | ||
else: | ||
return None | ||
|
||
if isinstance(param, jinja2.nodes.Dict): | ||
return None | ||
|
||
return None | ||
|
||
def is_valid_type(self, t: MacroType) -> bool: | ||
if len(t.type_params) == 0 and t.name in PRIMITIVE_TYPES: | ||
return True | ||
elif ( | ||
t.name == "Dict" | ||
and len(t.type_params) == 2 | ||
and t.type_params[0].name in PRIMITIVE_TYPES | ||
and self.is_valid_type(t.type_params[1]) | ||
): | ||
return True | ||
elif ( | ||
t.name in ["List", "Optional"] | ||
and len(t.type_params) == 1 | ||
and self.is_valid_type(t.type_params[0]) | ||
): | ||
return True | ||
|
||
return False | ||
|
||
def check(self, macro_text: str) -> List[TypeCheckFailure]: | ||
failures: List[TypeCheckFailure] = [] | ||
template = get_environment(None, capture_macros=True).parse(macro_text) | ||
jinja_macro = template.body[0] | ||
|
||
for arg_type in jinja_macro.arg_types: | ||
if not self.is_valid_type(arg_type): | ||
failures.append(TypeCheckFailure(msg="Invalid type.")) | ||
|
||
for i, arg_type in enumerate(self.arg_types): | ||
expected_type = jinja_macro.arg_types[i] | ||
if arg_type != expected_type: | ||
failures.append(TypeCheckFailure(msg="Wrong type of parameter.")) | ||
|
||
return failures |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from dbt_common.clients.jinja import MacroType | ||
from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DbtMacroCall | ||
|
||
|
||
single_param_macro_text = """{% macro call_me(param: TYPE) %} | ||
{% endmacro %}""" | ||
|
||
|
||
def test_primitive_type_checks() -> None: | ||
for type_name in PRIMITIVE_TYPES: | ||
macro_text = single_param_macro_text.replace("TYPE", type_name) | ||
call = DbtMacroCall("call_me", "call_me", [MacroType(type_name, [])], {}) | ||
assert not any(call.check(macro_text)) | ||
|
||
|
||
def test_primitive_type_checks_wrong() -> None: | ||
for type_name in PRIMITIVE_TYPES: | ||
macro_text = single_param_macro_text.replace("TYPE", type_name) | ||
wrong_type = next(t for t in PRIMITIVE_TYPES if t != type_name) | ||
call = DbtMacroCall("call_me", "call_me", [MacroType(wrong_type, [])], {}) | ||
assert any(call.check(macro_text)) | ||
|
||
|
||
def test_list_type_checks() -> None: | ||
for type_name in PRIMITIVE_TYPES: | ||
macro_text = single_param_macro_text.replace("TYPE", f"List[{type_name}]") | ||
expected_type = MacroType("List", [MacroType(type_name)]) | ||
call = DbtMacroCall("call_me", "call_me", [expected_type], {}) | ||
assert not any(call.check(macro_text)) | ||
|
||
|
||
def test_dict_type_checks() -> None: | ||
for type_name in PRIMITIVE_TYPES: | ||
macro_text = single_param_macro_text.replace("TYPE", f"Dict[{type_name}, {type_name}]") | ||
expected_type = MacroType("Dict", [MacroType(type_name), MacroType(type_name)]) | ||
call = DbtMacroCall("call_me", "call_me", [expected_type], {}) | ||
assert not any(call.check(macro_text)) | ||
|
||
|
||
def test_too_few_args() -> None: | ||
macro_text = "{% macro call_me(one: str, two: str, three: str) %}" | ||
|
||
|
||
def test_too_many_args() -> None: | ||
pass | ||
|
||
|
||
kwarg_param_macro_text = """{% macro call_me(param: int = 10, arg_one = "val1", arg_two: int = 2, arg_three: str = "val3" ) %} | ||
{% endmacro %}""" | ||
|
||
|
||
# Better structured exceptions | ||
# Test detection of macro called with too few positional args | ||
# Test detection of macro called with too many positional args | ||
# Test detection of macro called with keyword arg having wrong type | ||
# Test detection of macro called with non-existent keyword arg | ||
# Test detection of macro with invalid default value for param type |