diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 4926b5de1..b41c8395b 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -47,8 +47,12 @@ DbtDatabricksAllPurposeClusterPythonJobHelper, DbtDatabricksJobClusterPythonJobHelper, ) -from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType -from dbt.adapters.databricks.utils import is_hive, redact_credentials, undefined_proof +from dbt.adapters.databricks.relation import is_hive_metastore, extract_identifiers +from dbt.adapters.databricks.relation import ( + DatabricksRelation, + DatabricksRelationType, +) +from dbt.adapters.databricks.utils import redact_credentials, undefined_proof logger = AdapterLogger("Databricks") @@ -271,7 +275,7 @@ def typeFromNames( if view_names[name] else DatabricksRelationType.View ) - elif is_hive(database): + elif is_hive_metastore(database): return DatabricksRelationType.Table else: # not a view so it might be a streaming table @@ -434,7 +438,7 @@ def get_catalog( futures: List[Future[Table]] = [] for info_schema, relations in relations_by_catalog.items(): - if is_hive(info_schema.database): + if is_hive_metastore(info_schema.database): schema_map = defaultdict(list) for relation in relations: schema_map[relation.schema].append(relation) @@ -469,7 +473,7 @@ def _get_hive_catalog( schema: str, relations: List[BaseRelation], ) -> Table: - table_names: Set[str] = set([x.identifier for x in relations if x.identifier is not None]) + table_names = extract_identifiers(relations) columns: List[Dict[str, Any]] = [] if len(table_names) > 0: diff --git a/dbt/adapters/databricks/relation.py b/dbt/adapters/databricks/relation.py index 00d6720a4..63309e951 100644 --- a/dbt/adapters/databricks/relation.py +++ b/dbt/adapters/databricks/relation.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, List, Optional, Set, Type from dbt.contracts.relation import ( ComponentName, ) @@ -44,7 +44,7 @@ class DatabricksInformationSchema(InformationSchema): quote_character: str = "`" def is_hive_metastore(self) -> bool: - return self.database is None or self.database == "hive_metastore" + return is_hive_metastore(self.database) @dataclass(frozen=True, eq=False, repr=False) @@ -135,3 +135,11 @@ def information_schema(self, view_name: Optional[str] = None) -> InformationSche # Instead address this as .information_schema by default info_schema = DatabricksInformationSchema.from_relation(self, view_name) return info_schema.incorporate(path={"schema": None}) + + +def is_hive_metastore(database: Optional[str]) -> bool: + return database is None or database.lower() == "hive_metastore" + + +def extract_identifiers(relations: List[BaseRelation]) -> Set[str]: + return {r.identifier for r in relations if r.identifier is not None} diff --git a/dbt/adapters/databricks/utils.py b/dbt/adapters/databricks/utils.py index 5ed55da48..9fbbc411c 100644 --- a/dbt/adapters/databricks/utils.py +++ b/dbt/adapters/databricks/utils.py @@ -1,7 +1,7 @@ import functools import inspect import re -from typing import Any, Callable, Optional, Type, TypeVar +from typing import Any, Callable, Type, TypeVar from dbt.adapters.base import BaseAdapter from jinja2.runtime import Undefined @@ -77,7 +77,3 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: def remove_ansi(line: str) -> str: ansi_escape = re.compile(r"(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]") return ansi_escape.sub("", line) - - -def is_hive(database: Optional[str]) -> bool: - return database is None or database.lower() == "hive_metastore" diff --git a/dbt/include/databricks/macros/catalog.sql b/dbt/include/databricks/macros/catalog.sql index dbc712be5..0749bca95 100644 --- a/dbt/include/databricks/macros/catalog.sql +++ b/dbt/include/databricks/macros/catalog.sql @@ -34,8 +34,6 @@ {%- endset -%} {{ return(run_query(query)) }} - {{ return(run_query(query)) }} - {%- endmacro %} {% macro databricks__get_catalog_tables_sql(information_schema) -%} diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index a03b81bff..cb1f11545 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -426,7 +426,7 @@ def test_parse_relation(self): "table_type": rel_type, "table_owner": "root", "column": "col1", - "column_index": 1, + "column_index": 0, "dtype": "decimal(22,0)", "numeric_scale": None, "numeric_precision": None, @@ -443,7 +443,7 @@ def test_parse_relation(self): "table_type": rel_type, "table_owner": "root", "column": "col2", - "column_index": 2, + "column_index": 1, "dtype": "string", "numeric_scale": None, "numeric_precision": None, @@ -460,7 +460,7 @@ def test_parse_relation(self): "table_type": rel_type, "table_owner": "root", "column": "dt", - "column_index": 3, + "column_index": 2, "dtype": "date", "numeric_scale": None, "numeric_precision": None, @@ -477,7 +477,7 @@ def test_parse_relation(self): "table_type": rel_type, "table_owner": "root", "column": "struct_col", - "column_index": 4, + "column_index": 3, "dtype": "struct", "numeric_scale": None, "numeric_precision": None, @@ -575,7 +575,7 @@ def test_parse_relation_with_statistics(self): "table_type": rel_type, "table_owner": "root", "column": "col1", - "column_index": 1, + "column_index": 0, "dtype": "decimal(22,0)", "numeric_scale": None, "numeric_precision": None, @@ -643,7 +643,7 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) "table_type": rel_type, "table_owner": "root", "column": "col1", - "column_index": 1, + "column_index": 0, "dtype": "decimal(22,0)", "numeric_scale": None, "numeric_precision": None, @@ -664,7 +664,7 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) "table_type": rel_type, "table_owner": "root", "column": "struct_col", - "column_index": 4, + "column_index": 3, "dtype": "struct", "numeric_scale": None, "numeric_precision": None, @@ -728,7 +728,7 @@ def test_parse_columns_from_information_with_view_type(self): "table_type": rel_type, "table_owner": "root", "column": "col2", - "column_index": 2, + "column_index": 1, "dtype": "string", "numeric_scale": None, "numeric_precision": None, @@ -745,7 +745,7 @@ def test_parse_columns_from_information_with_view_type(self): "table_type": rel_type, "table_owner": "root", "column": "struct_col", - "column_index": 4, + "column_index": 3, "dtype": "struct", "numeric_scale": None, "numeric_precision": None, @@ -794,7 +794,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel "table_type": rel_type, "table_owner": "root", "column": "dt", - "column_index": 3, + "column_index": 2, "dtype": "date", "numeric_scale": None, "numeric_precision": None, @@ -819,7 +819,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel "table_type": rel_type, "table_owner": "root", "column": "struct_col", - "column_index": 4, + "column_index": 3, "dtype": "struct", "numeric_scale": None, "numeric_precision": None, diff --git a/tests/unit/test_relation.py b/tests/unit/test_relation.py index a7da9f27f..5a827bade 100644 --- a/tests/unit/test_relation.py +++ b/tests/unit/test_relation.py @@ -1,7 +1,9 @@ import unittest from jinja2.runtime import Undefined +import pytest +from dbt.adapters.databricks import relation from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksQuotePolicy @@ -177,3 +179,35 @@ def test_matches(self): relation = DatabricksRelation.from_dict(data) self.assertFalse(relation.matches("some_database", "some_schema", "table")) + + +class TestRelationsFunctions: + @pytest.mark.parametrize( + "database, expected", [(None, True), ("hive_metastore", True), ("not_hive", False)] + ) + def test_is_hive_metastore(self, database, expected): + assert relation.is_hive_metastore(database) is expected + + @pytest.mark.parametrize( + "input, expected", + [ + ([], set()), + ([DatabricksRelation.create(identifier=None)], set()), + ( + [ + DatabricksRelation.create(identifier=None), + DatabricksRelation.create(identifier="test"), + ], + {"test"}, + ), + ( + [ + DatabricksRelation.create(identifier="test"), + DatabricksRelation.create(identifier="test"), + ], + {"test"}, + ), + ], + ) + def test_extract_identifiers(self, input, expected): + assert relation.extract_identifiers(input) == expected