diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index 11cc81ef70e..228c68d3746 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -141,6 +141,7 @@ def global_flags(func): @p.warn_error_options @p.write_json @p.use_fast_test_edges + @p.type_check @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) diff --git a/core/dbt/cli/params.py b/core/dbt/cli/params.py index 612728de222..2cd3e81940e 100644 --- a/core/dbt/cli/params.py +++ b/core/dbt/cli/params.py @@ -742,3 +742,10 @@ def _version_callback(ctx, _param, value): default=False, hidden=True, ) + +type_check = click.option( + "--type-check/--no-type-check", + envvar="DBT_TYPE_CHECK", + default=False, + hidden=True, +) diff --git a/core/dbt/clients/jinja_macro_call.py b/core/dbt/clients/jinja_macro_call.py new file mode 100644 index 00000000000..2cd8a63122a --- /dev/null +++ b/core/dbt/clients/jinja_macro_call.py @@ -0,0 +1,106 @@ +import dataclasses +from typing import Any, Dict, List, Optional, Union + +import jinja2 + +from dbt_common.clients.jinja import MacroType, get_environment + +PRIMITIVE_TYPES = ["Any", "bool", "float", "int", "str"] + + +@dataclasses.dataclass +class TypeCheckFailure: + msg: str + + +@dataclasses.dataclass +class DbtMacroCall: + """An instance of this class represents a jinja macro call in a template + for the purposes of recording information for type checking.""" + + name: str + source: str + arg_types: List[Union[MacroType, None]] = dataclasses.field(default_factory=list) + kwarg_types: Dict[str, Optional[MacroType]] = dataclasses.field(default_factory=dict) + + @classmethod + def from_call(cls, call: jinja2.nodes.Call, name: str) -> "DbtMacroCall": + dbt_call = cls(name, "") + for arg in call.args: # type: ignore + dbt_call.arg_types.append(cls.get_type(arg)) + for arg in call.kwargs: # type: ignore + dbt_call.kwarg_types[arg.key] = cls.get_type(arg.value) + return dbt_call + + @classmethod + def get_type(cls, param: Any) -> Optional[MacroType]: + if isinstance(param, jinja2.nodes.Name): + return None # TODO: infer types from variable names + + if isinstance(param, jinja2.nodes.Call): + return None # TODO: infer types from function/macro calls + + if isinstance(param, jinja2.nodes.Getattr): + return None # TODO: infer types from . operator + + if isinstance(param, jinja2.nodes.Concat): + return None + + if isinstance(param, jinja2.nodes.Const): + if isinstance(param.value, str): # type: ignore + return MacroType("str") + elif isinstance(param.value, bool): # type: ignore + return MacroType("bool") + elif isinstance(param.value, int): # type: ignore + return MacroType("int") + elif isinstance(param.value, float): # type: ignore + return MacroType("float") + elif param.value is None: # type: ignore + return None + else: + return None + + if isinstance(param, jinja2.nodes.Dict): + return None + + return None + + def is_valid_type(self, t: MacroType) -> bool: + if len(t.type_params) == 0 and t.name in PRIMITIVE_TYPES: + return True + elif ( + t.name == "Dict" + and len(t.type_params) == 2 + and t.type_params[0].name in PRIMITIVE_TYPES + and self.is_valid_type(t.type_params[1]) + ): + return True + elif ( + t.name in ["List", "Optional"] + and len(t.type_params) == 1 + and self.is_valid_type(t.type_params[0]) + ): + return True + + return False + + def check(self, macro_text: str) -> List[TypeCheckFailure]: + failures: List[TypeCheckFailure] = [] + template = get_environment(None, capture_macros=True).parse(macro_text) + jinja_macro = template.body[0] + + for arg_type in jinja_macro.arg_types: + if not self.is_valid_type(arg_type): + failures.append(TypeCheckFailure(msg="Invalid type.")) + + for i, arg_type in enumerate(self.arg_types): + expected_type = jinja_macro.arg_types[i] + if arg_type != expected_type: + failures.append(TypeCheckFailure(msg="Wrong type of parameter.")) + + # if len(self.arg_types) + len(self.kwarg_types) > len(jinja_macro.args): + # failures.append( + # TypeCheckFailure(msg=f"Wrong number of parameters.") + # ) + + return failures diff --git a/core/dbt/clients/jinja_static.py b/core/dbt/clients/jinja_static.py index a89cbef9b26..7e488b286ae 100644 --- a/core/dbt/clients/jinja_static.py +++ b/core/dbt/clients/jinja_static.py @@ -4,6 +4,7 @@ import jinja2 from dbt.artifacts.resources import RefArgs +from dbt.clients.jinja_macro_call import DbtMacroCall from dbt.exceptions import MacroNamespaceNotStringError, ParsingError from dbt_common.clients.jinja import get_environment from dbt_common.exceptions.macros import MacroNameNotStringError @@ -31,7 +32,7 @@ def statically_extract_has_name_this(source: str) -> bool: def statically_extract_macro_calls( source: str, ctx: Dict[str, Any], db_wrapper: Optional["ParseDatabaseWrapper"] = None -) -> List[str]: +) -> List[DbtMacroCall]: # set 'capture_macros' to capture undefined env = get_environment(None, capture_macros=True) @@ -48,11 +49,11 @@ def statically_extract_macro_calls( setattr(parsed, "_dbt_cached_calls", func_calls) standard_calls = ["source", "ref", "config"] - possible_macro_calls = [] + possible_macro_calls: List[DbtMacroCall] = [] for func_call in func_calls: - func_name = None + macro_call: Optional[DbtMacroCall] = None if hasattr(func_call, "node") and hasattr(func_call.node, "name"): - func_name = func_call.node.name + macro_call = DbtMacroCall.from_call(func_call, func_call.node.name) else: if ( hasattr(func_call, "node") @@ -72,34 +73,31 @@ def statically_extract_macro_calls( # This skips calls such as adapter.parse_index continue else: - func_name = f"{package_name}.{macro_name}" + macro_call = DbtMacroCall.from_call(func_call, f"{package_name}.{macro_name}") else: continue - if not func_name: - continue - if func_name in standard_calls: - continue - elif ctx.get(func_name): + + if not macro_call or macro_call.name in standard_calls or ctx.get(macro_call.name): continue - else: - if func_name not in possible_macro_calls: - possible_macro_calls.append(func_name) + + possible_macro_calls.append(macro_call) return possible_macro_calls def statically_parse_adapter_dispatch( func_call, ctx: Dict[str, Any], db_wrapper: Optional["ParseDatabaseWrapper"] -) -> List[str]: - possible_macro_calls = [] +) -> List[DbtMacroCall]: + possible_macro_calls: List[DbtMacroCall] = [] # This captures an adapter.dispatch('') call. func_name = None # macro_name positional argument if len(func_call.args) > 0: func_name = func_call.args[0].value + if func_name: - possible_macro_calls.append(func_name) + possible_macro_calls.append(DbtMacroCall.from_call(func_call, func_name)) # packages positional argument macro_namespace = None @@ -118,7 +116,7 @@ def statically_parse_adapter_dispatch( # This will remain to enable static resolution if type(kwarg.value).__name__ == "Const": func_name = kwarg.value.value - possible_macro_calls.append(func_name) + possible_macro_calls.append(DbtMacroCall.from_call(func_call, func_name)) else: raise MacroNameNotStringError(kwarg_value=kwarg.value.value) elif kwarg.key == "macro_namespace": @@ -143,14 +141,16 @@ def statically_parse_adapter_dispatch( if db_wrapper: macro = db_wrapper.dispatch(func_name, macro_namespace=macro_namespace).macro func_name = f"{macro.package_name}.{macro.name}" # type: ignore[attr-defined] - possible_macro_calls.append(func_name) + possible_macro_calls.append(DbtMacroCall.from_call(func_call, func_name)) else: # this is only for tests/unit/test_macro_calls.py if macro_namespace: packages = [macro_namespace] else: packages = [] for package_name in packages: - possible_macro_calls.append(f"{package_name}.{func_name}") + possible_macro_calls.append( + DbtMacroCall.from_call(func_call, f"{package_name}.{func_name}") + ) return possible_macro_calls diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 72f328a0bbe..8ae3e6137ef 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -692,7 +692,7 @@ def load_and_parse_macros(self, project_parser_files): self.build_macro_resolver() # Look at changed macros and update the macro.depends_on.macros - self.macro_depends_on() + self.analyze_macros() # Parse the files in the 'parser_files' dictionary, for parsers listed in # 'parser_types' @@ -776,32 +776,39 @@ def build_macro_resolver(self): self.manifest.macros, self.root_project.project_name, internal_package_names ) - # Loop through macros in the manifest and statically parse - # the 'macro_sql' to find depends_on.macros - def macro_depends_on(self): + def analyze_macros(self): + """Loop through macros in the manifest and statically parse the + 'macro_sql' to find and set the value of depends_on.macros. Also, + perform type checking if flag is set. + """ macro_ctx = generate_macro_context(self.root_project) macro_namespace = TestMacroNamespace(self.macro_resolver, {}, None, MacroStack(), []) adapter = get_adapter(self.root_project) db_wrapper = ParseProvider().DatabaseWrapper(adapter, macro_namespace) + type_check = get_flags().TYPE_CHECK for macro in self.manifest.macros.values(): if macro.created_at < self.started_at: continue possible_macro_calls = statically_extract_macro_calls( macro.macro_sql, macro_ctx, db_wrapper ) - for macro_name in possible_macro_calls: + for macro_call in possible_macro_calls: # adapter.dispatch calls can generate a call with the same name as the macro # it ought to be an adapter prefix (postgres_) or default_ + macro_name = macro_call.name if macro_name == macro.name: continue package_name = macro.package_name if "." in macro_name: package_name, macro_name = macro_name.split(".") - dep_macro_id = self.macro_resolver.get_macro_id(package_name, macro_name) - if dep_macro_id: - macro.depends_on.add_macro(dep_macro_id) # will check for dupes + dep_macro = self.macro_resolver.get_macro(package_name, macro_name) + if dep_macro is not None and dep_macro.unique_id: + macro.depends_on.add_macro(dep_macro.unique_id) # will check for dupes + + if type_check: + macro_call.check(dep_macro) - def write_manifest_for_partial_parse(self): + def write_manifest_for_partial_parse(self) -> None: path = os.path.join(self.root_project.project_target_path, PARTIAL_PARSE_FILE_NAME) try: # This shouldn't be necessary, but we have gotten bug reports (#3757) of the diff --git a/tests/unit/clients/test_jinja_static.py b/tests/unit/clients/test_jinja_static.py index fb48c117b47..72ab0c6afd4 100644 --- a/tests/unit/clients/test_jinja_static.py +++ b/tests/unit/clients/test_jinja_static.py @@ -58,7 +58,7 @@ def test_extract_macro_calls(macro_string, expected_possible_macro_calls): ctx = generate_base_context(cli_vars) possible_macro_calls = statically_extract_macro_calls(macro_string, ctx) - assert possible_macro_calls == expected_possible_macro_calls + assert [c.name for c in possible_macro_calls] == expected_possible_macro_calls class TestStaticallyParseRefOrSource: diff --git a/tests/unit/clients/test_jinja_type_checking.py b/tests/unit/clients/test_jinja_type_checking.py new file mode 100644 index 00000000000..cc40645984e --- /dev/null +++ b/tests/unit/clients/test_jinja_type_checking.py @@ -0,0 +1,28 @@ +from dbt.clients.jinja_macro_call import PRIMITIVE_TYPES, DbtMacroCall +from dbt_common.clients.jinja import MacroType + +single_param_macro_text = """{% macro call_me(param: TYPE) %} + {% endmacro %}""" + + +def test_primitive_type_checks(): + for type_name in PRIMITIVE_TYPES: + macro_text = single_param_macro_text.replace("TYPE", type_name) + call = DbtMacroCall("call_me", "call_me", [MacroType(type_name, [])], {}) + assert not any(call.check(macro_text)) + + +def test_primitive_type_checks_wrong(): + for type_name in PRIMITIVE_TYPES: + macro_text = single_param_macro_text.replace("TYPE", type_name) + wrong_type = next(t for t in PRIMITIVE_TYPES if t != type_name) + call = DbtMacroCall("call_me", "call_me", [MacroType(wrong_type, [])], {}) + assert any(call.check(macro_text)) + + +def test_list_type_checks(): + for type_name in PRIMITIVE_TYPES: + macro_text = single_param_macro_text.replace("TYPE", f"List[{type_name}]") + expected_type = MacroType("List", [MacroType(type_name, [])]) + call = DbtMacroCall("call_me", "call_me", [expected_type], {}) + assert not any(call.check(macro_text))