Skip to content

Commit

Permalink
Rough draft of macro type check enforcement.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Dec 12, 2024
1 parent 03fdb4c commit 1d24734
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 29 deletions.
1 change: 1 addition & 0 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def global_flags(func):
@p.warn_error_options
@p.write_json
@p.use_fast_test_edges
@p.type_check
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
Expand Down
7 changes: 7 additions & 0 deletions core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,10 @@ def _version_callback(ctx, _param, value):
default=False,
hidden=True,
)

type_check = click.option(
"--type-check/--no-type-check",
envvar="DBT_TYPE_CHECK",
default=False,
hidden=True,
)
106 changes: 106 additions & 0 deletions core/dbt/clients/jinja_macro_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import dataclasses
from typing import Any, Dict, List, Optional, Union

import jinja2

from dbt_common.clients.jinja import MacroType, get_environment

PRIMITIVE_TYPES = ["Any", "bool", "float", "int", "str"]


@dataclasses.dataclass
class TypeCheckFailure:
msg: str


@dataclasses.dataclass
class DbtMacroCall:
"""An instance of this class represents a jinja macro call in a template
for the purposes of recording information for type checking."""

name: str
source: str
arg_types: List[Union[MacroType, None]] = dataclasses.field(default_factory=list)
kwarg_types: Dict[str, Optional[MacroType]] = dataclasses.field(default_factory=dict)

@classmethod
def from_call(cls, call: jinja2.nodes.Call, name: str) -> "DbtMacroCall":
dbt_call = cls(name, "")
for arg in call.args: # type: ignore
dbt_call.arg_types.append(cls.get_type(arg))
for arg in call.kwargs: # type: ignore
dbt_call.kwarg_types[arg.key] = cls.get_type(arg.value)
return dbt_call

@classmethod
def get_type(cls, param: Any) -> Optional[MacroType]:
if isinstance(param, jinja2.nodes.Name):
return None # TODO: infer types from variable names

if isinstance(param, jinja2.nodes.Call):
return None # TODO: infer types from function/macro calls

if isinstance(param, jinja2.nodes.Getattr):
return None # TODO: infer types from . operator

if isinstance(param, jinja2.nodes.Concat):
return None

if isinstance(param, jinja2.nodes.Const):
if isinstance(param.value, str): # type: ignore
return MacroType("str")
elif isinstance(param.value, bool): # type: ignore
return MacroType("bool")
elif isinstance(param.value, int): # type: ignore
return MacroType("int")
elif isinstance(param.value, float): # type: ignore
return MacroType("float")

Check warning on line 57 in core/dbt/clients/jinja_macro_call.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/clients/jinja_macro_call.py#L57

Added line #L57 was not covered by tests
elif param.value is None: # type: ignore
return None
else:
return None

Check warning on line 61 in core/dbt/clients/jinja_macro_call.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/clients/jinja_macro_call.py#L61

Added line #L61 was not covered by tests

if isinstance(param, jinja2.nodes.Dict):
return None

return None

def is_valid_type(self, t: MacroType) -> bool:
if len(t.type_params) == 0 and t.name in PRIMITIVE_TYPES:
return True
elif (
t.name == "Dict"
and len(t.type_params) == 2
and t.type_params[0].name in PRIMITIVE_TYPES
and self.is_valid_type(t.type_params[1])
):
return True

Check warning on line 77 in core/dbt/clients/jinja_macro_call.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/clients/jinja_macro_call.py#L77

Added line #L77 was not covered by tests
elif (
t.name in ["List", "Optional"]
and len(t.type_params) == 1
and self.is_valid_type(t.type_params[0])
):
return True

return False

Check warning on line 85 in core/dbt/clients/jinja_macro_call.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/clients/jinja_macro_call.py#L85

Added line #L85 was not covered by tests

def check(self, macro_text: str) -> List[TypeCheckFailure]:
failures: List[TypeCheckFailure] = []
template = get_environment(None, capture_macros=True).parse(macro_text)
jinja_macro = template.body[0]

for arg_type in jinja_macro.arg_types:
if not self.is_valid_type(arg_type):
failures.append(TypeCheckFailure(msg="Invalid type."))

Check warning on line 94 in core/dbt/clients/jinja_macro_call.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/clients/jinja_macro_call.py#L94

Added line #L94 was not covered by tests

