Skip to content

Commit

Permalink
small cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Oct 24, 2023
1 parent 5ef63e1 commit 40ed74b
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 25 deletions.
14 changes: 9 additions & 5 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@
DbtDatabricksAllPurposeClusterPythonJobHelper,
DbtDatabricksJobClusterPythonJobHelper,
)
from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType
from dbt.adapters.databricks.utils import is_hive, redact_credentials, undefined_proof
from dbt.adapters.databricks.relation import is_hive_metastore, extract_identifiers
from dbt.adapters.databricks.relation import (
DatabricksRelation,
DatabricksRelationType,
)
from dbt.adapters.databricks.utils import redact_credentials, undefined_proof


logger = AdapterLogger("Databricks")
Expand Down Expand Up @@ -271,7 +275,7 @@ def typeFromNames(
if view_names[name]
else DatabricksRelationType.View
)
elif is_hive(database):
elif is_hive_metastore(database):
return DatabricksRelationType.Table
else:
# not a view so it might be a streaming table
Expand Down Expand Up @@ -434,7 +438,7 @@ def get_catalog(
futures: List[Future[Table]] = []

for info_schema, relations in relations_by_catalog.items():
if is_hive(info_schema.database):
if is_hive_metastore(info_schema.database):
schema_map = defaultdict(list)
for relation in relations:
schema_map[relation.schema].append(relation)
Expand Down Expand Up @@ -469,7 +473,7 @@ def _get_hive_catalog(
schema: str,
relations: List[BaseRelation],
) -> Table:
table_names: Set[str] = set([x.identifier for x in relations if x.identifier is not None])
table_names = extract_identifiers(relations)
columns: List[Dict[str, Any]] = []

if len(table_names) > 0:
Expand Down
12 changes: 10 additions & 2 deletions dbt/adapters/databricks/relation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, List, Optional, Set, Type
from dbt.contracts.relation import (
ComponentName,
)
Expand Down Expand Up @@ -44,7 +44,7 @@ class DatabricksInformationSchema(InformationSchema):
quote_character: str = "`"

def is_hive_metastore(self) -> bool:
return self.database is None or self.database == "hive_metastore"
return is_hive_metastore(self.database)


@dataclass(frozen=True, eq=False, repr=False)
Expand Down Expand Up @@ -135,3 +135,11 @@ def information_schema(self, view_name: Optional[str] = None) -> InformationSche
# Instead address this as <database>.information_schema by default
info_schema = DatabricksInformationSchema.from_relation(self, view_name)
return info_schema.incorporate(path={"schema": None})


def is_hive_metastore(database: Optional[str]) -> bool:
return database is None or database.lower() == "hive_metastore"


def extract_identifiers(relations: List[BaseRelation]) -> Set[str]:
return {r.identifier for r in relations if r.identifier is not None}
6 changes: 1 addition & 5 deletions dbt/adapters/databricks/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import inspect
import re
from typing import Any, Callable, Optional, Type, TypeVar
from typing import Any, Callable, Type, TypeVar

from dbt.adapters.base import BaseAdapter
from jinja2.runtime import Undefined
Expand Down Expand Up @@ -77,7 +77,3 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
def remove_ansi(line: str) -> str:
ansi_escape = re.compile(r"(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]")
return ansi_escape.sub("", line)


def is_hive(database: Optional[str]) -> bool:
return database is None or database.lower() == "hive_metastore"
2 changes: 0 additions & 2 deletions dbt/include/databricks/macros/catalog.sql
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
{%- endset -%}

{{ return(run_query(query)) }}
{{ return(run_query(query)) }}

{%- endmacro %}

{% macro databricks__get_catalog_tables_sql(information_schema) -%}
Expand Down
22 changes: 11 additions & 11 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def test_parse_relation(self):
"table_type": rel_type,
"table_owner": "root",
"column": "col1",
"column_index": 1,
"column_index": 0,
"dtype": "decimal(22,0)",
"numeric_scale": None,
"numeric_precision": None,
Expand All @@ -443,7 +443,7 @@ def test_parse_relation(self):
"table_type": rel_type,
"table_owner": "root",
"column": "col2",
"column_index": 2,
"column_index": 1,
"dtype": "string",
"numeric_scale": None,
"numeric_precision": None,
Expand All @@ -460,7 +460,7 @@ def test_parse_relation(self):
"table_type": rel_type,
"table_owner": "root",
"column": "dt",
"column_index": 3,
"column_index": 2,
"dtype": "date",
"numeric_scale": None,
"numeric_precision": None,
Expand All @@ -477,7 +477,7 @@ def test_parse_relation(self):
"table_type": rel_type,
"table_owner": "root",
"column": "struct_col",
"column_index": 4,
"column_index": 3,
"dtype": "struct<struct_inner_col:string>",
"numeric_scale": None,
"numeric_precision": None,
Expand Down Expand Up @@ -575,7 +575,7 @@ def test_parse_relation_with_statistics(self):
"table_type": rel_type,
"table_owner": "root",
"column": "col1",
"column_index": 1,
"column_index": 0,
"dtype": "decimal(22,0)",
"numeric_scale": None,
"numeric_precision": None,
Expand Down Expand Up @@ -643,7 +643,7 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self)
"table_type": rel_type,
"table_owner": "root",
"column": "col1",
"column_index": 1,
"column_index": 0,
"dtype": "decimal(22,0)",
"numeric_scale": None,
"numeric_precision": None,
Expand All @@ -664,7 +664,7 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self)
"table_type": rel_type,
"table_owner": "root",
"column": "struct_col",
"column_index": 4,
"column_index": 3,
"dtype": "struct",
"numeric_scale": None,
"numeric_precision": None,
Expand Down Expand Up @@ -728,7 +728,7 @@ def test_parse_columns_from_information_with_view_type(self):
"table_type": rel_type,
"table_owner": "root",
"column": "col2",
"column_index": 2,
"column_index": 1,
"dtype": "string",
"numeric_scale": None,
"numeric_precision": None,
Expand All @@ -745,7 +745,7 @@ def test_parse_columns_from_information_with_view_type(self):
"table_type": rel_type,
"table_owner": "root",
"column": "struct_col",
"column_index": 4,
"column_index": 3,
"dtype": "struct",
"numeric_scale": None,
"numeric_precision": None,
Expand Down Expand Up @@ -794,7 +794,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel
"table_type": rel_type,
"table_owner": "root",
"column": "dt",
"column_index": 3,
"column_index": 2,
"dtype": "date",
"numeric_scale": None,
"numeric_precision": None,
Expand All @@ -819,7 +819,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel
"table_type": rel_type,
"table_owner": "root",
"column": "struct_col",
"column_index": 4,
"column_index": 3,
"dtype": "struct",
"numeric_scale": None,
"numeric_precision": None,
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/test_relation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest

from jinja2.runtime import Undefined
import pytest

from dbt.adapters.databricks import relation
from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksQuotePolicy


Expand Down Expand Up @@ -177,3 +179,35 @@ def test_matches(self):

relation = DatabricksRelation.from_dict(data)
self.assertFalse(relation.matches("some_database", "some_schema", "table"))


class TestRelationsFunctions:
@pytest.mark.parametrize(
"database, expected", [(None, True), ("hive_metastore", True), ("not_hive", False)]
)
def test_is_hive_metastore(self, database, expected):
assert relation.is_hive_metastore(database) is expected

@pytest.mark.parametrize(
"input, expected",
[
([], set()),
([DatabricksRelation.create(identifier=None)], set()),
(
[
DatabricksRelation.create(identifier=None),
DatabricksRelation.create(identifier="test"),
],
{"test"},
),
(
[
DatabricksRelation.create(identifier="test"),
DatabricksRelation.create(identifier="test"),
],
{"test"},
),
],
)
def test_extract_identifiers(self, input, expected):
assert relation.extract_identifiers(input) == expected

0 comments on commit 40ed74b

Please sign in to comment.