diff --git a/.changes/unreleased/Under the Hood-20240716-174655.yaml b/.changes/unreleased/Under the Hood-20240716-174655.yaml new file mode 100644 index 000000000..14c3c8d76 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240716-174655.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Add support for experimental record/replay testing. +time: 2024-07-16T17:46:55.11204-04:00 +custom: + Author: peterallenwebb + Issue: "1106" diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index 6b325ab9c..10bee30f0 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -44,6 +44,7 @@ DbtConfigError, ) from dbt_common.exceptions import DbtDatabaseError +from dbt_common.record import get_record_mode_from_env, RecorderMode from dbt.adapters.exceptions.connection import FailedToConnectError from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials from dbt.adapters.sql import SQLConnectionManager @@ -51,6 +52,7 @@ from dbt_common.events.functions import warn_or_error from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError from dbt_common.ui import line_wrap_message, warning_tag +from dbt.adapters.snowflake.record import SnowflakeRecordReplayHandle from dbt.adapters.snowflake.auth import private_key_from_file, private_key_from_string @@ -372,20 +374,32 @@ def connect(): if creds.query_tag: session_parameters.update({"QUERY_TAG": creds.query_tag}) + handle = None + + # In replay mode, we won't connect to a real database at all, while + # in record and diff modes we do, but insert an intermediate handle + # object which monitors native connection activity. + rec_mode = get_record_mode_from_env() + handle = None + if rec_mode != RecorderMode.REPLAY: + handle = snowflake.connector.connect( + account=creds.account, + database=creds.database, + schema=creds.schema, + warehouse=creds.warehouse, + role=creds.role, + autocommit=True, + client_session_keep_alive=creds.client_session_keep_alive, + application="dbt", + insecure_mode=creds.insecure_mode, + session_parameters=session_parameters, + **creds.auth_args(), + ) - handle = snowflake.connector.connect( - account=creds.account, - database=creds.database, - schema=creds.schema, - warehouse=creds.warehouse, - role=creds.role, - autocommit=True, - client_session_keep_alive=creds.client_session_keep_alive, - application="dbt", - insecure_mode=creds.insecure_mode, - session_parameters=session_parameters, - **creds.auth_args(), - ) + if rec_mode is not None: + # If using the record/replay mechanism, regardless of mode, we + # use a wrapper. + handle = SnowflakeRecordReplayHandle(handle, connection) return handle diff --git a/dbt/adapters/snowflake/record/__init__.py b/dbt/adapters/snowflake/record/__init__.py new file mode 100644 index 000000000..f763dc3a4 --- /dev/null +++ b/dbt/adapters/snowflake/record/__init__.py @@ -0,0 +1,2 @@ +from dbt.adapters.snowflake.record.cursor.cursor import SnowflakeRecordReplayCursor +from dbt.adapters.snowflake.record.handle import SnowflakeRecordReplayHandle diff --git a/dbt/adapters/snowflake/record/cursor/cursor.py b/dbt/adapters/snowflake/record/cursor/cursor.py new file mode 100644 index 000000000..a07468867 --- /dev/null +++ b/dbt/adapters/snowflake/record/cursor/cursor.py @@ -0,0 +1,21 @@ +from dbt_common.record import record_function + +from dbt.adapters.record import RecordReplayCursor +from dbt.adapters.snowflake.record.cursor.sfqid import CursorGetSfqidRecord +from dbt.adapters.snowflake.record.cursor.sqlstate import CursorGetSqlStateRecord + + +class SnowflakeRecordReplayCursor(RecordReplayCursor): + """A custom extension of RecordReplayCursor that adds the sqlstate + and sfqid properties which are specific to snowflake-connector.""" + + @property + @property + @record_function(CursorGetSqlStateRecord, method=True, id_field_name="connection_name") + def sqlstate(self): + return self.native_cursor.sqlstate + + @property + @record_function(CursorGetSfqidRecord, method=True, id_field_name="connection_name") + def sfqid(self): + return self.native_cursor.sfqid diff --git a/dbt/adapters/snowflake/record/cursor/sfqid.py b/dbt/adapters/snowflake/record/cursor/sfqid.py new file mode 100644 index 000000000..e39c857d3 --- /dev/null +++ b/dbt/adapters/snowflake/record/cursor/sfqid.py @@ -0,0 +1,21 @@ +import dataclasses +from typing import Optional + +from dbt_common.record import Record, Recorder + + +@dataclasses.dataclass +class CursorGetSfqidParams: + connection_name: str + + +@dataclasses.dataclass +class CursorGetSfqidResult: + msg: Optional[str] + + +@Recorder.register_record_type +class CursorGetSfqidRecord(Record): + params_cls = CursorGetSfqidParams + result_cls = CursorGetSfqidResult + group = "Database" diff --git a/dbt/adapters/snowflake/record/cursor/sqlstate.py b/dbt/adapters/snowflake/record/cursor/sqlstate.py new file mode 100644 index 000000000..5619058fd --- /dev/null +++ b/dbt/adapters/snowflake/record/cursor/sqlstate.py @@ -0,0 +1,21 @@ +import dataclasses +from typing import Optional + +from dbt_common.record import Record, Recorder + + +@dataclasses.dataclass +class CursorGetSqlStateParams: + connection_name: str + + +@dataclasses.dataclass +class CursorGetSqlStateResult: + msg: Optional[str] + + +@Recorder.register_record_type +class CursorGetSqlStateRecord(Record): + params_cls = CursorGetSqlStateParams + result_cls = CursorGetSqlStateResult + group = "Database" diff --git a/dbt/adapters/snowflake/record/handle.py b/dbt/adapters/snowflake/record/handle.py new file mode 100644 index 000000000..046bb911b --- /dev/null +++ b/dbt/adapters/snowflake/record/handle.py @@ -0,0 +1,12 @@ +from dbt.adapters.record import RecordReplayHandle + +from dbt.adapters.snowflake.record.cursor.cursor import SnowflakeRecordReplayCursor + + +class SnowflakeRecordReplayHandle(RecordReplayHandle): + """A custom extension of RecordReplayHandle that returns a + snowflake-connector-specific SnowflakeRecordReplayCursor object.""" + + def cursor(self): + cursor = None if self.native_handle is None else self.native_handle.cursor() + return SnowflakeRecordReplayCursor(cursor, self.connection)