diff --git a/CHANGELOG.md b/CHANGELOG.md index d9dc0441d..1a2844441 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ### Features +- Added support for getting info only on specified relations to improve performance of gathering metadata ([486](https://github.com/databricks/dbt-databricks/pull/486)) - Added support for getting freshness from metadata ([481](https://github.com/databricks/dbt-databricks/pull/481)) ## dbt-databricks 1.7.0rc1 (October 13, 2023) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index c6a8c2053..22bb513d6 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -156,6 +156,8 @@ def __post_init__(self) -> None: f"Invalid catalog name : `{self.database}`." ) self.database = database + else: + self.database = "hive_metastore" connection_parameters = self.connection_parameters or {} for key in ( diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 88869fa1b..b41c8395b 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -1,6 +1,6 @@ +from collections import defaultdict from concurrent.futures import Future from contextlib import contextmanager -from itertools import chain from dataclasses import dataclass import os import re @@ -22,7 +22,7 @@ from dbt.adapters.base import AdapterConfig, PythonJobHelper from dbt.adapters.base.impl import catch_as_completed from dbt.adapters.base.meta import available -from dbt.adapters.base.relation import BaseRelation, InformationSchema +from dbt.adapters.base.relation import BaseRelation from dbt.adapters.capability import CapabilityDict, CapabilitySupport, Support, Capability from dbt.adapters.spark.impl import ( SparkAdapter, @@ -36,7 +36,6 @@ from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER, empty_table from dbt.contracts.connection import AdapterResponse, Connection from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.nodes import ResultNode from dbt.contracts.relation import RelationType import dbt.exceptions from dbt.events import AdapterLogger @@ -48,7 +47,11 @@ DbtDatabricksAllPurposeClusterPythonJobHelper, DbtDatabricksJobClusterPythonJobHelper, ) -from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType +from dbt.adapters.databricks.relation import is_hive_metastore, extract_identifiers +from dbt.adapters.databricks.relation import ( + DatabricksRelation, + DatabricksRelationType, +) from dbt.adapters.databricks.utils import redact_credentials, undefined_proof @@ -109,7 +112,10 @@ class DatabricksAdapter(SparkAdapter): AdapterSpecificConfigs = DatabricksConfig _capabilities = CapabilityDict( - {Capability.TableLastModifiedMetadata: CapabilitySupport(support=Support.Full)} + { + Capability.TableLastModifiedMetadata: CapabilitySupport(support=Support.Full), + Capability.SchemaMetadataByRelations: CapabilitySupport(support=Support.Full), + } ) @available.parse(lambda *a, **k: 0) @@ -269,7 +275,7 @@ def typeFromNames( if view_names[name] else DatabricksRelationType.View ) - elif database is None or database == "hive_metastore": + elif is_hive_metastore(database): return DatabricksRelationType.Table else: # not a view so it might be a streaming table @@ -342,7 +348,7 @@ def parse_describe_extended( # type: ignore[override] table_owner=str(metadata.get(KEY_TABLE_OWNER)), table_stats=table_stats, column=column["col_name"], - column_index=(idx + 1), + column_index=idx, dtype=column["data_type"], ) for idx, column in enumerate(rows) @@ -413,7 +419,7 @@ def parse_columns_from_information( # type: ignore[override] table_schema=relation.schema, table_name=relation.table, table_type=relation.type, - column_index=(match_num + 1), + column_index=match_num, table_owner=owner, column=column_name, dtype=DatabricksColumn.translate_type(column_type), @@ -425,60 +431,54 @@ def parse_columns_from_information( # type: ignore[override] def get_catalog( self, manifest: Manifest, selected_nodes: Optional[Set] = None ) -> Tuple[Table, List[Exception]]: - schema_map = self._get_catalog_schemas(manifest) - with executor(self.config) as tpe: + catalog_relations = self._get_catalog_relations(manifest, selected_nodes) + relations_by_catalog = self._get_catalog_relations_by_info_schema(catalog_relations) + futures: List[Future[Table]] = [] - for info, schemas in schema_map.items(): - for schema in schemas: - futures.append( - tpe.submit_connected( - self, - schema, - self._get_one_catalog, - info, - [schema], - manifest, + + for info_schema, relations in relations_by_catalog.items(): + if is_hive_metastore(info_schema.database): + schema_map = defaultdict(list) + for relation in relations: + schema_map[relation.schema].append(relation) + + for schema, schema_relations in schema_map.items(): + futures.append( + tpe.submit_connected( + self, + "hive_metastore", + self._get_hive_catalog, + schema, + schema_relations, + ) ) + else: + name = ".".join([str(info_schema.database), "information_schema"]) + fut = tpe.submit_connected( + self, + name, + self._get_one_catalog_by_relations, + info_schema, + relations, + manifest, ) + futures.append(fut) + catalogs, exceptions = catch_as_completed(futures) return catalogs, exceptions - def _get_one_catalog( + def _get_hive_catalog( self, - information_schema: InformationSchema, - schemas: Set[str], - manifest: Manifest, + schema: str, + relations: List[BaseRelation], ) -> Table: - if len(schemas) != 1: - raise dbt.exceptions.CompilationError( - f"Expected only one schema in spark _get_one_catalog, found " f"{schemas}" - ) - - database = information_schema.database - schema = list(schemas)[0] - - nodes: Iterator[ResultNode] = chain( - ( - node - for node in manifest.nodes.values() - if (node.is_relational and not node.is_ephemeral_model) - ), - manifest.sources.values(), - ) - - table_names: Set[str] = set() - for node in nodes: - if node.database == database and node.schema == schema: - relation = self.Relation.create_from(self.config, node) - if relation.identifier: - table_names.add(relation.identifier) - + table_names = extract_identifiers(relations) columns: List[Dict[str, Any]] = [] if len(table_names) > 0: schema_relation = self.Relation.create( - database=database, + database="hive_metastore", schema=schema, identifier=get_identifier_list_string(table_names), quote_policy=self.config.quoting, diff --git a/dbt/adapters/databricks/relation.py b/dbt/adapters/databricks/relation.py index 00d6720a4..63309e951 100644 --- a/dbt/adapters/databricks/relation.py +++ b/dbt/adapters/databricks/relation.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, List, Optional, Set, Type from dbt.contracts.relation import ( ComponentName, ) @@ -44,7 +44,7 @@ class DatabricksInformationSchema(InformationSchema): quote_character: str = "`" def is_hive_metastore(self) -> bool: - return self.database is None or self.database == "hive_metastore" + return is_hive_metastore(self.database) @dataclass(frozen=True, eq=False, repr=False) @@ -135,3 +135,11 @@ def information_schema(self, view_name: Optional[str] = None) -> InformationSche # Instead address this as .information_schema by default info_schema = DatabricksInformationSchema.from_relation(self, view_name) return info_schema.incorporate(path={"schema": None}) + + +def is_hive_metastore(database: Optional[str]) -> bool: + return database is None or database.lower() == "hive_metastore" + + +def extract_identifiers(relations: List[BaseRelation]) -> Set[str]: + return {r.identifier for r in relations if r.identifier is not None} diff --git a/dbt/include/databricks/macros/catalog.sql b/dbt/include/databricks/macros/catalog.sql index 4cfd1c7ac..0749bca95 100644 --- a/dbt/include/databricks/macros/catalog.sql +++ b/dbt/include/databricks/macros/catalog.sql @@ -18,3 +18,84 @@ use catalog {{ adapter.quote(catalog) }} {% endcall %} {% endmacro %} + +{% macro databricks__get_catalog_relations(information_schema, relations) -%} + + {% set query %} + with tables as ( + {{ databricks__get_catalog_tables_sql(information_schema) }} + {{ databricks__get_catalog_relations_where_clause_sql(relations) }} + ), + columns as ( + {{ databricks__get_catalog_columns_sql(information_schema) }} + {{ databricks__get_catalog_relations_where_clause_sql(relations) }} + ) + {{ databricks__get_catalog_results_sql() }} + {%- endset -%} + + {{ return(run_query(query)) }} +{%- endmacro %} + +{% macro databricks__get_catalog_tables_sql(information_schema) -%} + select + table_catalog as table_database, + table_schema, + table_name, + lower(if(table_type in ('MANAGED', 'EXTERNAL'), 'table', table_type)) as table_type, + comment as table_comment, + table_owner, + 'Last Modified' as `stats:last_modified:label`, + last_altered as `stats:last_modified:value`, + 'The timestamp for last update/change' as `stats:last_modified:description`, + (last_altered is not null and table_type not ilike '%VIEW%') as `stats:last_modified:include` + from {{ information_schema }}.tables +{%- endmacro %} + +{% macro databricks__get_catalog_columns_sql(information_schema) -%} + select + table_catalog as table_database, + table_schema, + table_name, + column_name, + ordinal_position as column_index, + lower(data_type) as column_type, + comment as column_comment + from {{ information_schema }}.columns +{%- endmacro %} + +{% macro databricks__get_catalog_results_sql() -%} + select * + from tables + join columns using (table_database, table_schema, table_name) + order by column_index +{%- endmacro %} + +{% macro databricks__get_catalog_schemas_where_clause_sql(schemas) -%} + where ({%- for schema in schemas -%} + upper(table_schema) = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%} + {%- endfor -%}) +{%- endmacro %} + + +{% macro databricks__get_catalog_relations_where_clause_sql(relations) -%} + where ( + {%- for relation in relations -%} + {% if relation.schema and relation.identifier %} + ( + upper(table_schema) = upper('{{ relation.schema }}') + and upper(table_name) = upper('{{ relation.identifier }}') + ) + {% elif relation.schema %} + ( + upper(table_schema) = upper('{{ relation.schema }}') + ) + {% else %} + {% do exceptions.raise_compiler_error( + '`get_catalog_relations` requires a list of relations, each with a schema' + ) %} + {% endif %} + + {%- if not loop.last %} or {% endif -%} + {%- endfor -%} + ) +{%- endmacro %} \ No newline at end of file diff --git a/tests/functional/adapter/test_basic.py b/tests/functional/adapter/test_basic.py index f27a119ee..7fc27dce9 100644 --- a/tests/functional/adapter/test_basic.py +++ b/tests/functional/adapter/test_basic.py @@ -1,9 +1,6 @@ import pytest -from dbt.tests.adapter.basic.expected_catalog import ( - base_expected_catalog, - expected_references_catalog, -) +from dbt.tests.adapter.basic.expected_catalog import base_expected_catalog from dbt.tests.util import AnyInteger, AnyString from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations @@ -69,7 +66,7 @@ def expected_catalog(self, project): return base_expected_catalog( project, role=AnyString(), - id_type="bigint", + id_type=AnyLongType(), text_type="string", time_type="timestamp", view_type="view", @@ -81,18 +78,143 @@ def expected_catalog(self, project): class TestDocsGenReferencesDatabricks(BaseDocsGenReferences): @pytest.fixture(scope="class") def expected_catalog(self, project): - return expected_references_catalog( + return self.expected_references_catalog( project, role=AnyString(), - id_type="bigint", + id_type=AnyLongType(), text_type="string", time_type="timestamp", - bigint_type="bigint", + bigint_type=AnyLongType(), view_type="view", table_type="table", model_stats=_StatsLikeDict(), ) + # Temporary until upstream fixes to allow 0-based indexing + def expected_references_catalog( + self, + project, + role, + id_type, + text_type, + time_type, + view_type, + table_type, + model_stats, + bigint_type=None, + ): + seed_stats = model_stats + view_summary_stats = model_stats + + model_database = project.database + my_schema_name = project.test_schema + + summary_columns = { + "first_name": { + "name": "first_name", + "index": 0, + "type": text_type, + "comment": None, + }, + "ct": { + "name": "ct", + "index": 1, + "type": bigint_type, + "comment": None, + }, + } + + seed_columns = { + "id": { + "name": "id", + "index": 0, + "type": id_type, + "comment": None, + }, + "first_name": { + "name": "first_name", + "index": 1, + "type": text_type, + "comment": None, + }, + "email": { + "name": "email", + "index": 2, + "type": text_type, + "comment": None, + }, + "ip_address": { + "name": "ip_address", + "index": 3, + "type": text_type, + "comment": None, + }, + "updated_at": { + "name": "updated_at", + "index": 4, + "type": time_type, + "comment": None, + }, + } + return { + "nodes": { + "seed.test.seed": { + "unique_id": "seed.test.seed", + "metadata": { + "schema": my_schema_name, + "database": project.database, + "name": "seed", + "type": table_type, + "comment": None, + "owner": role, + }, + "stats": seed_stats, + "columns": seed_columns, + }, + "model.test.ephemeral_summary": { + "unique_id": "model.test.ephemeral_summary", + "metadata": { + "schema": my_schema_name, + "database": model_database, + "name": "ephemeral_summary", + "type": table_type, + "comment": None, + "owner": role, + }, + "stats": model_stats, + "columns": summary_columns, + }, + "model.test.view_summary": { + "unique_id": "model.test.view_summary", + "metadata": { + "schema": my_schema_name, + "database": model_database, + "name": "view_summary", + "type": view_type, + "comment": None, + "owner": role, + }, + "stats": view_summary_stats, + "columns": summary_columns, + }, + }, + "sources": { + "source.test.my_source.my_table": { + "unique_id": "source.test.my_source.my_table", + "metadata": { + "schema": my_schema_name, + "database": project.database, + "name": "seed", + "type": table_type, + "comment": None, + "owner": role, + }, + "stats": seed_stats, + "columns": seed_columns, + }, + }, + } + class _StatsLikeDict: """Any stats-like dict. Use this in assert calls""" @@ -112,3 +234,10 @@ def __eq__(self, other): } ) ) + + +class AnyLongType: + """Allows bigint and long to be treated equivalently""" + + def __eq__(self, other): + return isinstance(other, str) and other in ("bigint", "long") diff --git a/tests/integration/debug/test_debug.py b/tests/integration/debug/test_debug.py index 00c519486..2bcb0b27f 100644 --- a/tests/integration/debug/test_debug.py +++ b/tests/integration/debug/test_debug.py @@ -13,23 +13,23 @@ def schema(self): def models(self): return "models" - def run_and_test(self, contains_catalog: bool): + def run_and_test(self): with mock.patch("sys.stdout", new=StringIO()) as fake_out: self.run_dbt(["debug"]) stdout = fake_out.getvalue() self.assertIn("host: ", stdout) self.assertIn("http_path: ", stdout) self.assertIn("schema: ", stdout) - (self.assertIn if contains_catalog else self.assertNotIn)("catalog: ", stdout) + self.assertIn("catalog: ", stdout) @use_profile("databricks_cluster") def test_debug_databricks_cluster(self): - self.run_and_test(contains_catalog=False) + self.run_and_test() @use_profile("databricks_uc_cluster") def test_debug_databricks_uc_cluster(self): - self.run_and_test(contains_catalog=True) + self.run_and_test() @use_profile("databricks_uc_sql_endpoint") def test_debug_databricks_uc_sql_endpoint(self): - self.run_and_test(contains_catalog=True) + self.run_and_test() diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index a03b81bff..cb1f11545 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -426,7 +426,7 @@ def test_parse_relation(self): "table_type": rel_type, "table_owner": "root", "column": "col1", - "column_index": 1, + "column_index": 0, "dtype": "decimal(22,0)", "numeric_scale": None, "numeric_precision": None, @@ -443,7 +443,7 @@ def test_parse_relation(self): "table_type": rel_type, "table_owner": "root", "column": "col2", - "column_index": 2, + "column_index": 1, "dtype": "string", "numeric_scale": None, "numeric_precision": None, @@ -460,7 +460,7 @@ def test_parse_relation(self): "table_type": rel_type, "table_owner": "root", "column": "dt", - "column_index": 3, + "column_index": 2, "dtype": "date", "numeric_scale": None, "numeric_precision": None, @@ -477,7 +477,7 @@ def test_parse_relation(self): "table_type": rel_type, "table_owner": "root", "column": "struct_col", - "column_index": 4, + "column_index": 3, "dtype": "struct", "numeric_scale": None, "numeric_precision": None, @@ -575,7 +575,7 @@ def test_parse_relation_with_statistics(self): "table_type": rel_type, "table_owner": "root", "column": "col1", - "column_index": 1, + "column_index": 0, "dtype": "decimal(22,0)", "numeric_scale": None, "numeric_precision": None, @@ -643,7 +643,7 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) "table_type": rel_type, "table_owner": "root", "column": "col1", - "column_index": 1, + "column_index": 0, "dtype": "decimal(22,0)", "numeric_scale": None, "numeric_precision": None, @@ -664,7 +664,7 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) "table_type": rel_type, "table_owner": "root", "column": "struct_col", - "column_index": 4, + "column_index": 3, "dtype": "struct", "numeric_scale": None, "numeric_precision": None, @@ -728,7 +728,7 @@ def test_parse_columns_from_information_with_view_type(self): "table_type": rel_type, "table_owner": "root", "column": "col2", - "column_index": 2, + "column_index": 1, "dtype": "string", "numeric_scale": None, "numeric_precision": None, @@ -745,7 +745,7 @@ def test_parse_columns_from_information_with_view_type(self): "table_type": rel_type, "table_owner": "root", "column": "struct_col", - "column_index": 4, + "column_index": 3, "dtype": "struct", "numeric_scale": None, "numeric_precision": None, @@ -794,7 +794,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel "table_type": rel_type, "table_owner": "root", "column": "dt", - "column_index": 3, + "column_index": 2, "dtype": "date", "numeric_scale": None, "numeric_precision": None, @@ -819,7 +819,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel "table_type": rel_type, "table_owner": "root", "column": "struct_col", - "column_index": 4, + "column_index": 3, "dtype": "struct", "numeric_scale": None, "numeric_precision": None, diff --git a/tests/unit/test_relation.py b/tests/unit/test_relation.py index a7da9f27f..5a827bade 100644 --- a/tests/unit/test_relation.py +++ b/tests/unit/test_relation.py @@ -1,7 +1,9 @@ import unittest from jinja2.runtime import Undefined +import pytest +from dbt.adapters.databricks import relation from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksQuotePolicy @@ -177,3 +179,35 @@ def test_matches(self): relation = DatabricksRelation.from_dict(data) self.assertFalse(relation.matches("some_database", "some_schema", "table")) + + +class TestRelationsFunctions: + @pytest.mark.parametrize( + "database, expected", [(None, True), ("hive_metastore", True), ("not_hive", False)] + ) + def test_is_hive_metastore(self, database, expected): + assert relation.is_hive_metastore(database) is expected + + @pytest.mark.parametrize( + "input, expected", + [ + ([], set()), + ([DatabricksRelation.create(identifier=None)], set()), + ( + [ + DatabricksRelation.create(identifier=None), + DatabricksRelation.create(identifier="test"), + ], + {"test"}, + ), + ( + [ + DatabricksRelation.create(identifier="test"), + DatabricksRelation.create(identifier="test"), + ], + {"test"}, + ), + ], + ) + def test_extract_identifiers(self, input, expected): + assert relation.extract_identifiers(input) == expected