for i, arg_type in enumerate(self.arg_types):
expected_type = jinja_macro.arg_types[i]
if arg_type != expected_type:
failures.append(TypeCheckFailure(msg="Wrong type of parameter."))

# if len(self.arg_types) + len(self.kwarg_types) > len(jinja_macro.args):
# failures.append(
# TypeCheckFailure(msg=f"Wrong number of parameters.")
# )

return failures
38 changes: 19 additions & 19 deletions core/dbt/clients/jinja_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jinja2

from dbt.artifacts.resources import RefArgs
from dbt.clients.jinja_macro_call import DbtMacroCall
from dbt.exceptions import MacroNamespaceNotStringError, ParsingError
from dbt_common.clients.jinja import get_environment
from dbt_common.exceptions.macros import MacroNameNotStringError
Expand Down Expand Up @@ -31,7 +32,7 @@ def statically_extract_has_name_this(source: str) -> bool:

def statically_extract_macro_calls(
source: str, ctx: Dict[str, Any], db_wrapper: Optional["ParseDatabaseWrapper"] = None
) -> List[str]:
) -> List[DbtMacroCall]:
# set 'capture_macros' to capture undefined
env = get_environment(None, capture_macros=True)

Expand All @@ -48,11 +49,11 @@ def statically_extract_macro_calls(
setattr(parsed, "_dbt_cached_calls", func_calls)

standard_calls = ["source", "ref", "config"]
possible_macro_calls = []
possible_macro_calls: List[DbtMacroCall] = []
for func_call in func_calls:
func_name = None
macro_call: Optional[DbtMacroCall] = None
if hasattr(func_call, "node") and hasattr(func_call.node, "name"):
func_name = func_call.node.name
macro_call = DbtMacroCall.from_call(func_call, func_call.node.name)
else:
if (
hasattr(func_call, "node")
Expand All @@ -72,34 +73,31 @@ def statically_extract_macro_calls(
# This skips calls such as adapter.parse_index
continue
else:
func_name = f"{package_name}.{macro_name}"
macro_call = DbtMacroCall.from_call(func_call, f"{package_name}.{macro_name}")
else:
continue
if not func_name:
continue
if func_name in standard_calls:
continue
elif ctx.get(func_name):

if not macro_call or macro_call.name in standard_calls or ctx.get(macro_call.name):
continue
else:
if func_name not in possible_macro_calls:
possible_macro_calls.append(func_name)

possible_macro_calls.append(macro_call)

return possible_macro_calls


def statically_parse_adapter_dispatch(
func_call, ctx: Dict[str, Any], db_wrapper: Optional["ParseDatabaseWrapper"]
) -> List[str]:
possible_macro_calls = []
) -> List[DbtMacroCall]:
possible_macro_calls: List[DbtMacroCall] = []
# This captures an adapter.dispatch('<macro_name>') call.

func_name = None
# macro_name positional argument
if len(func_call.args) > 0:
func_name = func_call.args[0].value

if func_name:
possible_macro_calls.append(func_name)
possible_macro_calls.append(DbtMacroCall.from_call(func_call, func_name))

# packages positional argument
macro_namespace = None
Expand All @@ -118,7 +116,7 @@ def statically_parse_adapter_dispatch(
# This will remain to enable static resolution
if type(kwarg.value).__name__ == "Const":
func_name = kwarg.value.value
possible_macro_calls.append(func_name)
possible_macro_calls.append(DbtMacroCall.from_call(func_call, func_name))

Check warning on line 119 in core/dbt/clients/jinja_static.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/clients/jinja_static.py#L119

Added line #L119 was not covered by tests
else:
raise MacroNameNotStringError(kwarg_value=kwarg.value.value)
elif kwarg.key == "macro_namespace":
Expand All @@ -143,14 +141,16 @@ def statically_parse_adapter_dispatch(
if db_wrapper:
macro = db_wrapper.dispatch(func_name, macro_namespace=macro_namespace).macro
func_name = f"{macro.package_name}.{macro.name}" # type: ignore[attr-defined]
possible_macro_calls.append(func_name)
possible_macro_calls.append(DbtMacroCall.from_call(func_call, func_name))
else: # this is only for tests/unit/test_macro_calls.py
if macro_namespace:
packages = [macro_namespace]
else:
packages = []
for package_name in packages:
possible_macro_calls.append(f"{package_name}.{func_name}")
possible_macro_calls.append(
DbtMacroCall.from_call(func_call, f"{package_name}.{func_name}")
)

return possible_macro_calls

Expand Down
25 changes: 16 additions & 9 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def load_and_parse_macros(self, project_parser_files):

self.build_macro_resolver()
# Look at changed macros and update the macro.depends_on.macros
self.macro_depends_on()
self.analyze_macros()

# Parse the files in the 'parser_files' dictionary, for parsers listed in
# 'parser_types'
Expand Down Expand Up @@ -776,32 +776,39 @@ def build_macro_resolver(self):
self.manifest.macros, self.root_project.project_name, internal_package_names
)

# Loop through macros in the manifest and statically parse
# the 'macro_sql' to find depends_on.macros
def macro_depends_on(self):
def analyze_macros(self):
"""Loop through macros in the manifest and statically parse the
'macro_sql' to find and set the value of depends_on.macros. Also,
perform type checking if flag is set.
"""
macro_ctx = generate_macro_context(self.root_project)
macro_namespace = TestMacroNamespace(self.macro_resolver, {}, None, MacroStack(), [])
adapter = get_adapter(self.root_project)
db_wrapper = ParseProvider().DatabaseWrapper(adapter, macro_namespace)
type_check = get_flags().TYPE_CHECK
for macro in self.manifest.macros.values():
if macro.created_at < self.started_at:
continue
possible_macro_calls = statically_extract_macro_calls(
macro.macro_sql, macro_ctx, db_wrapper
)
for macro_name in possible_macro_calls:
for macro_call in possible_macro_calls:
# adapter.dispatch calls can generate a call with the same name as the macro
# it ought to be an adapter prefix (postgres_) or default_
macro_name = macro_call.name
if macro_name == macro.name:
continue
package_name = macro.package_name
if "." in macro_name:
package_name, macro_name = macro_name.split(".")
dep_macro_id = self.macro_resolver.get_macro_id(package_name, macro_name)
if dep_macro_id:
macro.depends_on.add_macro(dep_macro_id) # will check for dupes
dep_macro = self.macro_resolver.get_macro(package_name, macro_name)
if dep_macro is not None and dep_macro.unique_id:
macro.depends_on.add_macro(dep_macro.unique_id) # will check for dupes

if type_check:
macro_call.check(dep_macro)

Check warning on line 809 in core/dbt/parser/manifest.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/manifest.py#L809

Added line #L809 was not covered by tests

def write_manifest_for_partial_parse(self):
def write_manifest_for_partial_parse(self) -> None:
path = os.path.join(self.root_project.project_target_path, PARTIAL_PARSE_FILE_NAME)
try:
# This shouldn't be necessary, but we have gotten bug reports (#3757) of the
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/clients/test_jinja_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_extract_macro_calls(macro_string, expected_possible_macro_calls):
ctx = generate_base_context(cli_vars)

possible_macro_calls = statically_extract_macro_calls(macro_string, ctx)
assert possible_macro_calls == expected_possible_macro_calls
assert [c.name for c in possible_macro_calls] == expected_possible_macro_calls


class TestStaticallyParseRefOrSource:
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/clients/test_jinja_type_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from dbt.clients.jinja_macro_call import PRIMITIVE_TYPES, DbtMacroCall
from dbt_common.clients.jinja import MacroType

single_param_macro_text = """{% macro call_me(param: TYPE) %}
{% endmacro %}"""


def test_primitive_type_checks():
for type_name in PRIMITIVE_TYPES:
macro_text = single_param_macro_text.replace("TYPE", type_name)
call = DbtMacroCall("call_me", "call_me", [MacroType(type_name, [])], {})
assert not any(call.check(macro_text))


def test_primitive_type_checks_wrong():
for type_name in PRIMITIVE_TYPES:
macro_text = single_param_macro_text.replace("TYPE", type_name)
wrong_type = next(t for t in PRIMITIVE_TYPES if t != type_name)
call = DbtMacroCall("call_me", "call_me", [MacroType(wrong_type, [])], {})
assert any(call.check(macro_text))


def test_list_type_checks():
for type_name in PRIMITIVE_TYPES:
macro_text = single_param_macro_text.replace("TYPE", f"List[{type_name}]")
expected_type = MacroType("List", [MacroType(type_name, [])])
call = DbtMacroCall("call_me", "call_me", [expected_type], {})
assert not any(call.check(macro_text))

0 comments on commit 1d24734

Please sign in to comment.