Skip to content

Commit

Permalink
Allow models to execute on different warehouses
Browse files Browse the repository at this point in the history
Signed-off-by: Raymond Cypher <[email protected]>
  • Loading branch information
rcypher-databricks committed Oct 26, 2023
1 parent b199f38 commit ef89e18
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 1 deletion.
102 changes: 101 additions & 1 deletion dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,15 @@
Connection,
ConnectionState,
DEFAULT_QUERY_COMMENT,
Identifier,
LazyHandle,
)
from dbt.events.types import (
NewConnection,
ConnectionReused,
)
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.events import AdapterLogger
from dbt.events.contextvars import get_node_info
from dbt.events.functions import fire_event
Expand Down Expand Up @@ -111,6 +118,10 @@ class DatabricksCredentials(Credentials):
connection_parameters: Optional[Dict[str, Any]] = None
auth_type: Optional[str] = None

# Named compute resources specified in the profile. Used for
# creating a connection when a model specifies a compute resource.
compute: Optional[Dict[str, Any]] = None

connect_retries: int = 1
connect_timeout: Optional[int] = None
retry_all: bool = False
Expand Down Expand Up @@ -741,6 +752,50 @@ def exception_handler(self, sql: str) -> Iterator[None]:
else:
raise dbt.exceptions.DbtRuntimeError(str(exc)) from exc

# override/overload
def set_connection_name(
self, name: Optional[str] = None, node: Optional[ResultNode] = None
) -> Connection:
"""Called by 'acquire_connection' in DatabricksAdapter, which is called by
'connection_named', called by 'connection_for(node)'.
Creates a connection for this thread if one doesn't already
exist, and will rename an existing connection."""

conn_name: str = "master" if name is None else name

# Get a connection for this thread
conn = self.get_if_exists()

if conn and conn.name == conn_name and conn.state == "open":
# Found a connection and nothing to do, so just return it
return conn

if conn is None:
# Create a new connection
conn = Connection(
type=Identifier(self.TYPE),
name=conn_name,
state=ConnectionState.INIT,
transaction_open=False,
handle=None,
credentials=self.profile.credentials,
)
conn.handle = LazyHandle(self.get_open_for_model(node))
# Add the connection to thread_connections for this thread
self.set_thread_connection(conn)
fire_event(
NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info())
)
else: # existing connection either wasn't open or didn't have the right name
if conn.state != "open":
conn.handle = LazyHandle(self.get_open_for_model(node))
if conn.name != conn_name:
orig_conn_name: str = conn.name or ""
conn.name = conn_name
fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name))

return conn

def add_query(
self,
sql: str,
Expand Down Expand Up @@ -849,8 +904,29 @@ def list_tables(self, database: str, schema: str, identifier: Optional[str] = No
),
)

@classmethod
def get_open_for_model(
cls, node: Optional[ResultNode] = None
) -> Callable[[Connection], Connection]:
# If there is no node we can simply return the exsting class method open.
# If there is a node create a closure that will call cls._open with the node.
if not node:
return cls.open

def _open(connection: Connection) -> Connection:
return cls._open(connection, node)

return _open

@classmethod
def open(cls, connection: Connection) -> Connection:
# Simply call _open with no ResultNode argument.
# Because this is an overridden method we can't just add
# a ResultNode parameter to open.
return cls._open(connection)

@classmethod
def _open(cls, connection: Connection, node: Optional[ResultNode] = None) -> Connection:
if connection.state == ConnectionState.OPEN:
logger.debug("Connection is already open, skipping open.")
return connection
Expand All @@ -873,12 +949,16 @@ def open(cls, connection: Connection) -> Connection:
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

# If a model specifies a compute resource to use the http path
# may be different than the http_path property of creds.
http_path = get_http_path(node, creds)

def connect() -> DatabricksSQLConnectionWrapper:
try:
# TODO: what is the error when a user specifies a catalog they don't have access to
conn: DatabricksSQLConnection = dbsql.connect(
server_hostname=creds.host,
http_path=creds.http_path,
http_path=http_path,
credentials_provider=cls.credentials_provider,
http_headers=http_headers if http_headers else None,
session_configuration=creds.session_properties,
Expand Down Expand Up @@ -1016,3 +1096,23 @@ def _get_update_error_msg(host: str, headers: dict, pipeline_id: str, update_id:
msg = error_events[0].get("message", "")

return msg


def get_compute_name(node: Optional[ResultNode]) -> Optional[str]:
# Get the name of the specified compute resource from the node's
# config.
compute_name = None
if node and node.config and node.config.extra:
compute_name = node.config.extra.get("databricks_compute", None)
return compute_name


def get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]:
# Get the http path of the compute resource specified in the node's config.
# If none is specified return the default path from creds.
compute_name = get_compute_name(node)
http_path = creds.http_path
if compute_name and creds.compute:
http_path = creds.compute.get(compute_name, {}).get("http_path", creds.http_path)

return http_path
20 changes: 20 additions & 0 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER, empty_table
from dbt.contracts.connection import AdapterResponse, Connection
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.relation import RelationType
import dbt.exceptions
from dbt.events import AdapterLogger
Expand Down Expand Up @@ -118,6 +119,25 @@ class DatabricksAdapter(SparkAdapter):
}
)

# override/overload
def acquire_connection(
self, name: Optional[str] = None, node: Optional[ResultNode] = None
) -> Connection:
return self.connections.set_connection_name(name, node)

# override
@contextmanager
def connection_named(self, name: str, node: Optional[ResultNode] = None) -> Iterator[None]:
try:
if self.connections.query_header is not None:
self.connections.query_header.set(name, node)
self.acquire_connection(name, node)
yield
finally:
self.release_connection()
if self.connections.query_header is not None:
self.connections.query_header.reset()

@available.parse(lambda *a, **k: 0)
def compare_dbr_version(self, major: int, minor: int) -> int:
"""
Expand Down

0 comments on commit ef89e18

Please sign in to comment.