From 2b01804e78d56e7903bba5999d11ec19b1b2b482 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Thu, 31 Oct 2024 16:57:02 -0400 Subject: [PATCH 01/41] fix imports --- dbt/adapters/bigquery/gcloud.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 dbt/adapters/bigquery/gcloud.py diff --git a/dbt/adapters/bigquery/gcloud.py b/dbt/adapters/bigquery/gcloud.py new file mode 100644 index 000000000..e69de29bb From 8b455940cfde2ac163dff495200f898ad5f882bb Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Fri, 1 Nov 2024 16:04:44 -0400 Subject: [PATCH 02/41] create a retry factory and move relevant objects from connections --- dbt/adapters/bigquery/connections.py | 42 +--------- dbt/adapters/bigquery/retry.py | 110 +++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 41 deletions(-) create mode 100644 dbt/adapters/bigquery/retry.py diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index bda54080b..486c1556f 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -41,6 +41,7 @@ get_bigquery_defaults, setup_default_credentials, ) +from dbt.adapters.bigquery.retry import _BufferedPredicate as _ErrorCounter from dbt.adapters.bigquery.utility import is_base64, base64_to_string if TYPE_CHECKING: @@ -60,14 +61,6 @@ ConnectionError, ) -RETRYABLE_ERRORS = ( - google.cloud.exceptions.ServerError, - google.cloud.exceptions.BadRequest, - google.cloud.exceptions.BadGateway, - ConnectionResetError, - ConnectionError, -) - @dataclass class BigQueryAdapterResponse(AdapterResponse): @@ -693,39 +686,6 @@ def _labels_from_query_comment(self, comment: str) -> Dict: } -class _ErrorCounter(object): - """Counts errors seen up to a threshold then raises the next error.""" - - def __init__(self, retries): - self.retries = retries - self.error_count = 0 - - def count_error(self, error): - if self.retries == 0: - return False # Don't log - self.error_count += 1 - if _is_retryable(error) and self.error_count <= self.retries: - logger.debug( - "Retry attempt {} of {} after error: {}".format( - self.error_count, self.retries, repr(error) - ) - ) - return True - else: - return False - - -def _is_retryable(error): - """Return true for errors that are unlikely to occur again if retried.""" - if isinstance(error, RETRYABLE_ERRORS): - return True - elif isinstance(error, google.api_core.exceptions.Forbidden) and any( - e["reason"] == "rateLimitExceeded" for e in error.errors - ): - return True - return False - - _SANITIZE_LABEL_PATTERN = re.compile(r"[^a-z0-9_-]") _VALIDATE_LABEL_LENGTH_LIMIT = 63 diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py new file mode 100644 index 000000000..a2998ae97 --- /dev/null +++ b/dbt/adapters/bigquery/retry.py @@ -0,0 +1,110 @@ +from typing import Callable + +from google.api_core import retry +from google.api_core.exceptions import ClientError, Forbidden +from google.cloud.exceptions import BadGateway, BadRequest, ServerError + +from dbt.adapters.events.logging import AdapterLogger + +from dbt.adapters.bigquery.connections import logger +from dbt.adapters.bigquery.credentials import BigQueryCredentials + + +_logger = AdapterLogger("BigQuery") + + +RETRYABLE_ERRORS = ( + ServerError, + BadRequest, + BadGateway, + ConnectionResetError, + ConnectionError, +) + + +class RetryFactory: + + DEFAULT_INITIAL_DELAY = 1.0 # seconds + DEFAULT_MAXIMUM_DELAY = 3.0 # seconds + + def __init__(self, credentials: BigQueryCredentials) -> None: + self._retries = credentials.job_retries or 0 + self._deadline = credentials.job_retry_deadline_seconds + + def deadline(self, on_error: Callable[[Exception], None]) -> retry.Retry: + """ + This strategy mimics what was accomplished with _retry_and_handle + """ + return retry.Retry( + predicate=self._buffered_predicate(), + initial=self.DEFAULT_INITIAL_DELAY, + maximum=self.DEFAULT_MAXIMUM_DELAY, + deadline=self.deadline, + on_error=on_error, + ) + + def _buffered_predicate(self) -> Callable[[Exception], bool]: + class BufferedPredicate: + """ + Count ALL errors, not just retryable errors, up to a threshold + then raises the next error, regardless of whether it is retryable. + + Was previously called _ErrorCounter. + """ + + def __init__(self, retries: int) -> None: + self._retries: int = retries + self._error_count = 0 + + def __call__(self, error: Exception) -> bool: + # exit immediately if the user does not want retries + if self._retries == 0: + return False + + # count all errors + self._error_count += 1 + + # if the error is retryable and we haven't breached the threshold, log and continue + if _is_retryable(error) and self._error_count <= self._retries: + _logger.debug( + f"Retry attempt {self._error_count} of { self._retries} after error: {repr(error)}" + ) + return True + + # otherwise raise + return False + + return BufferedPredicate(self._retries) + + +def _is_retryable(error: ClientError) -> bool: + """Return true for errors that are unlikely to occur again if retried.""" + if isinstance(error, RETRYABLE_ERRORS): + return True + elif isinstance(error, Forbidden) and any( + e["reason"] == "rateLimitExceeded" for e in error.errors + ): + return True + return False + + +class _BufferedPredicate: + """Counts errors seen up to a threshold then raises the next error.""" + + def __init__(self, retries: int) -> None: + self._retries = retries + self._error_count = 0 + + def count_error(self, error): + if self._retries == 0: + return False # Don't log + self._error_count += 1 + if _is_retryable(error) and self.error_count <= self._retries: + logger.debug( + "Retry attempt {} of {} after error: {}".format( + self._error_count, self._retries, repr(error) + ) + ) + return True + else: + return False From 391099d8e5498dc46c3ebd5d68a408946bddfa8c Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Fri, 1 Nov 2024 16:30:40 -0400 Subject: [PATCH 03/41] add on_error method for deadline retries --- dbt/adapters/bigquery/connections.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 486c1556f..6aa3907fb 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -5,7 +5,7 @@ import json from multiprocessing.context import SpawnContext import re -from typing import Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING +from typing import Callable, Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING import uuid from google.api_core import client_info, client_options, retry @@ -28,6 +28,7 @@ from dbt.adapters.contracts.connection import ( AdapterRequiredConfig, AdapterResponse, + Connection, ConnectionState, ) from dbt.adapters.events.logging import AdapterLogger @@ -41,7 +42,7 @@ get_bigquery_defaults, setup_default_credentials, ) -from dbt.adapters.bigquery.retry import _BufferedPredicate as _ErrorCounter +from dbt.adapters.bigquery.retry import _BufferedPredicate as _ErrorCounter, RetryFactory from dbt.adapters.bigquery.utility import is_base64, base64_to_string if TYPE_CHECKING: @@ -56,10 +57,7 @@ WRITE_TRUNCATE = google.cloud.bigquery.job.WriteDisposition.WRITE_TRUNCATE -REOPENABLE_ERRORS = ( - ConnectionResetError, - ConnectionError, -) +REOPENABLE_ERRORS = (ConnectionError,) @dataclass @@ -81,6 +79,18 @@ class BigQueryConnectionManager(BaseConnectionManager): def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): super().__init__(profile, mp_context) self.jobs_by_thread: Dict[Hashable, List[str]] = defaultdict(list) + self._retry = RetryFactory(profile.credentials) + + def _reopen_on_error(self, connection: Connection) -> Callable[[Exception], None]: + + def _on_error(error: Exception): + if isinstance(error, ConnectionError): + logger.warning("Reopening connection after {!r}".format(error)) + self.close(connection) + self.open(connection) + return + + return _on_error @classmethod def handle_error(cls, error, message): From 7872a584a2e30a59d3bc5ca4a654e77129503ebe Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Fri, 1 Nov 2024 16:49:41 -0400 Subject: [PATCH 04/41] remove dependency on retry_and_handle from cancel_open --- dbt/adapters/bigquery/connections.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 6aa3907fb..86e780f85 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -156,15 +156,15 @@ def cancel_open(self): for thread_id, connection in self.thread_connections.items(): if connection is this_connection: continue + if connection.handle is not None and connection.state == ConnectionState.OPEN: client = connection.handle for job_id in self.jobs_by_thread.get(thread_id, []): - - def fn(): - return client.cancel_job(job_id) - - self._retry_and_handle(msg=f"Cancel job: {job_id}", conn=connection, fn=fn) - + with self.exception_handler(f"Cancel job: {job_id}"): + client.cancel_job( + job_id, + retry=self._retry.deadline(self._reopen_on_error(connection)), + ) self.close(connection) if connection.name is not None: From 42a88694bba3fb78bc18fb3c29e1a05b1d920f31 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Fri, 1 Nov 2024 17:48:19 -0400 Subject: [PATCH 05/41] remove dependencies on retry_and_handle --- dbt/adapters/bigquery/connections.py | 165 ++++++++---------- dbt/adapters/bigquery/retry.py | 42 ++++- .../unit/test_bigquery_connection_manager.py | 39 +---- 3 files changed, 109 insertions(+), 137 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 86e780f85..9266ff1c2 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -18,7 +18,16 @@ credentials as GoogleCredentials, service_account as GoogleServiceAccountCredentials, ) -from requests.exceptions import ConnectionError +from google.cloud.bigquery import ( + Client, + CopyJobConfig, + DatasetReference, + QueryJobConfig, + QueryPriority, + TableReference, + WriteDisposition, +) +import google.cloud.exceptions from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event @@ -42,7 +51,7 @@ get_bigquery_defaults, setup_default_credentials, ) -from dbt.adapters.bigquery.retry import _BufferedPredicate as _ErrorCounter, RetryFactory +from dbt.adapters.bigquery.retry import RetryFactory from dbt.adapters.bigquery.utility import is_base64, base64_to_string if TYPE_CHECKING: @@ -55,9 +64,12 @@ BQ_QUERY_JOB_SPLIT = "-----Query Job SQL Follows-----" -WRITE_TRUNCATE = google.cloud.bigquery.job.WriteDisposition.WRITE_TRUNCATE +WRITE_TRUNCATE = WriteDisposition.WRITE_TRUNCATE -REOPENABLE_ERRORS = (ConnectionError,) +REOPENABLE_ERRORS = ( + ConnectionError, + ConnectionResetError, +) @dataclass @@ -84,7 +96,7 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): def _reopen_on_error(self, connection: Connection) -> Callable[[Exception], None]: def _on_error(error: Exception): - if isinstance(error, ConnectionError): + if isinstance(error, REOPENABLE_ERRORS): logger.warning("Reopening connection after {!r}".format(error)) self.close(connection) self.open(connection) @@ -264,7 +276,7 @@ def get_bigquery_client(cls, profile_credentials): info = client_info.ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}") options = client_options.ClientOptions(quota_project_id=quota_project) - return google.cloud.bigquery.Client( + return Client( execution_project, creds, location=location, @@ -360,7 +372,6 @@ def raw_execute( dry_run: bool = False, ): conn = self.get_thread_connection() - client = conn.handle fire_event(SQLQuery(conn_name=conn.name, sql=sql, node_info=get_node_info())) @@ -376,34 +387,25 @@ def raw_execute( priority = conn.credentials.priority if priority == Priority.Batch: - job_params["priority"] = google.cloud.bigquery.QueryPriority.BATCH + job_params["priority"] = QueryPriority.BATCH else: - job_params["priority"] = google.cloud.bigquery.QueryPriority.INTERACTIVE + job_params["priority"] = QueryPriority.INTERACTIVE maximum_bytes_billed = conn.credentials.maximum_bytes_billed if maximum_bytes_billed is not None and maximum_bytes_billed != 0: job_params["maximum_bytes_billed"] = maximum_bytes_billed - job_creation_timeout = self.get_job_creation_timeout_seconds(conn) - job_execution_timeout = self.get_job_execution_timeout_seconds(conn) - - def fn(): + with self.exception_handler(sql): job_id = self.generate_job_id() return self._query_and_results( - client, + conn, sql, job_params, job_id, - job_creation_timeout=job_creation_timeout, - job_execution_timeout=job_execution_timeout, limit=limit, ) - query_job, iterator = self._retry_and_handle(msg=sql, conn=conn, fn=fn) - - return query_job, iterator - def execute( self, sql, auto_begin=False, fetch=None, limit: Optional[int] = None ) -> Tuple[BigQueryAdapterResponse, "agate.Table"]: @@ -533,7 +535,7 @@ def standard_to_legacy(table): def copy_bq_table(self, source, destination, write_disposition): conn = self.get_thread_connection() - client = conn.handle + client: Client = conn.handle # ------------------------------------------------------------------------------- # BigQuery allows to use copy API using two different formats: @@ -561,30 +563,27 @@ def copy_bq_table(self, source, destination, write_disposition): write_disposition, ) - def copy_and_results(): - job_config = google.cloud.bigquery.CopyJobConfig(write_disposition=write_disposition) - copy_job = client.copy_table(source_ref_array, destination_ref, job_config=job_config) - timeout = self.get_job_execution_timeout_seconds(conn) or 300 - iterator = copy_job.result(timeout=timeout) - return copy_job, iterator - - self._retry_and_handle( - msg='copy table "{}" to "{}"'.format( - ", ".join(source_ref.path for source_ref in source_ref_array), - destination_ref.path, - ), - conn=conn, - fn=copy_and_results, + msg = 'copy table "{}" to "{}"'.format( + ", ".join(source_ref.path for source_ref in source_ref_array), + destination_ref.path, ) + with self.exception_handler(msg): + copy_job = client.copy_table( + source_ref_array, + destination_ref, + job_config=CopyJobConfig(write_disposition=write_disposition), + retry=self._retry.deadline(self._reopen_on_error(conn)), + ) + copy_job.result(retry=self._retry.job_execution_capped(self._reopen_on_error(conn))) @staticmethod def dataset_ref(database, schema): - return google.cloud.bigquery.DatasetReference(project=database, dataset_id=schema) + return DatasetReference(project=database, dataset_id=schema) @staticmethod def table_ref(database, schema, table_name): - dataset_ref = google.cloud.bigquery.DatasetReference(database, schema) - return google.cloud.bigquery.TableReference(dataset_ref, table_name) + dataset_ref = DatasetReference(database, schema) + return TableReference(dataset_ref, table_name) def get_bq_table(self, database, schema, identifier): """Get a bigquery table for a schema/model.""" @@ -597,53 +596,56 @@ def get_bq_table(self, database, schema, identifier): def drop_dataset(self, database, schema): conn = self.get_thread_connection() - dataset_ref = self.dataset_ref(database, schema) - client = conn.handle - - def fn(): - return client.delete_dataset(dataset_ref, delete_contents=True, not_found_ok=True) - - self._retry_and_handle(msg="drop dataset", conn=conn, fn=fn) + client: Client = conn.handle + with self.exception_handler("drop dataset"): + return client.delete_dataset( + dataset=self.dataset_ref(database, schema), + delete_contents=True, + not_found_ok=True, + retry=self._retry.deadline(self._reopen_on_error(conn)), + ) def create_dataset(self, database, schema): conn = self.get_thread_connection() - client = conn.handle - dataset_ref = self.dataset_ref(database, schema) - - def fn(): - return client.create_dataset(dataset_ref, exists_ok=True) - - self._retry_and_handle(msg="create dataset", conn=conn, fn=fn) + client: Client = conn.handle + with self.exception_handler("create dataset"): + return client.create_dataset( + dataset=self.dataset_ref(database, schema), + exists_ok=True, + retry=self._retry.deadline(self._reopen_on_error(conn)), + ) def list_dataset(self, database: str): - # the database string we get here is potentially quoted. Strip that off - # for the API call. - database = database.strip("`") + # The database string we get here is potentially quoted. + # Strip that off for the API call. conn = self.get_thread_connection() - client = conn.handle - - def query_schemas(): + client: Client = conn.handle + with self.exception_handler("list dataset"): # this is similar to how we have to deal with listing tables - all_datasets = client.list_datasets(project=database, max_results=10000) + all_datasets = client.list_datasets( + project=database.strip("`"), + max_results=10000, + retry=self._retry.deadline(self._reopen_on_error(conn)), + ) return [ds.dataset_id for ds in all_datasets] - return self._retry_and_handle(msg="list dataset", conn=conn, fn=query_schemas) - def _query_and_results( self, - client, + conn, sql, job_params, job_id, - job_creation_timeout=None, - job_execution_timeout=None, limit: Optional[int] = None, ): + client: Client = conn.handle """Query the client and wait for results.""" # Cannot reuse job_config if destination is set and ddl is used - job_config = google.cloud.bigquery.QueryJobConfig(**job_params) + job_config = QueryJobConfig(**job_params) query_job = client.query( - query=sql, job_config=job_config, job_id=job_id, timeout=job_creation_timeout + query=sql, + job_config=job_config, + job_id=job_id, # note, this disables retry since the job_id will have been used + timeout=self._retry.job_creation_timeout, ) if ( query_job.location is not None @@ -654,37 +656,14 @@ def _query_and_results( self._bq_job_link(query_job.location, query_job.project, query_job.job_id) ) try: - iterator = query_job.result(max_results=limit, timeout=job_execution_timeout) + iterator = query_job.result( + max_results=limit, timeout=self._retry.job_execution_timeout + ) return query_job, iterator except TimeoutError: - exc = f"Operation did not complete within the designated timeout of {job_execution_timeout} seconds." + exc = f"Operation did not complete within the designated timeout of {self._retry.job_execution_timeout} seconds." raise TimeoutError(exc) - def _retry_and_handle(self, msg, conn, fn): - """retry a function call within the context of exception_handler.""" - - def reopen_conn_on_error(error): - if isinstance(error, REOPENABLE_ERRORS): - logger.warning("Reopening connection after {!r}".format(error)) - self.close(conn) - self.open(conn) - return - - with self.exception_handler(msg): - return retry.retry_target( - target=fn, - predicate=_ErrorCounter(self.get_job_retries(conn)).count_error, - sleep_generator=self._retry_generator(), - deadline=self.get_job_retry_deadline_seconds(conn), - on_error=reopen_conn_on_error, - ) - - def _retry_generator(self): - """Generates retry intervals that exponentially back off.""" - return retry.exponential_sleep_generator( - initial=self.DEFAULT_INITIAL_DELAY, maximum=self.DEFAULT_MAXIMUM_DELAY - ) - def _labels_from_query_comment(self, comment: str) -> Dict: try: comment_labels = json.loads(comment) diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index a2998ae97..54cb293f3 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -1,7 +1,7 @@ from typing import Callable from google.api_core import retry -from google.api_core.exceptions import ClientError, Forbidden +from google.api_core.exceptions import Forbidden from google.cloud.exceptions import BadGateway, BadRequest, ServerError from dbt.adapters.events.logging import AdapterLogger @@ -24,12 +24,14 @@ class RetryFactory: - DEFAULT_INITIAL_DELAY = 1.0 # seconds - DEFAULT_MAXIMUM_DELAY = 3.0 # seconds + _DEFAULT_INITIAL_DELAY = 1.0 # seconds + _DEFAULT_MAXIMUM_DELAY = 3.0 # seconds def __init__(self, credentials: BigQueryCredentials) -> None: self._retries = credentials.job_retries or 0 - self._deadline = credentials.job_retry_deadline_seconds + self.job_creation_timeout = credentials.job_creation_timeout_seconds + self.job_execution_timeout = credentials.job_execution_timeout_seconds + self.job_deadline = credentials.job_retry_deadline_seconds def deadline(self, on_error: Callable[[Exception], None]) -> retry.Retry: """ @@ -37,9 +39,31 @@ def deadline(self, on_error: Callable[[Exception], None]) -> retry.Retry: """ return retry.Retry( predicate=self._buffered_predicate(), - initial=self.DEFAULT_INITIAL_DELAY, - maximum=self.DEFAULT_MAXIMUM_DELAY, - deadline=self.deadline, + initial=self._DEFAULT_INITIAL_DELAY, + maximum=self._DEFAULT_MAXIMUM_DELAY, + timeout=self.job_deadline, + on_error=on_error, + ) + + def job_execution(self, on_error: Callable[[Exception], None]) -> retry.Retry: + """ + This strategy mimics what was accomplished with _retry_and_handle + """ + return retry.Retry( + predicate=self._buffered_predicate(), + initial=self._DEFAULT_INITIAL_DELAY, + maximum=self._DEFAULT_MAXIMUM_DELAY, + timeout=self.job_execution_timeout, + on_error=on_error, + ) + + def job_execution_capped(self, on_error: Callable[[Exception], None]) -> retry.Retry: + """ + This strategy mimics what was accomplished with _retry_and_handle + """ + return retry.Retry( + predicate=self._buffered_predicate(), + timeout=self.job_execution_timeout or 300, on_error=on_error, ) @@ -77,7 +101,7 @@ def __call__(self, error: Exception) -> bool: return BufferedPredicate(self._retries) -def _is_retryable(error: ClientError) -> bool: +def _is_retryable(error: Exception) -> bool: """Return true for errors that are unlikely to occur again if retried.""" if isinstance(error, RETRYABLE_ERRORS): return True @@ -99,7 +123,7 @@ def count_error(self, error): if self._retries == 0: return False # Don't log self._error_count += 1 - if _is_retryable(error) and self.error_count <= self._retries: + if _is_retryable(error) and self._error_count <= self._retries: logger.debug( "Retry attempt {} of {} after error: {}".format( self._error_count, self._retries, repr(error) diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 1c14100f6..4fa457240 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -26,38 +26,11 @@ def setUp(self): self.connections.get_job_retry_deadline_seconds = lambda x: None self.connections.get_job_retries = lambda x: 1 - @patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True) - def test_retry_and_handle(self, is_retryable): - self.connections.DEFAULT_MAXIMUM_DELAY = 2.0 - - @contextmanager - def dummy_handler(msg): - yield - - self.connections.exception_handler = dummy_handler - - class DummyException(Exception): - """Count how many times this exception is raised""" - - count = 0 - - def __init__(self): - DummyException.count += 1 - - def raiseDummyException(): - raise DummyException() - - with self.assertRaises(DummyException): - self.connections._retry_and_handle( - "some sql", Mock(credentials=Mock(retries=8)), raiseDummyException - ) - self.assertEqual(DummyException.count, 9) - - @patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True) + @patch("dbt.adapters.bigquery.retry._is_retryable", return_value=True) def test_retry_connection_reset(self, is_retryable): self.connections.open = MagicMock() self.connections.close = MagicMock() - self.connections.DEFAULT_MAXIMUM_DELAY = 2.0 + self.connections._retry.DEFAULT_MAXIMUM_DELAY = 2.0 @contextmanager def dummy_handler(msg): @@ -65,17 +38,13 @@ def dummy_handler(msg): self.connections.exception_handler = dummy_handler - def raiseConnectionResetError(): - raise ConnectionResetError("Connection broke") - mock_conn = Mock(credentials=Mock(retries=1)) - with self.assertRaises(ConnectionResetError): - self.connections._retry_and_handle("some sql", mock_conn, raiseConnectionResetError) + # do something that will raise a ConnectionResetError self.connections.close.assert_called_once_with(mock_conn) self.connections.open.assert_called_once_with(mock_conn) def test_is_retryable(self): - _is_retryable = dbt.adapters.bigquery.connections._is_retryable + _is_retryable = dbt.adapters.bigquery.retry._is_retryable exceptions = dbt.adapters.bigquery.impl.google.cloud.exceptions internal_server_error = exceptions.InternalServerError("code broke") bad_request_error = exceptions.BadRequest("code broke") From 900dcac69b9aa51b2b4bf0e27947a93254d21737 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Fri, 1 Nov 2024 17:56:28 -0400 Subject: [PATCH 06/41] remove timeout methods from connection manager --- dbt/adapters/bigquery/connections.py | 20 ------------------- dbt/adapters/bigquery/impl.py | 4 ++-- .../unit/test_bigquery_connection_manager.py | 1 - 3 files changed, 2 insertions(+), 23 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 9266ff1c2..8e8f5a563 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -313,26 +313,6 @@ def open(cls, connection): connection.state = "open" return connection - @classmethod - def get_job_execution_timeout_seconds(cls, conn): - credentials = conn.credentials - return credentials.job_execution_timeout_seconds - - @classmethod - def get_job_retries(cls, conn) -> int: - credentials = conn.credentials - return credentials.job_retries - - @classmethod - def get_job_creation_timeout_seconds(cls, conn): - credentials = conn.credentials - return credentials.job_creation_timeout_seconds - - @classmethod - def get_job_retry_deadline_seconds(cls, conn): - credentials = conn.credentials - return credentials.job_retry_deadline_seconds - @classmethod def get_table_from_response(cls, resp) -> "agate.Table": from dbt_common.clients import agate_helper diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index cf5800fd3..ec9afb08f 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -698,7 +698,7 @@ def load_dataframe( f, table_ref, rewind=True, job_config=load_config, job_id=job_id ) - timeout = self.connections.get_job_execution_timeout_seconds(conn) or 300 + timeout = conn.credentials.job_execution_timeout_seconds or 300 with self.connections.exception_handler("LOAD TABLE"): self.poll_until_job_completes(job, timeout) @@ -721,7 +721,7 @@ def upload_file( with open(local_file_path, "rb") as f: job = client.load_table_from_file(f, table_ref, rewind=True, job_config=load_config) - timeout = self.connections.get_job_execution_timeout_seconds(conn) or 300 + timeout = conn.credentials.job_execution_timeout_seconds or 300 with self.connections.exception_handler("LOAD TABLE"): self.poll_until_job_completes(job, timeout) diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 4fa457240..45dc85136 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -23,7 +23,6 @@ def setUp(self): self.mock_connection.handle = self.mock_client self.connections.get_thread_connection = lambda: self.mock_connection - self.connections.get_job_retry_deadline_seconds = lambda x: None self.connections.get_job_retries = lambda x: 1 @patch("dbt.adapters.bigquery.retry._is_retryable", return_value=True) From 81bfa0ca25a7c03909e75d62c4aea617e0066795 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Mon, 4 Nov 2024 13:34:10 -0500 Subject: [PATCH 07/41] add retry to get_bq_table --- dbt/adapters/bigquery/connections.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 8e8f5a563..68a30bef0 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -27,7 +27,6 @@ TableReference, WriteDisposition, ) -import google.cloud.exceptions from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event @@ -568,11 +567,14 @@ def table_ref(database, schema, table_name): def get_bq_table(self, database, schema, identifier): """Get a bigquery table for a schema/model.""" conn = self.get_thread_connection() + client: Client = conn.handle # backwards compatibility: fill in with defaults if not specified database = database or conn.credentials.database schema = schema or conn.credentials.schema - table_ref = self.table_ref(database, schema, identifier) - return conn.handle.get_table(table_ref) + return client.get_table( + table=self.table_ref(database, schema, identifier), + retry=self._retry.deadline(self._reopen_on_error(conn)), + ) def drop_dataset(self, database, schema): conn = self.get_thread_connection() From 3e32872e65c898fd98dd7e1641fd8d61d7a132bc Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Mon, 4 Nov 2024 17:36:53 -0500 Subject: [PATCH 08/41] fix mocks in unit tests --- dbt/adapters/bigquery/connections.py | 15 ++++--- dbt/adapters/bigquery/retry.py | 4 +- tests/unit/test_bigquery_adapter.py | 13 +++--- .../unit/test_bigquery_connection_manager.py | 45 ++++++++++++------- 4 files changed, 46 insertions(+), 31 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 68a30bef0..89cb89640 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -12,12 +12,6 @@ import google.auth from google.auth import impersonated_credentials import google.auth.exceptions -import google.cloud.bigquery -import google.cloud.exceptions -from google.oauth2 import ( - credentials as GoogleCredentials, - service_account as GoogleServiceAccountCredentials, -) from google.cloud.bigquery import ( Client, CopyJobConfig, @@ -27,6 +21,12 @@ TableReference, WriteDisposition, ) +import google.cloud.exceptions +from google.oauth2 import ( + credentials as GoogleCredentials, + service_account as GoogleServiceAccountCredentials, +) +from requests.exceptions import ConnectionError from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event @@ -622,7 +622,8 @@ def _query_and_results( client: Client = conn.handle """Query the client and wait for results.""" # Cannot reuse job_config if destination is set and ddl is used - job_config = QueryJobConfig(**job_params) + job_factory = QueryJobConfig + job_config = job_factory(**job_params) query_job = client.query( query=sql, job_config=job_config, diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 54cb293f3..2d1a839b7 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -3,10 +3,10 @@ from google.api_core import retry from google.api_core.exceptions import Forbidden from google.cloud.exceptions import BadGateway, BadRequest, ServerError +from requests.exceptions import ConnectionError from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.bigquery.connections import logger from dbt.adapters.bigquery.credentials import BigQueryCredentials @@ -124,7 +124,7 @@ def count_error(self, error): return False # Don't log self._error_count += 1 if _is_retryable(error) and self._error_count <= self._retries: - logger.debug( + _logger.debug( "Retry attempt {} of {} after error: {}".format( self._error_count, self._retries, repr(error) ) diff --git a/tests/unit/test_bigquery_adapter.py b/tests/unit/test_bigquery_adapter.py index ca3bfc24c..50eafb59c 100644 --- a/tests/unit/test_bigquery_adapter.py +++ b/tests/unit/test_bigquery_adapter.py @@ -386,21 +386,20 @@ def test_cancel_open_connections_single(self): adapter.connections.thread_connections.update({key: master, 1: model}) self.assertEqual(len(list(adapter.cancel_open_connections())), 1) - @patch("dbt.adapters.bigquery.impl.google.api_core.client_options.ClientOptions") - @patch("dbt.adapters.bigquery.impl.google.auth.default") - @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") - def test_location_user_agent(self, mock_bq, mock_auth_default, MockClientOptions): + @patch("dbt.adapters.bigquery.connections.client_options.ClientOptions") + @patch("dbt.adapters.bigquery.credentials.google.auth.default") + @patch("dbt.adapters.bigquery.connections.Client") + def test_location_user_agent(self, MockClient, mock_auth_default, MockClientOptions): creds = MagicMock() mock_auth_default.return_value = (creds, MagicMock()) adapter = self.get_adapter("loc") connection = adapter.acquire_connection("dummy") - mock_client = mock_bq.Client mock_client_options = MockClientOptions.return_value - mock_client.assert_not_called() + MockClient.assert_not_called() connection.handle - mock_client.assert_called_once_with( + MockClient.assert_called_once_with( "dbt-unit-000000", creds, location="Luna Station", diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 45dc85136..54db54429 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -5,25 +5,27 @@ from unittest.mock import patch, MagicMock, Mock, ANY import dbt.adapters +import google.cloud.bigquery from dbt.adapters.bigquery import BigQueryCredentials from dbt.adapters.bigquery import BigQueryRelation from dbt.adapters.bigquery.connections import BigQueryConnectionManager +from dbt.adapters.bigquery.retry import RetryFactory class TestBigQueryConnectionManager(unittest.TestCase): def setUp(self): - credentials = Mock(BigQueryCredentials) - profile = Mock(query_comment=None, credentials=credentials) + self.credentials = Mock(BigQueryCredentials) + self.credentials.job_retries = 1 + profile = Mock(query_comment=None, credentials=self.credentials) self.connections = BigQueryConnectionManager(profile=profile, mp_context=Mock()) - self.mock_client = Mock(dbt.adapters.bigquery.impl.google.cloud.bigquery.Client) + self.mock_client = Mock(google.cloud.bigquery.Client) self.mock_connection = MagicMock() self.mock_connection.handle = self.mock_client self.connections.get_thread_connection = lambda: self.mock_connection - self.connections.get_job_retries = lambda x: 1 @patch("dbt.adapters.bigquery.retry._is_retryable", return_value=True) def test_retry_connection_reset(self, is_retryable): @@ -37,8 +39,18 @@ def dummy_handler(msg): self.connections.exception_handler = dummy_handler - mock_conn = Mock(credentials=Mock(retries=1)) - # do something that will raise a ConnectionResetError + retry = RetryFactory(Mock(job_retries=1, job_execution_timeout_seconds=60)) + mock_conn = Mock() + + on_error = self.connections._reopen_on_error(mock_conn) + + @retry.job_execution(on_error) + def generate_connection_reset_error(): + raise ConnectionResetError + + with self.assertRaises(ConnectionResetError): + # this will always raise the error, we just want to test that the connection was reopening in between + generate_connection_reset_error() self.connections.close.assert_called_once_with(mock_conn) self.connections.open.assert_called_once_with(mock_conn) @@ -72,20 +84,21 @@ def test_drop_dataset(self): self.mock_client.delete_table.assert_not_called() self.mock_client.delete_dataset.assert_called_once() - @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") - def test_query_and_results(self, mock_bq): + @patch("dbt.adapters.bigquery.connections.QueryJobConfig") + def test_query_and_results(self, MockQueryJobConfig): self.connections._query_and_results( - self.mock_client, + self.mock_connection, "sql", - {"job_param_1": "blah"}, + {"dry_run": True}, job_id=1, - job_creation_timeout=15, - job_execution_timeout=100, ) - mock_bq.QueryJobConfig.assert_called_once() + MockQueryJobConfig.assert_called_once() self.mock_client.query.assert_called_once_with( - query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15 + query="sql", + job_config=MockQueryJobConfig(), + job_id=1, + timeout=self.credentials.job_creation_timeout_seconds, ) def test_copy_bq_table_appends(self): @@ -95,6 +108,7 @@ def test_copy_bq_table_appends(self): [self._table_ref("project", "dataset", "table1")], self._table_ref("project", "dataset", "table2"), job_config=ANY, + retry=ANY, ) args, kwargs = self.mock_client.copy_table.call_args self.assertEqual( @@ -108,6 +122,7 @@ def test_copy_bq_table_truncates(self): [self._table_ref("project", "dataset", "table1")], self._table_ref("project", "dataset", "table2"), job_config=ANY, + retry=ANY, ) args, kwargs = self.mock_client.copy_table.call_args self.assertEqual( @@ -129,7 +144,7 @@ def test_list_dataset_correctly_calls_lists_datasets(self): self.mock_client.list_datasets = mock_list_dataset result = self.connections.list_dataset("project") self.mock_client.list_datasets.assert_called_once_with( - project="project", max_results=10000 + project="project", max_results=10000, retry=ANY ) assert result == ["d1"] From 89e2a5055d5da89a621ee08d5688bacd76f57b33 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 5 Nov 2024 14:13:37 -0500 Subject: [PATCH 09/41] rebase on main --- dbt/adapters/bigquery/gcloud.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 dbt/adapters/bigquery/gcloud.py diff --git a/dbt/adapters/bigquery/gcloud.py b/dbt/adapters/bigquery/gcloud.py deleted file mode 100644 index e69de29bb..000000000 From 3f79642086f79f1f1999e0578a307ac43cd10e34 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 5 Nov 2024 14:24:35 -0500 Subject: [PATCH 10/41] reorder this tuple to make the pr review easier to understand --- dbt/adapters/bigquery/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 89cb89640..ea73728ff 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -66,8 +66,8 @@ WRITE_TRUNCATE = WriteDisposition.WRITE_TRUNCATE REOPENABLE_ERRORS = ( - ConnectionError, ConnectionResetError, + ConnectionError, ) From f30008017619984d100b7912a782f02740692a9f Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 5 Nov 2024 19:22:04 -0500 Subject: [PATCH 11/41] move client factory to credentials module so that on_error can be moved to the retry factory in the retry module --- dbt/adapters/bigquery/connections.py | 117 ++-------- dbt/adapters/bigquery/credentials.py | 200 ++++++++++++++---- dbt/adapters/bigquery/python_submissions.py | 5 +- dbt/adapters/bigquery/utility.py | 40 +--- tests/conftest.py | 8 +- tests/functional/adapter/test_json_keyfile.py | 7 +- tests/unit/test_bigquery_adapter.py | 14 +- tests/unit/test_configure_dataproc_batch.py | 4 +- 8 files changed, 194 insertions(+), 201 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index ea73728ff..244845b16 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -8,24 +8,21 @@ from typing import Callable, Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING import uuid -from google.api_core import client_info, client_options, retry +from google.api_core import retry import google.auth -from google.auth import impersonated_credentials import google.auth.exceptions from google.cloud.bigquery import ( Client, CopyJobConfig, + Dataset, DatasetReference, QueryJobConfig, QueryPriority, + Table, TableReference, WriteDisposition, ) import google.cloud.exceptions -from google.oauth2 import ( - credentials as GoogleCredentials, - service_account as GoogleServiceAccountCredentials, -) from requests.exceptions import ConnectionError from dbt_common.events.contextvars import get_node_info @@ -43,15 +40,8 @@ from dbt.adapters.events.types import SQLQuery from dbt.adapters.exceptions.connection import FailedToConnectError -import dbt.adapters.bigquery.__version__ as dbt_version -from dbt.adapters.bigquery.credentials import ( - BigQueryConnectionMethod, - Priority, - get_bigquery_defaults, - setup_default_credentials, -) +from dbt.adapters.bigquery.credentials import BigQueryCredentials, Priority, get_bigquery_client from dbt.adapters.bigquery.retry import RetryFactory -from dbt.adapters.bigquery.utility import is_base64, base64_to_string if TYPE_CHECKING: # Indirectly imported via agate_helper, which is lazy loaded further downfile. @@ -217,101 +207,28 @@ def format_rows_number(self, rows_number): rows_number *= 1000.0 return f"{rows_number:3.1f}{unit}".strip() - @classmethod - def get_google_credentials(cls, profile_credentials) -> GoogleCredentials: - method = profile_credentials.method - creds = GoogleServiceAccountCredentials.Credentials - - if method == BigQueryConnectionMethod.OAUTH: - credentials, _ = get_bigquery_defaults(scopes=profile_credentials.scopes) - return credentials - - elif method == BigQueryConnectionMethod.SERVICE_ACCOUNT: - keyfile = profile_credentials.keyfile - return creds.from_service_account_file(keyfile, scopes=profile_credentials.scopes) - - elif method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: - details = profile_credentials.keyfile_json - if is_base64(profile_credentials.keyfile_json): - details = base64_to_string(details) - return creds.from_service_account_info(details, scopes=profile_credentials.scopes) - - elif method == BigQueryConnectionMethod.OAUTH_SECRETS: - return GoogleCredentials.Credentials( - token=profile_credentials.token, - refresh_token=profile_credentials.refresh_token, - client_id=profile_credentials.client_id, - client_secret=profile_credentials.client_secret, - token_uri=profile_credentials.token_uri, - scopes=profile_credentials.scopes, - ) - - error = 'Invalid `method` in profile: "{}"'.format(method) - raise FailedToConnectError(error) - - @classmethod - def get_impersonated_credentials(cls, profile_credentials): - source_credentials = cls.get_google_credentials(profile_credentials) - return impersonated_credentials.Credentials( - source_credentials=source_credentials, - target_principal=profile_credentials.impersonate_service_account, - target_scopes=list(profile_credentials.scopes), - ) - - @classmethod - def get_credentials(cls, profile_credentials): - if profile_credentials.impersonate_service_account: - return cls.get_impersonated_credentials(profile_credentials) - else: - return cls.get_google_credentials(profile_credentials) - @classmethod @retry.Retry() # google decorator. retries on transient errors with exponential backoff - def get_bigquery_client(cls, profile_credentials): - creds = cls.get_credentials(profile_credentials) - execution_project = profile_credentials.execution_project - quota_project = profile_credentials.quota_project - location = getattr(profile_credentials, "location", None) - - info = client_info.ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}") - options = client_options.ClientOptions(quota_project_id=quota_project) - return Client( - execution_project, - creds, - location=location, - client_info=info, - client_options=options, - ) + def bigquery_client(cls, credentials: BigQueryCredentials) -> Client: + return get_bigquery_client(credentials) @classmethod def open(cls, connection): - if connection.state == "open": + if connection.state == ConnectionState.OPEN: logger.debug("Connection is already open, skipping open.") return connection try: - handle = cls.get_bigquery_client(connection.credentials) - - except google.auth.exceptions.DefaultCredentialsError: - logger.info("Please log into GCP to continue") - setup_default_credentials() - - handle = cls.get_bigquery_client(connection.credentials) + connection.handle = cls.bigquery_client(connection.credentials) + connection.state = ConnectionState.OPEN + return connection except Exception as e: - logger.debug( - "Got an error when attempting to create a bigquery " "client: '{}'".format(e) - ) - + logger.debug(f"""Got an error when attempting to create a bigquery " "client: '{e}'""") connection.handle = None - connection.state = "fail" - + connection.state = ConnectionState.FAIL raise FailedToConnectError(str(e)) - connection.handle = handle - connection.state = "open" - return connection - @classmethod def get_table_from_response(cls, resp) -> "agate.Table": from dbt_common.clients import agate_helper @@ -512,7 +429,7 @@ def standard_to_legacy(table): _, iterator = self.raw_execute(sql, use_legacy_sql=True) return self.get_table_from_response(iterator) - def copy_bq_table(self, source, destination, write_disposition): + def copy_bq_table(self, source, destination, write_disposition) -> None: conn = self.get_thread_connection() client: Client = conn.handle @@ -564,7 +481,7 @@ def table_ref(database, schema, table_name): dataset_ref = DatasetReference(database, schema) return TableReference(dataset_ref, table_name) - def get_bq_table(self, database, schema, identifier): + def get_bq_table(self, database, schema, identifier) -> Table: """Get a bigquery table for a schema/model.""" conn = self.get_thread_connection() client: Client = conn.handle @@ -576,18 +493,18 @@ def get_bq_table(self, database, schema, identifier): retry=self._retry.deadline(self._reopen_on_error(conn)), ) - def drop_dataset(self, database, schema): + def drop_dataset(self, database, schema) -> None: conn = self.get_thread_connection() client: Client = conn.handle with self.exception_handler("drop dataset"): - return client.delete_dataset( + client.delete_dataset( dataset=self.dataset_ref(database, schema), delete_contents=True, not_found_ok=True, retry=self._retry.deadline(self._reopen_on_error(conn)), ) - def create_dataset(self, database, schema): + def create_dataset(self, database, schema) -> Dataset: conn = self.get_thread_connection() client: Client = conn.handle with self.exception_handler("create dataset"): diff --git a/dbt/adapters/bigquery/credentials.py b/dbt/adapters/bigquery/credentials.py index 32f172dac..4af817153 100644 --- a/dbt/adapters/bigquery/credentials.py +++ b/dbt/adapters/bigquery/credentials.py @@ -1,9 +1,17 @@ +import base64 +import binascii from dataclasses import dataclass, field from functools import lru_cache -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union -import google.auth +from google.api_core.client_info import ClientInfo +from google.api_core.client_options import ClientOptions +from google.auth import default from google.auth.exceptions import DefaultCredentialsError +from google.auth.impersonated_credentials import Credentials as ImpersonatedCredentials +from google.cloud.bigquery.client import Client as BigQueryClient +from google.oauth2.credentials import Credentials as GoogleCredentials +from google.oauth2.service_account import Credentials as ServiceAccountCredentials from mashumaro import pass_through from dbt_common.clients.system import run_cmd @@ -11,6 +19,9 @@ from dbt_common.exceptions import DbtConfigError, DbtRuntimeError from dbt.adapters.contracts.connection import Credentials from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.exceptions.connection import FailedToConnectError + +import dbt.adapters.bigquery.__version__ as dbt_version _logger = AdapterLogger("BigQuery") @@ -21,59 +32,22 @@ class Priority(StrEnum): Batch = "batch" -class BigQueryConnectionMethod(StrEnum): - OAUTH = "oauth" - SERVICE_ACCOUNT = "service-account" - SERVICE_ACCOUNT_JSON = "service-account-json" - OAUTH_SECRETS = "oauth-secrets" - - @dataclass class DataprocBatchConfig(ExtensibleDbtClassMixin): def __init__(self, batch_config): self.batch_config = batch_config -@lru_cache() -def get_bigquery_defaults(scopes=None) -> Tuple[Any, Optional[str]]: - """ - Returns (credentials, project_id) - - project_id is returned available from the environment; otherwise None - """ - # Cached, because the underlying implementation shells out, taking ~1s - try: - credentials, _ = google.auth.default(scopes=scopes) - return credentials, _ - except DefaultCredentialsError as e: - raise DbtConfigError(f"Failed to authenticate with supplied credentials\nerror:\n{e}") - - -def setup_default_credentials(): - if _gcloud_installed(): - run_cmd(".", ["gcloud", "auth", "application-default", "login"]) - else: - msg = """ - dbt requires the gcloud SDK to be installed to authenticate with BigQuery. - Please download and install the SDK, or use a Service Account instead. - - https://cloud.google.com/sdk/ - """ - raise DbtRuntimeError(msg) - - -def _gcloud_installed(): - try: - run_cmd(".", ["gcloud", "--version"]) - return True - except OSError as e: - _logger.debug(e) - return False +class _BigQueryConnectionMethod(StrEnum): + OAUTH = "oauth" + SERVICE_ACCOUNT = "service-account" + SERVICE_ACCOUNT_JSON = "service-account-json" + OAUTH_SECRETS = "oauth-secrets" @dataclass class BigQueryCredentials(Credentials): - method: BigQueryConnectionMethod = None # type: ignore + method: _BigQueryConnectionMethod = None # type: ignore # BigQuery allows an empty database / project, where it defers to the # environment for the project @@ -179,9 +153,143 @@ def __pre_deserialize__(cls, d: Dict[Any, Any]) -> Dict[Any, Any]: # `database` is an alias of `project` in BigQuery if "database" not in d: - _, database = get_bigquery_defaults() + _, database = _bigquery_defaults() d["database"] = database # `execution_project` default to dataset/project if "execution_project" not in d: d["execution_project"] = d["database"] return d + + +def get_bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: + try: + return _bigquery_client(credentials) + except DefaultCredentialsError: + _logger.info("Please log into GCP to continue") + _setup_default_credentials() + return _bigquery_client(credentials) + + +def _bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: + return BigQueryClient( + credentials.execution_project, + get_credentials(credentials), + location=getattr(credentials, "location", None), + client_info=ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}"), + client_options=ClientOptions(quota_project_id=credentials.quota_project), + ) + + +def _setup_default_credentials() -> None: + try: + run_cmd(".", ["gcloud", "--version"]) + except OSError as e: + _logger.debug(e) + msg = """ + dbt requires the gcloud SDK to be installed to authenticate with BigQuery. + Please download and install the SDK, or use a Service Account instead. + + https://cloud.google.com/sdk/ + """ + raise DbtRuntimeError(msg) + + run_cmd(".", ["gcloud", "auth", "application-default", "login"]) + + +def get_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: + if credentials.impersonate_service_account: + return _impersonated_credentials(credentials) + return _google_credentials(credentials) + + +def _impersonated_credentials(credentials: BigQueryCredentials) -> ImpersonatedCredentials: + if scopes := credentials.scopes: + target_scopes = list(scopes) + else: + target_scopes = [] + + return ImpersonatedCredentials( + source_credentials=_google_credentials(credentials), + target_principal=credentials.impersonate_service_account, + target_scopes=target_scopes, + ) + + +def _google_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: + + if credentials.method == _BigQueryConnectionMethod.OAUTH: + creds, _ = _bigquery_defaults(scopes=credentials.scopes) + + elif credentials.method == _BigQueryConnectionMethod.SERVICE_ACCOUNT: + creds = ServiceAccountCredentials.from_service_account_file( + credentials.keyfile, scopes=credentials.scopes + ) + + elif credentials.method == _BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: + details = credentials.keyfile_json + if _is_base64(details): # type:ignore + details = _base64_to_string(details) + creds = ServiceAccountCredentials.from_service_account_info( + details, scopes=credentials.scopes + ) + + elif credentials.method == _BigQueryConnectionMethod.OAUTH_SECRETS: + creds = GoogleCredentials( + token=credentials.token, + refresh_token=credentials.refresh_token, + client_id=credentials.client_id, + client_secret=credentials.client_secret, + token_uri=credentials.token_uri, + scopes=credentials.scopes, + ) + + else: + raise FailedToConnectError(f"Invalid `method` in profile: '{credentials.method}'") + + return creds + + +@lru_cache() +def _bigquery_defaults(scopes=None) -> Tuple[Any, Optional[str]]: + """ + Returns (credentials, project_id) + + project_id is returned available from the environment; otherwise None + """ + # Cached, because the underlying implementation shells out, taking ~1s + try: + return default(scopes=scopes) + except DefaultCredentialsError as e: + raise DbtConfigError(f"Failed to authenticate with supplied credentials\nerror:\n{e}") + + +def _is_base64(s: Union[str, bytes]) -> bool: + """ + Checks if the given string or bytes object is valid Base64 encoded. + + Args: + s: The string or bytes object to check. + + Returns: + True if the input is valid Base64, False otherwise. + """ + + if isinstance(s, str): + # For strings, ensure they consist only of valid Base64 characters + if not s.isascii(): + return False + # Convert to bytes for decoding + s = s.encode("ascii") + + try: + # Use the 'validate' parameter to enforce strict Base64 decoding rules + base64.b64decode(s, validate=True) + return True + except TypeError: + return False + except binascii.Error: # Catch specific errors from the base64 module + return False + + +def _base64_to_string(b): + return base64.b64decode(b).decode("utf-8") diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 93c82ca92..432cc6303 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -10,8 +10,7 @@ from dbt.adapters.base import PythonJobHelper from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.bigquery.connections import BigQueryConnectionManager -from dbt.adapters.bigquery.credentials import BigQueryCredentials +from dbt.adapters.bigquery.credentials import BigQueryCredentials, get_credentials from dbt.adapters.bigquery.dataproc.batch import ( DEFAULT_JAR_FILE_URI, create_batch_request, @@ -45,7 +44,7 @@ def __init__(self, parsed_model: Dict, credential: BigQueryCredentials) -> None: ) self.model_file_name = f"{schema}/{identifier}.py" self.credential = credential - self.GoogleCredentials = BigQueryConnectionManager.get_credentials(credential) + self.GoogleCredentials = get_credentials(credential) self.storage_client = storage.Client( project=self.credential.execution_project, credentials=self.GoogleCredentials ) diff --git a/dbt/adapters/bigquery/utility.py b/dbt/adapters/bigquery/utility.py index 557986b38..5914280a3 100644 --- a/dbt/adapters/bigquery/utility.py +++ b/dbt/adapters/bigquery/utility.py @@ -1,7 +1,5 @@ -import base64 -import binascii import json -from typing import Any, Optional, Union +from typing import Any, Optional import dbt_common.exceptions @@ -45,39 +43,3 @@ def sql_escape(string): if not isinstance(string, str): raise dbt_common.exceptions.CompilationError(f"cannot escape a non-string: {string}") return json.dumps(string)[1:-1] - - -def is_base64(s: Union[str, bytes]) -> bool: - """ - Checks if the given string or bytes object is valid Base64 encoded. - - Args: - s: The string or bytes object to check. - - Returns: - True if the input is valid Base64, False otherwise. - """ - - if isinstance(s, str): - # For strings, ensure they consist only of valid Base64 characters - if not s.isascii(): - return False - # Convert to bytes for decoding - s = s.encode("ascii") - - try: - # Use the 'validate' parameter to enforce strict Base64 decoding rules - base64.b64decode(s, validate=True) - return True - except TypeError: - return False - except binascii.Error: # Catch specific errors from the base64 module - return False - - -def base64_to_string(b): - return base64.b64decode(b).decode("utf-8") - - -def string_to_base64(s): - return base64.b64encode(s.encode("utf-8")) diff --git a/tests/conftest.py b/tests/conftest.py index 6dc9e6443..33f7f9d17 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,9 @@ import pytest import os import json -from dbt.adapters.bigquery.utility import is_base64, base64_to_string +from dbt.adapters.bigquery.credentials import _is_base64, _base64_to_string -# Import the fuctional fixtures as a plugin +# Import the functional fixtures as a plugin # Note: fixtures with session scope need to be local pytest_plugins = ["dbt.tests.fixtures.project"] @@ -39,8 +39,8 @@ def oauth_target(): def service_account_target(): credentials_json_str = os.getenv("BIGQUERY_TEST_SERVICE_ACCOUNT_JSON").replace("'", "") - if is_base64(credentials_json_str): - credentials_json_str = base64_to_string(credentials_json_str) + if _is_base64(credentials_json_str): + credentials_json_str = _base64_to_string(credentials_json_str) credentials = json.loads(credentials_json_str) project_id = credentials.get("project_id") return { diff --git a/tests/functional/adapter/test_json_keyfile.py b/tests/functional/adapter/test_json_keyfile.py index 91e41a3f1..43928555e 100644 --- a/tests/functional/adapter/test_json_keyfile.py +++ b/tests/functional/adapter/test_json_keyfile.py @@ -1,6 +1,11 @@ +import base64 import json import pytest -from dbt.adapters.bigquery.utility import string_to_base64, is_base64 +from dbt.adapters.bigquery.credentials import is_base64 + + +def string_to_base64(s): + return base64.b64encode(s.encode("utf-8")) @pytest.fixture diff --git a/tests/unit/test_bigquery_adapter.py b/tests/unit/test_bigquery_adapter.py index 50eafb59c..57e676cc4 100644 --- a/tests/unit/test_bigquery_adapter.py +++ b/tests/unit/test_bigquery_adapter.py @@ -203,7 +203,7 @@ def get_adapter(self, target) -> BigQueryAdapter: class TestBigQueryAdapterAcquire(BaseTestBigQueryAdapter): @patch( - "dbt.adapters.bigquery.credentials.get_bigquery_defaults", + "dbt.adapters.bigquery.credentials._bigquery_defaults", return_value=("credentials", "project_id"), ) @patch("dbt.adapters.bigquery.BigQueryConnectionManager.open", return_value=_bq_conn()) @@ -244,10 +244,12 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection): mock_open_connection.assert_called_once() @patch( - "dbt.adapters.bigquery.credentials.get_bigquery_defaults", + "dbt.adapters.bigquery.credentials._bigquery_defaults", return_value=("credentials", "project_id"), ) - @patch("dbt.adapters.bigquery.BigQueryConnectionManager.open", return_value=_bq_conn()) + @patch( + "dbt.adapters.bigquery.connections.BigQueryConnectionManager.open", return_value=_bq_conn() + ) def test_acquire_connection_dataproc_serverless( self, mock_open_connection, mock_get_bigquery_defaults ): @@ -386,9 +388,9 @@ def test_cancel_open_connections_single(self): adapter.connections.thread_connections.update({key: master, 1: model}) self.assertEqual(len(list(adapter.cancel_open_connections())), 1) - @patch("dbt.adapters.bigquery.connections.client_options.ClientOptions") - @patch("dbt.adapters.bigquery.credentials.google.auth.default") - @patch("dbt.adapters.bigquery.connections.Client") + @patch("dbt.adapters.bigquery.credentials.ClientOptions") + @patch("dbt.adapters.bigquery.credentials.default") + @patch("dbt.adapters.bigquery.credentials.BigQueryClient") def test_location_user_agent(self, MockClient, mock_auth_default, MockClientOptions): creds = MagicMock() mock_auth_default.return_value = (creds, MagicMock()) diff --git a/tests/unit/test_configure_dataproc_batch.py b/tests/unit/test_configure_dataproc_batch.py index f56aee129..19a0d3012 100644 --- a/tests/unit/test_configure_dataproc_batch.py +++ b/tests/unit/test_configure_dataproc_batch.py @@ -12,7 +12,7 @@ # parsed credentials class TestConfigureDataprocBatch(BaseTestBigQueryAdapter): @patch( - "dbt.adapters.bigquery.credentials.get_bigquery_defaults", + "dbt.adapters.bigquery.credentials._bigquery_defaults", return_value=("credentials", "project_id"), ) def test_update_dataproc_serverless_batch(self, mock_get_bigquery_defaults): @@ -64,7 +64,7 @@ def to_str_values(d): ) @patch( - "dbt.adapters.bigquery.credentials.get_bigquery_defaults", + "dbt.adapters.bigquery.credentials._bigquery_defaults", return_value=("credentials", "project_id"), ) def test_default_dataproc_serverless_batch(self, mock_get_bigquery_defaults): From c3065e5f9c19285b225ef5232f30fd5fb2dff29f Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 5 Nov 2024 21:18:16 -0500 Subject: [PATCH 12/41] move on_error factory to retry module --- dbt/adapters/bigquery/connections.py | 34 ++++---------- dbt/adapters/bigquery/credentials.py | 6 +-- dbt/adapters/bigquery/retry.py | 44 ++++++++++++++++--- .../unit/test_bigquery_connection_manager.py | 41 ++++++++--------- 4 files changed, 66 insertions(+), 59 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 244845b16..5540ca8ee 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -5,7 +5,7 @@ import json from multiprocessing.context import SpawnContext import re -from typing import Callable, Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING +from typing import Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING import uuid from google.api_core import retry @@ -23,7 +23,6 @@ WriteDisposition, ) import google.cloud.exceptions -from requests.exceptions import ConnectionError from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event @@ -33,7 +32,6 @@ from dbt.adapters.contracts.connection import ( AdapterRequiredConfig, AdapterResponse, - Connection, ConnectionState, ) from dbt.adapters.events.logging import AdapterLogger @@ -55,11 +53,6 @@ WRITE_TRUNCATE = WriteDisposition.WRITE_TRUNCATE -REOPENABLE_ERRORS = ( - ConnectionResetError, - ConnectionError, -) - @dataclass class BigQueryAdapterResponse(AdapterResponse): @@ -82,17 +75,6 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): self.jobs_by_thread: Dict[Hashable, List[str]] = defaultdict(list) self._retry = RetryFactory(profile.credentials) - def _reopen_on_error(self, connection: Connection) -> Callable[[Exception], None]: - - def _on_error(error: Exception): - if isinstance(error, REOPENABLE_ERRORS): - logger.warning("Reopening connection after {!r}".format(error)) - self.close(connection) - self.open(connection) - return - - return _on_error - @classmethod def handle_error(cls, error, message): error_msg = "\n".join([item["message"] for item in error.errors]) @@ -164,7 +146,7 @@ def cancel_open(self): with self.exception_handler(f"Cancel job: {job_id}"): client.cancel_job( job_id, - retry=self._retry.deadline(self._reopen_on_error(connection)), + retry=self._retry.deadline(connection), ) self.close(connection) @@ -468,9 +450,9 @@ def copy_bq_table(self, source, destination, write_disposition) -> None: source_ref_array, destination_ref, job_config=CopyJobConfig(write_disposition=write_disposition), - retry=self._retry.deadline(self._reopen_on_error(conn)), + retry=self._retry.deadline(conn), ) - copy_job.result(retry=self._retry.job_execution_capped(self._reopen_on_error(conn))) + copy_job.result(retry=self._retry.job_execution_capped(conn)) @staticmethod def dataset_ref(database, schema): @@ -490,7 +472,7 @@ def get_bq_table(self, database, schema, identifier) -> Table: schema = schema or conn.credentials.schema return client.get_table( table=self.table_ref(database, schema, identifier), - retry=self._retry.deadline(self._reopen_on_error(conn)), + retry=self._retry.deadline(conn), ) def drop_dataset(self, database, schema) -> None: @@ -501,7 +483,7 @@ def drop_dataset(self, database, schema) -> None: dataset=self.dataset_ref(database, schema), delete_contents=True, not_found_ok=True, - retry=self._retry.deadline(self._reopen_on_error(conn)), + retry=self._retry.deadline(conn), ) def create_dataset(self, database, schema) -> Dataset: @@ -511,7 +493,7 @@ def create_dataset(self, database, schema) -> Dataset: return client.create_dataset( dataset=self.dataset_ref(database, schema), exists_ok=True, - retry=self._retry.deadline(self._reopen_on_error(conn)), + retry=self._retry.deadline(conn), ) def list_dataset(self, database: str): @@ -524,7 +506,7 @@ def list_dataset(self, database: str): all_datasets = client.list_datasets( project=database.strip("`"), max_results=10000, - retry=self._retry.deadline(self._reopen_on_error(conn)), + retry=self._retry.deadline(conn), ) return [ds.dataset_id for ds in all_datasets] diff --git a/dbt/adapters/bigquery/credentials.py b/dbt/adapters/bigquery/credentials.py index 4af817153..4be9a996a 100644 --- a/dbt/adapters/bigquery/credentials.py +++ b/dbt/adapters/bigquery/credentials.py @@ -2,7 +2,7 @@ import binascii from dataclasses import dataclass, field from functools import lru_cache -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union, Iterable from google.api_core.client_info import ClientInfo from google.api_core.client_options import ClientOptions @@ -203,8 +203,8 @@ def get_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: def _impersonated_credentials(credentials: BigQueryCredentials) -> ImpersonatedCredentials: - if scopes := credentials.scopes: - target_scopes = list(scopes) + if credentials.scopes and isinstance(credentials.scopes, Iterable): + target_scopes = list(credentials.scopes) else: target_scopes = [] diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 2d1a839b7..6cc0384e3 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -5,14 +5,22 @@ from google.cloud.exceptions import BadGateway, BadRequest, ServerError from requests.exceptions import ConnectionError +from dbt.adapters.contracts.connection import Connection, ConnectionState from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.exceptions.connection import FailedToConnectError -from dbt.adapters.bigquery.credentials import BigQueryCredentials +from dbt.adapters.bigquery.credentials import BigQueryCredentials, get_bigquery_client _logger = AdapterLogger("BigQuery") +REOPENABLE_ERRORS = ( + ConnectionResetError, + ConnectionError, +) + + RETRYABLE_ERRORS = ( ServerError, BadRequest, @@ -33,7 +41,7 @@ def __init__(self, credentials: BigQueryCredentials) -> None: self.job_execution_timeout = credentials.job_execution_timeout_seconds self.job_deadline = credentials.job_retry_deadline_seconds - def deadline(self, on_error: Callable[[Exception], None]) -> retry.Retry: + def deadline(self, connection: Connection) -> retry.Retry: """ This strategy mimics what was accomplished with _retry_and_handle """ @@ -42,10 +50,10 @@ def deadline(self, on_error: Callable[[Exception], None]) -> retry.Retry: initial=self._DEFAULT_INITIAL_DELAY, maximum=self._DEFAULT_MAXIMUM_DELAY, timeout=self.job_deadline, - on_error=on_error, + on_error=_on_error(connection), ) - def job_execution(self, on_error: Callable[[Exception], None]) -> retry.Retry: + def job_execution(self, connection: Connection) -> retry.Retry: """ This strategy mimics what was accomplished with _retry_and_handle """ @@ -54,17 +62,17 @@ def job_execution(self, on_error: Callable[[Exception], None]) -> retry.Retry: initial=self._DEFAULT_INITIAL_DELAY, maximum=self._DEFAULT_MAXIMUM_DELAY, timeout=self.job_execution_timeout, - on_error=on_error, + on_error=_on_error(connection), ) - def job_execution_capped(self, on_error: Callable[[Exception], None]) -> retry.Retry: + def job_execution_capped(self, connection: Connection) -> retry.Retry: """ This strategy mimics what was accomplished with _retry_and_handle """ return retry.Retry( predicate=self._buffered_predicate(), timeout=self.job_execution_timeout or 300, - on_error=on_error, + on_error=_on_error(connection), ) def _buffered_predicate(self) -> Callable[[Exception], bool]: @@ -101,6 +109,28 @@ def __call__(self, error: Exception) -> bool: return BufferedPredicate(self._retries) +def _on_error(connection: Connection) -> Callable[[Exception], None]: + + def on_error(error: Exception): + if isinstance(error, REOPENABLE_ERRORS): + _logger.warning("Reopening connection after {!r}".format(error)) + connection.handle.close() + + try: + connection.handle = get_bigquery_client(connection.credentials) + connection.state = ConnectionState.OPEN + + except Exception as e: + _logger.debug( + f"""Got an error when attempting to create a bigquery " "client: '{e}'""" + ) + connection.handle = None + connection.state = ConnectionState.FAIL + raise FailedToConnectError(str(e)) + + return on_error + + def _is_retryable(error: Exception) -> bool: """Return true for errors that are unlikely to occur again if retried.""" if isinstance(error, RETRYABLE_ERRORS): diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 54db54429..8be5e19a3 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -1,6 +1,5 @@ import json import unittest -from contextlib import contextmanager from requests.exceptions import ConnectionError from unittest.mock import patch, MagicMock, Mock, ANY @@ -16,43 +15,39 @@ class TestBigQueryConnectionManager(unittest.TestCase): def setUp(self): self.credentials = Mock(BigQueryCredentials) + self.credentials.method = "oauth" self.credentials.job_retries = 1 - profile = Mock(query_comment=None, credentials=self.credentials) - self.connections = BigQueryConnectionManager(profile=profile, mp_context=Mock()) + self.credentials.job_execution_timeout_seconds = 1 + self.credentials.scopes = tuple() self.mock_client = Mock(google.cloud.bigquery.Client) - self.mock_connection = MagicMock() + self.mock_connection = MagicMock() self.mock_connection.handle = self.mock_client + self.mock_connection.credentials = self.credentials + self.connections = BigQueryConnectionManager( + profile=Mock(credentials=self.credentials, query_comment=None), + mp_context=Mock(), + ) self.connections.get_thread_connection = lambda: self.mock_connection - @patch("dbt.adapters.bigquery.retry._is_retryable", return_value=True) - def test_retry_connection_reset(self, is_retryable): - self.connections.open = MagicMock() - self.connections.close = MagicMock() - self.connections._retry.DEFAULT_MAXIMUM_DELAY = 2.0 - - @contextmanager - def dummy_handler(msg): - yield + @patch( + "dbt.adapters.bigquery.retry.get_bigquery_client", + return_value=Mock(google.cloud.bigquery.Client), + ) + def test_retry_connection_reset(self, mock_bigquery_client): + original_handle = self.mock_connection.handle - self.connections.exception_handler = dummy_handler - - retry = RetryFactory(Mock(job_retries=1, job_execution_timeout_seconds=60)) - mock_conn = Mock() - - on_error = self.connections._reopen_on_error(mock_conn) - - @retry.job_execution(on_error) + @self.connections._retry.job_execution(self.mock_connection) def generate_connection_reset_error(): raise ConnectionResetError with self.assertRaises(ConnectionResetError): # this will always raise the error, we just want to test that the connection was reopening in between generate_connection_reset_error() - self.connections.close.assert_called_once_with(mock_conn) - self.connections.open.assert_called_once_with(mock_conn) + + assert not self.mock_connection.handle is original_handle def test_is_retryable(self): _is_retryable = dbt.adapters.bigquery.retry._is_retryable From ad74114027865f52d0f75b76cd91dc366a5e968f Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 5 Nov 2024 21:55:46 -0500 Subject: [PATCH 13/41] move client factories from python_submissions module to credentials module --- dbt/adapters/bigquery/connections.py | 4 +- dbt/adapters/bigquery/credentials.py | 37 +++++++++++++++++-- dbt/adapters/bigquery/python_submissions.py | 37 +++++++++---------- dbt/adapters/bigquery/retry.py | 4 +- .../unit/test_bigquery_connection_manager.py | 2 +- 5 files changed, 56 insertions(+), 28 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 5540ca8ee..effc9c87e 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -38,7 +38,7 @@ from dbt.adapters.events.types import SQLQuery from dbt.adapters.exceptions.connection import FailedToConnectError -from dbt.adapters.bigquery.credentials import BigQueryCredentials, Priority, get_bigquery_client +from dbt.adapters.bigquery.credentials import BigQueryCredentials, Priority, bigquery_client from dbt.adapters.bigquery.retry import RetryFactory if TYPE_CHECKING: @@ -192,7 +192,7 @@ def format_rows_number(self, rows_number): @classmethod @retry.Retry() # google decorator. retries on transient errors with exponential backoff def bigquery_client(cls, credentials: BigQueryCredentials) -> Client: - return get_bigquery_client(credentials) + return bigquery_client(credentials) @classmethod def open(cls, connection): diff --git a/dbt/adapters/bigquery/credentials.py b/dbt/adapters/bigquery/credentials.py index 4be9a996a..c92c01d8b 100644 --- a/dbt/adapters/bigquery/credentials.py +++ b/dbt/adapters/bigquery/credentials.py @@ -2,7 +2,7 @@ import binascii from dataclasses import dataclass, field from functools import lru_cache -from typing import Any, Dict, Optional, Tuple, Union, Iterable +from typing import Any, Dict, Iterable, Optional, Tuple, Union from google.api_core.client_info import ClientInfo from google.api_core.client_options import ClientOptions @@ -10,6 +10,8 @@ from google.auth.exceptions import DefaultCredentialsError from google.auth.impersonated_credentials import Credentials as ImpersonatedCredentials from google.cloud.bigquery.client import Client as BigQueryClient +from google.cloud.dataproc_v1 import JobControllerClient, BatchControllerClient +from google.cloud.storage.client import Client as StorageClient from google.oauth2.credentials import Credentials as GoogleCredentials from google.oauth2.service_account import Credentials as ServiceAccountCredentials from mashumaro import pass_through @@ -161,7 +163,7 @@ def __pre_deserialize__(cls, d: Dict[Any, Any]) -> Dict[Any, Any]: return d -def get_bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: +def bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: try: return _bigquery_client(credentials) except DefaultCredentialsError: @@ -170,10 +172,37 @@ def get_bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: return _bigquery_client(credentials) +def storage_client(credentials: BigQueryCredentials) -> StorageClient: + return StorageClient( + project=credentials.execution_project, + credentials=_get_credentials(credentials), + ) + + +def job_controller_client(credentials: BigQueryCredentials) -> JobControllerClient: + options = ClientOptions( + api_endpoint=f"{credentials.dataproc_region}-dataproc.googleapis.com:443", + ) + return JobControllerClient( + credentials=_get_credentials(credentials), + client_options=options, + ) + + +def batch_controller_client(credentials: BigQueryCredentials) -> BatchControllerClient: + options = ClientOptions( + api_endpoint=f"{credentials.dataproc_region}-dataproc.googleapis.com:443", + ) + return BatchControllerClient( + credentials=_get_credentials(credentials), + client_options=options, + ) + + def _bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: return BigQueryClient( credentials.execution_project, - get_credentials(credentials), + _get_credentials(credentials), location=getattr(credentials, "location", None), client_info=ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}"), client_options=ClientOptions(quota_project_id=credentials.quota_project), @@ -196,7 +225,7 @@ def _setup_default_credentials() -> None: run_cmd(".", ["gcloud", "auth", "application-default", "login"]) -def get_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: +def _get_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: if credentials.impersonate_service_account: return _impersonated_credentials(credentials) return _google_credentials(credentials) diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 432cc6303..8fe018d58 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -4,13 +4,19 @@ from google.api_core import retry from google.api_core.client_options import ClientOptions from google.api_core.future.polling import POLLING_PREDICATE -from google.cloud import storage, dataproc_v1 -from google.cloud.dataproc_v1.types.batches import Batch +from google.cloud import dataproc_v1 +from google.cloud.dataproc_v1 import BatchControllerClient, JobControllerClient +from google.cloud.dataproc_v1.types import Batch, Job from dbt.adapters.base import PythonJobHelper from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.bigquery.credentials import BigQueryCredentials, get_credentials +from dbt.adapters.bigquery.credentials import ( + BigQueryCredentials, + batch_controller_client, + job_controller_client, + storage_client, +) from dbt.adapters.bigquery.dataproc.batch import ( DEFAULT_JAR_FILE_URI, create_batch_request, @@ -44,10 +50,7 @@ def __init__(self, parsed_model: Dict, credential: BigQueryCredentials) -> None: ) self.model_file_name = f"{schema}/{identifier}.py" self.credential = credential - self.GoogleCredentials = get_credentials(credential) - self.storage_client = storage.Client( - project=self.credential.execution_project, credentials=self.GoogleCredentials - ) + self.storage_client = storage_client(self.credential) self.gcs_location = "gs://{}/{}".format(self.credential.gcs_bucket, self.model_file_name) # set retry policy, default to timeout after 24 hours @@ -67,7 +70,7 @@ def _upload_to_gcs(self, filename: str, compiled_code: str) -> None: blob = bucket.blob(filename) blob.upload_from_string(compiled_code) - def submit(self, compiled_code: str) -> dataproc_v1.types.jobs.Job: + def submit(self, compiled_code: str) -> Job: # upload python file to GCS self._upload_to_gcs(self.model_file_name, compiled_code) # submit dataproc job @@ -75,29 +78,27 @@ def submit(self, compiled_code: str) -> dataproc_v1.types.jobs.Job: def _get_job_client( self, - ) -> Union[dataproc_v1.JobControllerClient, dataproc_v1.BatchControllerClient]: + ) -> Union[JobControllerClient, BatchControllerClient]: raise NotImplementedError("_get_job_client not implemented") - def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job: + def _submit_dataproc_job(self) -> Job: raise NotImplementedError("_submit_dataproc_job not implemented") class ClusterDataprocHelper(BaseDataProcHelper): - def _get_job_client(self) -> dataproc_v1.JobControllerClient: + def _get_job_client(self) -> JobControllerClient: if not self._get_cluster_name(): raise ValueError( "Need to supply dataproc_cluster_name in profile or config to submit python job with cluster submission method" ) - return dataproc_v1.JobControllerClient( - client_options=self.client_options, credentials=self.GoogleCredentials - ) + return job_controller_client(self.credential) def _get_cluster_name(self) -> str: return self.parsed_model["config"].get( "dataproc_cluster_name", self.credential.dataproc_cluster_name ) - def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job: + def _submit_dataproc_job(self) -> Job: job = { "placement": {"cluster_name": self._get_cluster_name()}, "pyspark_job": { @@ -119,10 +120,8 @@ def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job: class ServerlessDataProcHelper(BaseDataProcHelper): - def _get_job_client(self) -> dataproc_v1.BatchControllerClient: - return dataproc_v1.BatchControllerClient( - client_options=self.client_options, credentials=self.GoogleCredentials - ) + def _get_job_client(self) -> BatchControllerClient: + return batch_controller_client(self.credential) def _get_batch_id(self) -> str: model = self.parsed_model diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 6cc0384e3..84a026f38 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -9,7 +9,7 @@ from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.exceptions.connection import FailedToConnectError -from dbt.adapters.bigquery.credentials import BigQueryCredentials, get_bigquery_client +from dbt.adapters.bigquery.credentials import BigQueryCredentials, bigquery_client _logger = AdapterLogger("BigQuery") @@ -117,7 +117,7 @@ def on_error(error: Exception): connection.handle.close() try: - connection.handle = get_bigquery_client(connection.credentials) + connection.handle = bigquery_client(connection.credentials) connection.state = ConnectionState.OPEN except Exception as e: diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 8be5e19a3..19e0e1ab4 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -33,7 +33,7 @@ def setUp(self): self.connections.get_thread_connection = lambda: self.mock_connection @patch( - "dbt.adapters.bigquery.retry.get_bigquery_client", + "dbt.adapters.bigquery.retry.bigquery_client", return_value=Mock(google.cloud.bigquery.Client), ) def test_retry_connection_reset(self, mock_bigquery_client): From 9029c495e30fd27ba5345e5ad534bb91893df6d6 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 5 Nov 2024 22:28:08 -0500 Subject: [PATCH 14/41] create a clients module --- dbt/adapters/bigquery/clients.py | 64 +++++++++++++++++++++ dbt/adapters/bigquery/connections.py | 3 +- dbt/adapters/bigquery/credentials.py | 58 +------------------ dbt/adapters/bigquery/python_submissions.py | 4 +- dbt/adapters/bigquery/retry.py | 4 +- tests/unit/test_bigquery_adapter.py | 4 +- 6 files changed, 74 insertions(+), 63 deletions(-) create mode 100644 dbt/adapters/bigquery/clients.py diff --git a/dbt/adapters/bigquery/clients.py b/dbt/adapters/bigquery/clients.py new file mode 100644 index 000000000..5fe66dc9d --- /dev/null +++ b/dbt/adapters/bigquery/clients.py @@ -0,0 +1,64 @@ +from google.api_core.client_info import ClientInfo +from google.api_core.client_options import ClientOptions +from google.auth.exceptions import DefaultCredentialsError +from google.cloud.bigquery import Client as BigQueryClient +from google.cloud.dataproc_v1 import BatchControllerClient, JobControllerClient +from google.cloud.storage import Client as StorageClient + +from dbt.adapters.events.logging import AdapterLogger + +import dbt.adapters.bigquery.__version__ as dbt_version +from dbt.adapters.bigquery.credentials import ( + BigQueryCredentials, + google_credentials, + setup_default_credentials, +) + + +_logger = AdapterLogger("BigQuery") + + +def bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: + try: + return _bigquery_client(credentials) + except DefaultCredentialsError: + _logger.info("Please log into GCP to continue") + setup_default_credentials() + return _bigquery_client(credentials) + + +def storage_client(credentials: BigQueryCredentials) -> StorageClient: + return StorageClient( + project=credentials.execution_project, + credentials=google_credentials(credentials), + ) + + +def job_controller_client(credentials: BigQueryCredentials) -> JobControllerClient: + options = ClientOptions( + api_endpoint=f"{credentials.dataproc_region}-dataproc.googleapis.com:443", + ) + return JobControllerClient( + credentials=google_credentials(credentials), + client_options=options, + ) + + +def batch_controller_client(credentials: BigQueryCredentials) -> BatchControllerClient: + options = ClientOptions( + api_endpoint=f"{credentials.dataproc_region}-dataproc.googleapis.com:443", + ) + return BatchControllerClient( + credentials=google_credentials(credentials), + client_options=options, + ) + + +def _bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: + return BigQueryClient( + credentials.execution_project, + google_credentials(credentials), + location=getattr(credentials, "location", None), + client_info=ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}"), + client_options=ClientOptions(quota_project_id=credentials.quota_project), + ) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index effc9c87e..d6e473b43 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -38,7 +38,8 @@ from dbt.adapters.events.types import SQLQuery from dbt.adapters.exceptions.connection import FailedToConnectError -from dbt.adapters.bigquery.credentials import BigQueryCredentials, Priority, bigquery_client +from dbt.adapters.bigquery.clients import bigquery_client +from dbt.adapters.bigquery.credentials import BigQueryCredentials, Priority from dbt.adapters.bigquery.retry import RetryFactory if TYPE_CHECKING: diff --git a/dbt/adapters/bigquery/credentials.py b/dbt/adapters/bigquery/credentials.py index c92c01d8b..e00e17781 100644 --- a/dbt/adapters/bigquery/credentials.py +++ b/dbt/adapters/bigquery/credentials.py @@ -4,14 +4,9 @@ from functools import lru_cache from typing import Any, Dict, Iterable, Optional, Tuple, Union -from google.api_core.client_info import ClientInfo -from google.api_core.client_options import ClientOptions from google.auth import default from google.auth.exceptions import DefaultCredentialsError from google.auth.impersonated_credentials import Credentials as ImpersonatedCredentials -from google.cloud.bigquery.client import Client as BigQueryClient -from google.cloud.dataproc_v1 import JobControllerClient, BatchControllerClient -from google.cloud.storage.client import Client as StorageClient from google.oauth2.credentials import Credentials as GoogleCredentials from google.oauth2.service_account import Credentials as ServiceAccountCredentials from mashumaro import pass_through @@ -23,9 +18,6 @@ from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.exceptions.connection import FailedToConnectError -import dbt.adapters.bigquery.__version__ as dbt_version - - _logger = AdapterLogger("BigQuery") @@ -163,53 +155,7 @@ def __pre_deserialize__(cls, d: Dict[Any, Any]) -> Dict[Any, Any]: return d -def bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: - try: - return _bigquery_client(credentials) - except DefaultCredentialsError: - _logger.info("Please log into GCP to continue") - _setup_default_credentials() - return _bigquery_client(credentials) - - -def storage_client(credentials: BigQueryCredentials) -> StorageClient: - return StorageClient( - project=credentials.execution_project, - credentials=_get_credentials(credentials), - ) - - -def job_controller_client(credentials: BigQueryCredentials) -> JobControllerClient: - options = ClientOptions( - api_endpoint=f"{credentials.dataproc_region}-dataproc.googleapis.com:443", - ) - return JobControllerClient( - credentials=_get_credentials(credentials), - client_options=options, - ) - - -def batch_controller_client(credentials: BigQueryCredentials) -> BatchControllerClient: - options = ClientOptions( - api_endpoint=f"{credentials.dataproc_region}-dataproc.googleapis.com:443", - ) - return BatchControllerClient( - credentials=_get_credentials(credentials), - client_options=options, - ) - - -def _bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: - return BigQueryClient( - credentials.execution_project, - _get_credentials(credentials), - location=getattr(credentials, "location", None), - client_info=ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}"), - client_options=ClientOptions(quota_project_id=credentials.quota_project), - ) - - -def _setup_default_credentials() -> None: +def setup_default_credentials() -> None: try: run_cmd(".", ["gcloud", "--version"]) except OSError as e: @@ -225,7 +171,7 @@ def _setup_default_credentials() -> None: run_cmd(".", ["gcloud", "auth", "application-default", "login"]) -def _get_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: +def google_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: if credentials.impersonate_service_account: return _impersonated_credentials(credentials) return _google_credentials(credentials) diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 8fe018d58..d5b7c1115 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -11,8 +11,8 @@ from dbt.adapters.base import PythonJobHelper from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.bigquery.credentials import ( - BigQueryCredentials, +from dbt.adapters.bigquery.credentials import BigQueryCredentials +from dbt.adapters.bigquery.clients import ( batch_controller_client, job_controller_client, storage_client, diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 84a026f38..0b1541805 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -9,8 +9,8 @@ from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.exceptions.connection import FailedToConnectError -from dbt.adapters.bigquery.credentials import BigQueryCredentials, bigquery_client - +from dbt.adapters.bigquery.credentials import BigQueryCredentials +from dbt.adapters.bigquery.clients import bigquery_client _logger = AdapterLogger("BigQuery") diff --git a/tests/unit/test_bigquery_adapter.py b/tests/unit/test_bigquery_adapter.py index 57e676cc4..3d7e9e77e 100644 --- a/tests/unit/test_bigquery_adapter.py +++ b/tests/unit/test_bigquery_adapter.py @@ -388,9 +388,9 @@ def test_cancel_open_connections_single(self): adapter.connections.thread_connections.update({key: master, 1: model}) self.assertEqual(len(list(adapter.cancel_open_connections())), 1) - @patch("dbt.adapters.bigquery.credentials.ClientOptions") + @patch("dbt.adapters.bigquery.clients.ClientOptions") @patch("dbt.adapters.bigquery.credentials.default") - @patch("dbt.adapters.bigquery.credentials.BigQueryClient") + @patch("dbt.adapters.bigquery.clients.BigQueryClient") def test_location_user_agent(self, MockClient, mock_auth_default, MockClientOptions): creds = MagicMock() mock_auth_default.return_value = (creds, MagicMock()) From bc0fbea641dec036e857dc4d7759166b2dcca6bb Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 5 Nov 2024 22:34:16 -0500 Subject: [PATCH 15/41] retry all client factories by default --- dbt/adapters/bigquery/clients.py | 5 +++++ dbt/adapters/bigquery/connections.py | 13 ++----------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/dbt/adapters/bigquery/clients.py b/dbt/adapters/bigquery/clients.py index 5fe66dc9d..948d47f65 100644 --- a/dbt/adapters/bigquery/clients.py +++ b/dbt/adapters/bigquery/clients.py @@ -1,5 +1,6 @@ from google.api_core.client_info import ClientInfo from google.api_core.client_options import ClientOptions +from google.api_core.retry import Retry from google.auth.exceptions import DefaultCredentialsError from google.cloud.bigquery import Client as BigQueryClient from google.cloud.dataproc_v1 import BatchControllerClient, JobControllerClient @@ -27,6 +28,7 @@ def bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: return _bigquery_client(credentials) +@Retry() # google decorator. retries on transient errors with exponential backoff def storage_client(credentials: BigQueryCredentials) -> StorageClient: return StorageClient( project=credentials.execution_project, @@ -34,6 +36,7 @@ def storage_client(credentials: BigQueryCredentials) -> StorageClient: ) +@Retry() # google decorator. retries on transient errors with exponential backoff def job_controller_client(credentials: BigQueryCredentials) -> JobControllerClient: options = ClientOptions( api_endpoint=f"{credentials.dataproc_region}-dataproc.googleapis.com:443", @@ -44,6 +47,7 @@ def job_controller_client(credentials: BigQueryCredentials) -> JobControllerClie ) +@Retry() # google decorator. retries on transient errors with exponential backoff def batch_controller_client(credentials: BigQueryCredentials) -> BatchControllerClient: options = ClientOptions( api_endpoint=f"{credentials.dataproc_region}-dataproc.googleapis.com:443", @@ -54,6 +58,7 @@ def batch_controller_client(credentials: BigQueryCredentials) -> BatchController ) +@Retry() # google decorator. retries on transient errors with exponential backoff def _bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: return BigQueryClient( credentials.execution_project, diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index d6e473b43..e522da77f 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -8,7 +8,6 @@ from typing import Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING import uuid -from google.api_core import retry import google.auth import google.auth.exceptions from google.cloud.bigquery import ( @@ -39,7 +38,7 @@ from dbt.adapters.exceptions.connection import FailedToConnectError from dbt.adapters.bigquery.clients import bigquery_client -from dbt.adapters.bigquery.credentials import BigQueryCredentials, Priority +from dbt.adapters.bigquery.credentials import Priority from dbt.adapters.bigquery.retry import RetryFactory if TYPE_CHECKING: @@ -68,9 +67,6 @@ class BigQueryAdapterResponse(AdapterResponse): class BigQueryConnectionManager(BaseConnectionManager): TYPE = "bigquery" - DEFAULT_INITIAL_DELAY = 1.0 # Seconds - DEFAULT_MAXIMUM_DELAY = 3.0 # Seconds - def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): super().__init__(profile, mp_context) self.jobs_by_thread: Dict[Hashable, List[str]] = defaultdict(list) @@ -190,11 +186,6 @@ def format_rows_number(self, rows_number): rows_number *= 1000.0 return f"{rows_number:3.1f}{unit}".strip() - @classmethod - @retry.Retry() # google decorator. retries on transient errors with exponential backoff - def bigquery_client(cls, credentials: BigQueryCredentials) -> Client: - return bigquery_client(credentials) - @classmethod def open(cls, connection): if connection.state == ConnectionState.OPEN: @@ -202,7 +193,7 @@ def open(cls, connection): return connection try: - connection.handle = cls.bigquery_client(connection.credentials) + connection.handle = bigquery_client(connection.credentials) connection.state = ConnectionState.OPEN return connection From 9a9f87ec08eceaf7d13dfcf7a413835a243e209f Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 6 Nov 2024 13:22:25 -0500 Subject: [PATCH 16/41] move polling from manual check in python_submissions module into retry_factory --- dbt/adapters/bigquery/dataproc/__init__.py | 0 dbt/adapters/bigquery/dataproc/batch.py | 68 ------- dbt/adapters/bigquery/python_submissions.py | 196 ++++++++++---------- dbt/adapters/bigquery/retry.py | 14 +- tests/unit/test_configure_dataproc_batch.py | 4 +- 5 files changed, 111 insertions(+), 171 deletions(-) delete mode 100644 dbt/adapters/bigquery/dataproc/__init__.py delete mode 100644 dbt/adapters/bigquery/dataproc/batch.py diff --git a/dbt/adapters/bigquery/dataproc/__init__.py b/dbt/adapters/bigquery/dataproc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/dbt/adapters/bigquery/dataproc/batch.py b/dbt/adapters/bigquery/dataproc/batch.py deleted file mode 100644 index 59f40d246..000000000 --- a/dbt/adapters/bigquery/dataproc/batch.py +++ /dev/null @@ -1,68 +0,0 @@ -from datetime import datetime -import time -from typing import Dict, Union - -from google.cloud.dataproc_v1 import ( - Batch, - BatchControllerClient, - CreateBatchRequest, - GetBatchRequest, -) -from google.protobuf.json_format import ParseDict - -from dbt.adapters.bigquery.credentials import DataprocBatchConfig - - -_BATCH_RUNNING_STATES = [Batch.State.PENDING, Batch.State.RUNNING] -DEFAULT_JAR_FILE_URI = "gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.13-0.34.0.jar" - - -def create_batch_request( - batch: Batch, batch_id: str, project: str, region: str -) -> CreateBatchRequest: - return CreateBatchRequest( - parent=f"projects/{project}/locations/{region}", - batch_id=batch_id, - batch=batch, - ) - - -def poll_batch_job( - parent: str, batch_id: str, job_client: BatchControllerClient, timeout: int -) -> Batch: - batch_name = "".join([parent, "/batches/", batch_id]) - state = Batch.State.PENDING - response = None - run_time = 0 - while state in _BATCH_RUNNING_STATES and run_time < timeout: - time.sleep(1) - response = job_client.get_batch( - request=GetBatchRequest(name=batch_name), - ) - run_time = datetime.now().timestamp() - response.create_time.timestamp() - state = response.state - if not response: - raise ValueError("No response from Dataproc") - if state != Batch.State.SUCCEEDED: - if run_time >= timeout: - raise ValueError( - f"Operation did not complete within the designated timeout of {timeout} seconds." - ) - else: - raise ValueError(response.state_message) - return response - - -def update_batch_from_config(config_dict: Union[Dict, DataprocBatchConfig], target: Batch): - try: - # updates in place - ParseDict(config_dict, target._pb) - except Exception as e: - docurl = ( - "https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1" - "#google.cloud.dataproc.v1.Batch" - ) - raise ValueError( - f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}" - ) from e - return target diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index d5b7c1115..e76f6dc13 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -1,150 +1,131 @@ -import uuid from typing import Dict, Union +import uuid -from google.api_core import retry -from google.api_core.client_options import ClientOptions -from google.api_core.future.polling import POLLING_PREDICATE -from google.cloud import dataproc_v1 -from google.cloud.dataproc_v1 import BatchControllerClient, JobControllerClient -from google.cloud.dataproc_v1.types import Batch, Job +from google.cloud.dataproc_v1 import ( + Batch, + CreateBatchRequest, + GetBatchRequest, + Job, + RuntimeConfig, +) from dbt.adapters.base import PythonJobHelper from dbt.adapters.events.logging import AdapterLogger +from google.protobuf.json_format import ParseDict -from dbt.adapters.bigquery.credentials import BigQueryCredentials +from dbt.adapters.bigquery.credentials import BigQueryCredentials, DataprocBatchConfig from dbt.adapters.bigquery.clients import ( batch_controller_client, job_controller_client, storage_client, ) -from dbt.adapters.bigquery.dataproc.batch import ( - DEFAULT_JAR_FILE_URI, - create_batch_request, - poll_batch_job, - update_batch_from_config, -) +from dbt.adapters.bigquery.retry import RetryFactory + -OPERATION_RETRY_TIME = 10 -logger = AdapterLogger("BigQuery") +_logger = AdapterLogger("BigQuery") + + +_DEFAULT_JAR_FILE_URI = "gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.13-0.34.0.jar" class BaseDataProcHelper(PythonJobHelper): - def __init__(self, parsed_model: Dict, credential: BigQueryCredentials) -> None: + def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: """_summary_ Args: - credential (_type_): _description_ + credentials (_type_): _description_ """ # validate all additional stuff for python is set - schema = parsed_model["schema"] - identifier = parsed_model["alias"] - self.parsed_model = parsed_model - python_required_configs = [ - "dataproc_region", - "gcs_bucket", - ] - for required_config in python_required_configs: - if not getattr(credential, required_config): + for required_config in ["dataproc_region", "gcs_bucket"]: + if not getattr(credentials, required_config): raise ValueError( f"Need to supply {required_config} in profile to submit python job" ) - self.model_file_name = f"{schema}/{identifier}.py" - self.credential = credential - self.storage_client = storage_client(self.credential) - self.gcs_location = "gs://{}/{}".format(self.credential.gcs_bucket, self.model_file_name) + + self._storage_client = storage_client(credentials) + self._project = credentials.execution_project + self._region = credentials.dataproc_region + + schema = parsed_model["schema"] + identifier = parsed_model["alias"] + self._model_file_name = f"{schema}/{identifier}.py" + self._gcs_bucket = credentials.gcs_bucket + self._gcs_path = f"gs://{credentials.gcs_bucket}/{self._model_file_name}" # set retry policy, default to timeout after 24 hours - self.timeout = self.parsed_model["config"].get( - "timeout", self.credential.job_execution_timeout_seconds or 60 * 60 * 24 - ) - self.result_polling_policy = retry.Retry( - predicate=POLLING_PREDICATE, maximum=10.0, timeout=self.timeout - ) - self.client_options = ClientOptions( - api_endpoint="{}-dataproc.googleapis.com:443".format(self.credential.dataproc_region) - ) - self.job_client = self._get_job_client() + retry = RetryFactory(credentials) + timeout = parsed_model["config"].get("timeout") + self._polling_retry = retry.polling(timeout) - def _upload_to_gcs(self, filename: str, compiled_code: str) -> None: - bucket = self.storage_client.get_bucket(self.credential.gcs_bucket) - blob = bucket.blob(filename) + def _upload_to_gcs(self, compiled_code: str) -> None: + bucket = self._storage_client.get_bucket(self._gcs_bucket) + blob = bucket.blob(self._model_file_name) blob.upload_from_string(compiled_code) def submit(self, compiled_code: str) -> Job: - # upload python file to GCS - self._upload_to_gcs(self.model_file_name, compiled_code) - # submit dataproc job + self._upload_to_gcs(compiled_code) return self._submit_dataproc_job() - def _get_job_client( - self, - ) -> Union[JobControllerClient, BatchControllerClient]: - raise NotImplementedError("_get_job_client not implemented") - def _submit_dataproc_job(self) -> Job: raise NotImplementedError("_submit_dataproc_job not implemented") class ClusterDataprocHelper(BaseDataProcHelper): - def _get_job_client(self) -> JobControllerClient: - if not self._get_cluster_name(): + def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: + super().__init__(parsed_model, credentials) + self._job_controller_client = job_controller_client(credentials) + self._cluster_name = parsed_model["config"].get( + "dataproc_cluster_name", credentials.dataproc_cluster_name + ) + + if not self._cluster_name: raise ValueError( "Need to supply dataproc_cluster_name in profile or config to submit python job with cluster submission method" ) - return job_controller_client(self.credential) - - def _get_cluster_name(self) -> str: - return self.parsed_model["config"].get( - "dataproc_cluster_name", self.credential.dataproc_cluster_name - ) def _submit_dataproc_job(self) -> Job: job = { - "placement": {"cluster_name": self._get_cluster_name()}, + "placement": {"cluster_name": self._cluster_name}, "pyspark_job": { - "main_python_file_uri": self.gcs_location, + "main_python_file_uri": self._gcs_path, }, } - operation = self.job_client.submit_job_as_operation( + operation = self._job_controller_client.submit_job_as_operation( request={ - "project_id": self.credential.execution_project, - "region": self.credential.dataproc_region, + "project_id": self._project, + "region": self._region, "job": job, } ) # check if job failed - response = operation.result(polling=self.result_polling_policy) + response = operation.result(polling=self._polling_retry) if response.status.state == 6: raise ValueError(response.status.details) return response class ServerlessDataProcHelper(BaseDataProcHelper): - def _get_job_client(self) -> BatchControllerClient: - return batch_controller_client(self.credential) - - def _get_batch_id(self) -> str: - model = self.parsed_model - default_batch_id = str(uuid.uuid4()) - return model["config"].get("batch_id", default_batch_id) + def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: + super().__init__(parsed_model, credentials) + self._batch_controller_client = batch_controller_client(credentials) + self._batch_id = parsed_model["config"].get("batch_id", str(uuid.uuid4())) + self._jar_file_uri = parsed_model["config"].get("jar_file_uri", _DEFAULT_JAR_FILE_URI) + self._dataproc_batch = credentials.dataproc_batch def _submit_dataproc_job(self) -> Batch: - batch_id = self._get_batch_id() - logger.info(f"Submitting batch job with id: {batch_id}") - request = create_batch_request( - batch=self._configure_batch(), - batch_id=batch_id, - region=self.credential.dataproc_region, # type: ignore - project=self.credential.execution_project, # type: ignore - ) + _logger.info(f"Submitting batch job with id: {self._batch_id}") + # make the request - self.job_client.create_batch(request=request) - return poll_batch_job( - parent=request.parent, - batch_id=batch_id, - job_client=self.job_client, - timeout=self.timeout, + request = CreateBatchRequest( + parent=f"projects/{self._project}/locations/{self._region}", + batch=self._configure_batch(), + batch_id=self._batch_id, ) + self._batch_controller_client.create_batch(request=request) + + # return the response + batch = GetBatchRequest(f"{request.parent}/batches/{self._batch_id}") + return self._batch_controller_client.get_batch(batch, retry=self._polling_retry) # there might be useful results here that we can parse and return # Dataproc job output is saved to the Cloud Storage bucket # allocated to the job. Use regex to obtain the bucket and blob info. @@ -156,30 +137,45 @@ def _submit_dataproc_job(self) -> Batch: # .download_as_string() # ) - def _configure_batch(self): + def _configure_batch(self) -> Batch: # create the Dataproc Serverless job config # need to pin dataproc version to 1.1 as it now defaults to 2.0 # https://cloud.google.com/dataproc-serverless/docs/concepts/properties # https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#runtimeconfig - batch = dataproc_v1.Batch( + batch = Batch( { - "runtime_config": dataproc_v1.RuntimeConfig( + "runtime_config": RuntimeConfig( version="1.1", properties={ "spark.executor.instances": "2", }, - ) + ), + "pyspark_batch": { + "main_python_file_uri": self._gcs_path, + "jar_file_uris": [self._jar_file_uri], + }, } ) - # Apply defaults - batch.pyspark_batch.main_python_file_uri = self.gcs_location - jar_file_uri = self.parsed_model["config"].get( - "jar_file_uri", - DEFAULT_JAR_FILE_URI, - ) - batch.pyspark_batch.jar_file_uris = [jar_file_uri] # Apply configuration from dataproc_batch key, possibly overriding defaults. - if self.credential.dataproc_batch: - batch = update_batch_from_config(self.credential.dataproc_batch, batch) + if self._dataproc_batch: + batch = _update_batch_from_config(self._dataproc_batch, batch) + return batch + + +def _update_batch_from_config( + config_dict: Union[Dict, DataprocBatchConfig], target: Batch +) -> Batch: + try: + # updates in place + ParseDict(config_dict, target._pb) + except Exception as e: + docurl = ( + "https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1" + "#google.cloud.dataproc.v1.Batch" + ) + raise ValueError( + f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}" + ) from e + return target diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 0b1541805..4143c54e5 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -1,7 +1,8 @@ -from typing import Callable +from typing import Callable, Optional from google.api_core import retry from google.api_core.exceptions import Forbidden +from google.api_core.future.polling import POLLING_PREDICATE from google.cloud.exceptions import BadGateway, BadRequest, ServerError from requests.exceptions import ConnectionError @@ -75,6 +76,17 @@ def job_execution_capped(self, connection: Connection) -> retry.Retry: on_error=_on_error(connection), ) + def polling(self, timeout: Optional[float] = None) -> retry.Retry: + """ + This strategy mimics what was accomplished with _retry_and_handle + """ + return retry.Retry( + predicate=POLLING_PREDICATE, + minimum=1.0, + maximum=10.0, + timeout=timeout or self.job_execution_timeout or 60 * 60 * 24, + ) + def _buffered_predicate(self) -> Callable[[Exception], bool]: class BufferedPredicate: """ diff --git a/tests/unit/test_configure_dataproc_batch.py b/tests/unit/test_configure_dataproc_batch.py index 19a0d3012..e73e5b845 100644 --- a/tests/unit/test_configure_dataproc_batch.py +++ b/tests/unit/test_configure_dataproc_batch.py @@ -1,6 +1,6 @@ from unittest.mock import patch -from dbt.adapters.bigquery.dataproc.batch import update_batch_from_config +from dbt.adapters.bigquery.python_submissions import _update_batch_from_config from google.cloud import dataproc_v1 from .test_bigquery_adapter import BaseTestBigQueryAdapter @@ -39,7 +39,7 @@ def test_update_dataproc_serverless_batch(self, mock_get_bigquery_defaults): batch = dataproc_v1.Batch() - batch = update_batch_from_config(raw_batch_config, batch) + batch = _update_batch_from_config(raw_batch_config, batch) def to_str_values(d): """google's protobuf types expose maps as dict[str, str]""" From 136ea7749f456f66c9eeb8880a6fbb305f5c1597 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 6 Nov 2024 13:52:57 -0500 Subject: [PATCH 17/41] move load_dataframe logic from adapter to connection manager, use the built-in timeout argument instead of a manual polling method --- dbt/adapters/bigquery/connections.py | 37 +++++++++++++++++++++ dbt/adapters/bigquery/impl.py | 49 +++++++++++++--------------- 2 files changed, 59 insertions(+), 27 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index e522da77f..041028f63 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -15,8 +15,10 @@ CopyJobConfig, Dataset, DatasetReference, + LoadJobConfig, QueryJobConfig, QueryPriority, + SchemaField, Table, TableReference, WriteDisposition, @@ -446,6 +448,41 @@ def copy_bq_table(self, source, destination, write_disposition) -> None: ) copy_job.result(retry=self._retry.job_execution_capped(conn)) + def load_dataframe( + self, + client: Client, + database: str, + schema: str, + table_name: str, + table_schema: List[SchemaField], + field_delimiter: str, + file_path: str, + ) -> None: + + load_config = LoadJobConfig( + skip_leading_rows=1, + schema=table_schema, + field_delimiter=field_delimiter, + ) + + with self.exception_handler("LOAD TABLE"): + with open(file_path, "rb") as f: + job = client.load_table_from_file( + f, + self.table_ref(database, schema, table_name), + rewind=True, + job_config=load_config, + job_id=self.generate_job_id(), + timeout=self._retry.job_execution_timeout or 300, + ) + + if job.state != "DONE": + raise DbtRuntimeError("BigQuery Timeout Exceeded") + + elif job.error_result: + message = "\n".join(error["message"].strip() for error in job.errors) + raise DbtRuntimeError(message) + @staticmethod def dataset_ref(database, schema): return DatasetReference(project=database, dataset_id=schema) diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index ec9afb08f..d3767cf7d 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -22,7 +22,7 @@ import google.auth import google.oauth2 import google.cloud.bigquery -from google.cloud.bigquery import AccessEntry, SchemaField, Table as BigQueryTable +from google.cloud.bigquery import AccessEntry, Client, SchemaField, Table as BigQueryTable import google.cloud.exceptions import pytz @@ -675,32 +675,27 @@ def alter_table_add_columns(self, relation, columns): @available.parse_none def load_dataframe( self, - database, - schema, - table_name, + database: str, + schema: str, + table_name: str, agate_table: "agate.Table", - column_override, - field_delimiter, - ): - bq_schema = self._agate_to_schema(agate_table, column_override) - conn = self.connections.get_thread_connection() - client = conn.handle - - table_ref = self.connections.table_ref(database, schema, table_name) - - load_config = google.cloud.bigquery.LoadJobConfig() - load_config.skip_leading_rows = 1 - load_config.schema = bq_schema - load_config.field_delimiter = field_delimiter - job_id = self.connections.generate_job_id() - with open(agate_table.original_abspath, "rb") as f: # type: ignore - job = client.load_table_from_file( - f, table_ref, rewind=True, job_config=load_config, job_id=job_id - ) - - timeout = conn.credentials.job_execution_timeout_seconds or 300 - with self.connections.exception_handler("LOAD TABLE"): - self.poll_until_job_completes(job, timeout) + column_override: Dict[str, str], + field_delimiter: str, + ) -> None: + connection = self.connections.get_thread_connection() + client: Client = connection.handle + table_schema = self._agate_to_schema(agate_table, column_override) + file_path = agate_table.original_abspath # type: ignore + + self.connections.load_dataframe( + client, + database, + schema, + table_name, + table_schema, + field_delimiter, + file_path, + ) @available.parse_none def upload_file( @@ -759,7 +754,7 @@ def calculate_freshness_from_metadata( macro_resolver: Optional[MacroResolverProtocol] = None, ) -> Tuple[Optional[AdapterResponse], FreshnessResponse]: conn = self.connections.get_thread_connection() - client: google.cloud.bigquery.Client = conn.handle + client: Client = conn.handle table_ref = self.get_table_ref_from_relation(source) table = client.get_table(table_ref) From 90d5308c8be7289d433f7459d11da02f573221e6 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 6 Nov 2024 15:17:00 -0500 Subject: [PATCH 18/41] move upload_file logic from adapter to connection manager, use the built-in timeout argument instead of a manual polling method, remove the manual polling method --- dbt/adapters/bigquery/connections.py | 38 ++++++++++++++++++-- dbt/adapters/bigquery/impl.py | 54 +++++++++------------------- 2 files changed, 52 insertions(+), 40 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 041028f63..cb989ae41 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -451,12 +451,12 @@ def copy_bq_table(self, source, destination, write_disposition) -> None: def load_dataframe( self, client: Client, + file_path: str, database: str, schema: str, - table_name: str, + identifier: str, table_schema: List[SchemaField], field_delimiter: str, - file_path: str, ) -> None: load_config = LoadJobConfig( @@ -469,7 +469,7 @@ def load_dataframe( with open(file_path, "rb") as f: job = client.load_table_from_file( f, - self.table_ref(database, schema, table_name), + self.table_ref(database, schema, identifier), rewind=True, job_config=load_config, job_id=self.generate_job_id(), @@ -483,6 +483,38 @@ def load_dataframe( message = "\n".join(error["message"].strip() for error in job.errors) raise DbtRuntimeError(message) + def upload_file( + self, + client: Client, + file_path: str, + database: str, + schema: str, + identifier: str, + **kwargs, + ) -> None: + + config = kwargs["kwargs"] + if "schema" in config: + config["schema"] = json.load(config["schema"]) + load_config = LoadJobConfig(**config) + + with self.exception_handler("LOAD TABLE"): + with open(file_path, "rb") as f: + job = client.load_table_from_file( + f, + self.table_ref(database, schema, identifier), + rewind=True, + job_config=load_config, + timeout=self._retry.job_execution_timeout or 300, + ) + + if job.state != "DONE": + raise DbtRuntimeError("BigQuery Timeout Exceeded") + + elif job.error_result: + message = "\n".join(error["message"].strip() for error in job.errors) + raise DbtRuntimeError(message) + @staticmethod def dataset_ref(database, schema): return DatasetReference(project=database, dataset_id=schema) diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index d3767cf7d..f4883f1d8 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -1,9 +1,7 @@ from dataclasses import dataclass from datetime import datetime -import json from multiprocessing.context import SpawnContext import threading -import time from typing import ( Any, Dict, @@ -460,22 +458,6 @@ def get_columns_in_select_sql(self, select_sql: str) -> List[BigQueryColumn]: logger.debug("get_columns_in_select_sql error: {}".format(e)) return [] - @classmethod - def poll_until_job_completes(cls, job, timeout): - retry_count = timeout - - while retry_count > 0 and job.state != "DONE": - retry_count -= 1 - time.sleep(1) - job.reload() - - if job.state != "DONE": - raise dbt_common.exceptions.DbtRuntimeError("BigQuery Timeout Exceeded") - - elif job.error_result: - message = "\n".join(error["message"].strip() for error in job.errors) - raise dbt_common.exceptions.DbtRuntimeError(message) - def _bq_table_to_relation(self, bq_table) -> Union[BigQueryRelation, None]: if bq_table is None: return None @@ -689,36 +671,34 @@ def load_dataframe( self.connections.load_dataframe( client, + file_path, database, schema, table_name, table_schema, field_delimiter, - file_path, ) @available.parse_none def upload_file( - self, local_file_path: str, database: str, table_schema: str, table_name: str, **kwargs + self, + local_file_path: str, + database: str, + table_schema: str, + table_name: str, + **kwargs, ) -> None: - conn = self.connections.get_thread_connection() - client = conn.handle - - table_ref = self.connections.table_ref(database, table_schema, table_name) - - load_config = google.cloud.bigquery.LoadJobConfig() - for k, v in kwargs["kwargs"].items(): - if k == "schema": - setattr(load_config, k, json.loads(v)) - else: - setattr(load_config, k, v) - - with open(local_file_path, "rb") as f: - job = client.load_table_from_file(f, table_ref, rewind=True, job_config=load_config) + connection = self.connections.get_thread_connection() + client: Client = connection.handle - timeout = conn.credentials.job_execution_timeout_seconds or 300 - with self.connections.exception_handler("LOAD TABLE"): - self.poll_until_job_completes(job, timeout) + self.connections.upload_file( + client, + local_file_path, + database, + table_schema, + table_name, + **kwargs, + ) @classmethod def _catalog_filter_table( From 9211e1c6aa10b8f9f18e3a86c4f231b6262e7faa Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 6 Nov 2024 18:33:40 -0500 Subject: [PATCH 19/41] move the retry to polling for done instead of create --- dbt/adapters/bigquery/connections.py | 45 ++++++++------------- dbt/adapters/bigquery/impl.py | 2 + dbt/adapters/bigquery/python_submissions.py | 3 +- dbt/adapters/bigquery/retry.py | 41 ++++++++++++------- 4 files changed, 47 insertions(+), 44 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index cb989ae41..856e69485 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -457,31 +457,15 @@ def load_dataframe( identifier: str, table_schema: List[SchemaField], field_delimiter: str, + fallback_timeout: Optional[float] = None, ) -> None: - load_config = LoadJobConfig( skip_leading_rows=1, schema=table_schema, field_delimiter=field_delimiter, ) - - with self.exception_handler("LOAD TABLE"): - with open(file_path, "rb") as f: - job = client.load_table_from_file( - f, - self.table_ref(database, schema, identifier), - rewind=True, - job_config=load_config, - job_id=self.generate_job_id(), - timeout=self._retry.job_execution_timeout or 300, - ) - - if job.state != "DONE": - raise DbtRuntimeError("BigQuery Timeout Exceeded") - - elif job.error_result: - message = "\n".join(error["message"].strip() for error in job.errors) - raise DbtRuntimeError(message) + table = self.table_ref(database, schema, identifier) + self._load_table_from_file(client, file_path, table, load_config, fallback_timeout) def upload_file( self, @@ -490,25 +474,30 @@ def upload_file( database: str, schema: str, identifier: str, + fallback_timeout: Optional[float] = None, **kwargs, ) -> None: - config = kwargs["kwargs"] if "schema" in config: config["schema"] = json.load(config["schema"]) load_config = LoadJobConfig(**config) + table = self.table_ref(database, schema, identifier) + self._load_table_from_file(client, file_path, table, load_config, fallback_timeout) + + def _load_table_from_file( + self, + client: Client, + file_path: str, + table: TableReference, + config: LoadJobConfig, + fallback_timeout: Optional[float] = None, + ) -> None: with self.exception_handler("LOAD TABLE"): with open(file_path, "rb") as f: - job = client.load_table_from_file( - f, - self.table_ref(database, schema, identifier), - rewind=True, - job_config=load_config, - timeout=self._retry.job_execution_timeout or 300, - ) + job = client.load_table_from_file(f, table, rewind=True, job_config=config) - if job.state != "DONE": + if not job.done(retry=self._retry.polling_done(fallback_timeout=fallback_timeout)): raise DbtRuntimeError("BigQuery Timeout Exceeded") elif job.error_result: diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index f4883f1d8..e0a6a38db 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -677,6 +677,7 @@ def load_dataframe( table_name, table_schema, field_delimiter, + fallback_timeout=300, ) @available.parse_none @@ -697,6 +698,7 @@ def upload_file( database, table_schema, table_name, + fallback_timeout=300, **kwargs, ) diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index e76f6dc13..98a8bee25 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -54,8 +54,7 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None # set retry policy, default to timeout after 24 hours retry = RetryFactory(credentials) - timeout = parsed_model["config"].get("timeout") - self._polling_retry = retry.polling(timeout) + self._polling_retry = retry.polling(timeout=parsed_model["config"].get("timeout")) def _upload_to_gcs(self, compiled_code: str) -> None: bucket = self._storage_client.get_bucket(self._gcs_bucket) diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 4143c54e5..e5c6abbb2 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -1,8 +1,9 @@ from typing import Callable, Optional -from google.api_core import retry +from google.api_core.retry import Retry from google.api_core.exceptions import Forbidden from google.api_core.future.polling import POLLING_PREDICATE +from google.cloud.bigquery.retry import DEFAULT_RETRY from google.cloud.exceptions import BadGateway, BadRequest, ServerError from requests.exceptions import ConnectionError @@ -16,13 +17,13 @@ _logger = AdapterLogger("BigQuery") -REOPENABLE_ERRORS = ( +_REOPENABLE_ERRORS = ( ConnectionResetError, ConnectionError, ) -RETRYABLE_ERRORS = ( +_RETRYABLE_ERRORS = ( ServerError, BadRequest, BadGateway, @@ -31,6 +32,9 @@ ) +_ONE_DAY = 60 * 60 * 24 + + class RetryFactory: _DEFAULT_INITIAL_DELAY = 1.0 # seconds @@ -42,11 +46,11 @@ def __init__(self, credentials: BigQueryCredentials) -> None: self.job_execution_timeout = credentials.job_execution_timeout_seconds self.job_deadline = credentials.job_retry_deadline_seconds - def deadline(self, connection: Connection) -> retry.Retry: + def deadline(self, connection: Connection) -> Retry: """ This strategy mimics what was accomplished with _retry_and_handle """ - return retry.Retry( + return Retry( predicate=self._buffered_predicate(), initial=self._DEFAULT_INITIAL_DELAY, maximum=self._DEFAULT_MAXIMUM_DELAY, @@ -54,11 +58,11 @@ def deadline(self, connection: Connection) -> retry.Retry: on_error=_on_error(connection), ) - def job_execution(self, connection: Connection) -> retry.Retry: + def job_execution(self, connection: Connection) -> Retry: """ This strategy mimics what was accomplished with _retry_and_handle """ - return retry.Retry( + return Retry( predicate=self._buffered_predicate(), initial=self._DEFAULT_INITIAL_DELAY, maximum=self._DEFAULT_MAXIMUM_DELAY, @@ -66,25 +70,34 @@ def job_execution(self, connection: Connection) -> retry.Retry: on_error=_on_error(connection), ) - def job_execution_capped(self, connection: Connection) -> retry.Retry: + def job_execution_capped(self, connection: Connection) -> Retry: """ This strategy mimics what was accomplished with _retry_and_handle """ - return retry.Retry( + return Retry( predicate=self._buffered_predicate(), timeout=self.job_execution_timeout or 300, on_error=_on_error(connection), ) - def polling(self, timeout: Optional[float] = None) -> retry.Retry: + def polling( + self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None + ) -> Retry: """ This strategy mimics what was accomplished with _retry_and_handle """ - return retry.Retry( + return Retry( predicate=POLLING_PREDICATE, minimum=1.0, maximum=10.0, - timeout=timeout or self.job_execution_timeout or 60 * 60 * 24, + timeout=timeout or self.job_execution_timeout or fallback_timeout or _ONE_DAY, + ) + + def polling_done( + self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None + ) -> Retry: + return DEFAULT_RETRY.with_timeout( + timeout or self.job_execution_timeout or fallback_timeout or _ONE_DAY ) def _buffered_predicate(self) -> Callable[[Exception], bool]: @@ -124,7 +137,7 @@ def __call__(self, error: Exception) -> bool: def _on_error(connection: Connection) -> Callable[[Exception], None]: def on_error(error: Exception): - if isinstance(error, REOPENABLE_ERRORS): + if isinstance(error, _REOPENABLE_ERRORS): _logger.warning("Reopening connection after {!r}".format(error)) connection.handle.close() @@ -145,7 +158,7 @@ def on_error(error: Exception): def _is_retryable(error: Exception) -> bool: """Return true for errors that are unlikely to occur again if retried.""" - if isinstance(error, RETRYABLE_ERRORS): + if isinstance(error, _RETRYABLE_ERRORS): return True elif isinstance(error, Forbidden) and any( e["reason"] == "rateLimitExceeded" for e in error.errors From e90c24dfe573fa3748e55246a91324926e1df1ab Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 6 Nov 2024 18:39:29 -0500 Subject: [PATCH 20/41] fix broken import in tests from code migration --- tests/functional/adapter/test_json_keyfile.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/functional/adapter/test_json_keyfile.py b/tests/functional/adapter/test_json_keyfile.py index 43928555e..a5caaebdf 100644 --- a/tests/functional/adapter/test_json_keyfile.py +++ b/tests/functional/adapter/test_json_keyfile.py @@ -1,7 +1,7 @@ import base64 import json import pytest -from dbt.adapters.bigquery.credentials import is_base64 +from dbt.adapters.bigquery.credentials import _is_base64 def string_to_base64(s): @@ -58,7 +58,7 @@ def test_valid_base64_strings(example_json_keyfile_b64): ] for s in valid_strings: - assert is_base64(s) is True + assert _is_base64(s) is True def test_valid_base64_bytes(example_json_keyfile_b64): @@ -70,7 +70,7 @@ def test_valid_base64_bytes(example_json_keyfile_b64): example_json_keyfile_b64, ] for s in valid_bytes: - assert is_base64(s) is True + assert _is_base64(s) is True def test_invalid_base64(example_json_keyfile): @@ -84,4 +84,4 @@ def test_invalid_base64(example_json_keyfile): example_json_keyfile, ] for s in invalid_inputs: - assert is_base64(s) is False + assert _is_base64(s) is False From a2db35badbe8ab82a7fcf4e8a72a02269657fd36 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 6 Nov 2024 22:31:23 -0500 Subject: [PATCH 21/41] align new retries with original methods, simplify retry factory --- dbt/adapters/bigquery/connections.py | 27 ++- dbt/adapters/bigquery/python_submissions.py | 63 +++---- dbt/adapters/bigquery/retry.py | 157 +++++++----------- .../unit/test_bigquery_connection_manager.py | 16 +- 4 files changed, 102 insertions(+), 161 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 856e69485..952a83b04 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -140,12 +140,12 @@ def cancel_open(self): continue if connection.handle is not None and connection.state == ConnectionState.OPEN: - client = connection.handle + client: Client = connection.handle for job_id in self.jobs_by_thread.get(thread_id, []): with self.exception_handler(f"Cancel job: {job_id}"): client.cancel_job( job_id, - retry=self._retry.deadline(connection), + retry=self._retry.reopen_with_deadline(connection), ) self.close(connection) @@ -444,9 +444,8 @@ def copy_bq_table(self, source, destination, write_disposition) -> None: source_ref_array, destination_ref, job_config=CopyJobConfig(write_disposition=write_disposition), - retry=self._retry.deadline(conn), ) - copy_job.result(retry=self._retry.job_execution_capped(conn)) + copy_job.result(timeout=self._retry.job_execution_timeout(300)) def load_dataframe( self, @@ -497,7 +496,7 @@ def _load_table_from_file( with open(file_path, "rb") as f: job = client.load_table_from_file(f, table, rewind=True, job_config=config) - if not job.done(retry=self._retry.polling_done(fallback_timeout=fallback_timeout)): + if not job.done(retry=self._retry.retry(fallback_timeout=fallback_timeout)): raise DbtRuntimeError("BigQuery Timeout Exceeded") elif job.error_result: @@ -522,7 +521,7 @@ def get_bq_table(self, database, schema, identifier) -> Table: schema = schema or conn.credentials.schema return client.get_table( table=self.table_ref(database, schema, identifier), - retry=self._retry.deadline(conn), + retry=self._retry.reopen_with_deadline(conn), ) def drop_dataset(self, database, schema) -> None: @@ -533,7 +532,7 @@ def drop_dataset(self, database, schema) -> None: dataset=self.dataset_ref(database, schema), delete_contents=True, not_found_ok=True, - retry=self._retry.deadline(conn), + retry=self._retry.reopen_with_deadline(conn), ) def create_dataset(self, database, schema) -> Dataset: @@ -543,7 +542,7 @@ def create_dataset(self, database, schema) -> Dataset: return client.create_dataset( dataset=self.dataset_ref(database, schema), exists_ok=True, - retry=self._retry.deadline(conn), + retry=self._retry.reopen_with_deadline(conn), ) def list_dataset(self, database: str): @@ -556,7 +555,7 @@ def list_dataset(self, database: str): all_datasets = client.list_datasets( project=database.strip("`"), max_results=10000, - retry=self._retry.deadline(conn), + retry=self._retry.reopen_with_deadline(conn), ) return [ds.dataset_id for ds in all_datasets] @@ -571,13 +570,11 @@ def _query_and_results( client: Client = conn.handle """Query the client and wait for results.""" # Cannot reuse job_config if destination is set and ddl is used - job_factory = QueryJobConfig - job_config = job_factory(**job_params) query_job = client.query( query=sql, - job_config=job_config, + job_config=QueryJobConfig(**job_params), job_id=job_id, # note, this disables retry since the job_id will have been used - timeout=self._retry.job_creation_timeout, + timeout=self._retry.job_creation_timeout(), ) if ( query_job.location is not None @@ -589,11 +586,11 @@ def _query_and_results( ) try: iterator = query_job.result( - max_results=limit, timeout=self._retry.job_execution_timeout + max_results=limit, timeout=self._retry.job_execution_timeout() ) return query_job, iterator except TimeoutError: - exc = f"Operation did not complete within the designated timeout of {self._retry.job_execution_timeout} seconds." + exc = f"Operation did not complete within the designated timeout of {self._retry.job_execution_timeout()} seconds." raise TimeoutError(exc) def _labels_from_query_comment(self, comment: str) -> Dict: diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 98a8bee25..7118a67cc 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -4,7 +4,6 @@ from google.cloud.dataproc_v1 import ( Batch, CreateBatchRequest, - GetBatchRequest, Job, RuntimeConfig, ) @@ -30,11 +29,6 @@ class BaseDataProcHelper(PythonJobHelper): def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: - """_summary_ - - Args: - credentials (_type_): _description_ - """ # validate all additional stuff for python is set for required_config in ["dataproc_region", "gcs_bucket"]: if not getattr(credentials, required_config): @@ -83,23 +77,26 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None ) def _submit_dataproc_job(self) -> Job: - job = { - "placement": {"cluster_name": self._cluster_name}, - "pyspark_job": { - "main_python_file_uri": self._gcs_path, + request = { + "project_id": self._project, + "region": self._region, + "job": { + "placement": {"cluster_name": self._cluster_name}, + "pyspark_job": { + "main_python_file_uri": self._gcs_path, + }, }, } - operation = self._job_controller_client.submit_job_as_operation( - request={ - "project_id": self._project, - "region": self._region, - "job": job, - } - ) - # check if job failed + + # submit the job + operation = self._job_controller_client.submit_job_as_operation(request) + + # wait for the job to complete response = operation.result(polling=self._polling_retry) + if response.status.state == 6: raise ValueError(response.status.details) + return response @@ -114,29 +111,21 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None def _submit_dataproc_job(self) -> Batch: _logger.info(f"Submitting batch job with id: {self._batch_id}") - # make the request request = CreateBatchRequest( parent=f"projects/{self._project}/locations/{self._region}", - batch=self._configure_batch(), + batch=self._batch(), batch_id=self._batch_id, ) - self._batch_controller_client.create_batch(request=request) - - # return the response - batch = GetBatchRequest(f"{request.parent}/batches/{self._batch_id}") - return self._batch_controller_client.get_batch(batch, retry=self._polling_retry) - # there might be useful results here that we can parse and return - # Dataproc job output is saved to the Cloud Storage bucket - # allocated to the job. Use regex to obtain the bucket and blob info. - # matches = re.match("gs://(.*?)/(.*)", response.driver_output_resource_uri) - # output = ( - # self.storage_client - # .get_bucket(matches.group(1)) - # .blob(f"{matches.group(2)}.000000000") - # .download_as_string() - # ) - - def _configure_batch(self) -> Batch: + + # submit the batch + operation = self._batch_controller_client.create_batch(request) + + # wait for the batch to complete + response = operation.result(polling=self._polling_retry) + + return response + + def _batch(self) -> Batch: # create the Dataproc Serverless job config # need to pin dataproc version to 1.1 as it now defaults to 2.0 # https://cloud.google.com/dataproc-serverless/docs/concepts/properties diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index e5c6abbb2..0f1be6f81 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -1,8 +1,8 @@ from typing import Callable, Optional -from google.api_core.retry import Retry from google.api_core.exceptions import Forbidden -from google.api_core.future.polling import POLLING_PREDICATE +from google.api_core.future.polling import DEFAULT_POLLING +from google.api_core.retry import Retry from google.cloud.bigquery.retry import DEFAULT_RETRY from google.cloud.exceptions import BadGateway, BadRequest, ServerError from requests.exceptions import ConnectionError @@ -17,6 +17,14 @@ _logger = AdapterLogger("BigQuery") +_ONE_DAY = 60 * 60 * 24 # seconds + + +_DEFAULT_INITIAL_DELAY = 1.0 # seconds +_DEFAULT_MAXIMUM_DELAY = 3.0 # seconds +_DEFAULT_POLLING_MAXIMUM_DELAY = 10.0 # seconds + + _REOPENABLE_ERRORS = ( ConnectionResetError, ConnectionError, @@ -32,53 +40,24 @@ ) -_ONE_DAY = 60 * 60 * 24 - - class RetryFactory: - _DEFAULT_INITIAL_DELAY = 1.0 # seconds - _DEFAULT_MAXIMUM_DELAY = 3.0 # seconds - def __init__(self, credentials: BigQueryCredentials) -> None: self._retries = credentials.job_retries or 0 - self.job_creation_timeout = credentials.job_creation_timeout_seconds - self.job_execution_timeout = credentials.job_execution_timeout_seconds - self.job_deadline = credentials.job_retry_deadline_seconds + self._job_creation_timeout = credentials.job_creation_timeout_seconds + self._job_execution_timeout = credentials.job_execution_timeout_seconds + self._job_deadline = credentials.job_retry_deadline_seconds - def deadline(self, connection: Connection) -> Retry: - """ - This strategy mimics what was accomplished with _retry_and_handle - """ - return Retry( - predicate=self._buffered_predicate(), - initial=self._DEFAULT_INITIAL_DELAY, - maximum=self._DEFAULT_MAXIMUM_DELAY, - timeout=self.job_deadline, - on_error=_on_error(connection), - ) + def job_creation_timeout(self, fallback: Optional[float] = None) -> Optional[float]: + return self._job_creation_timeout or fallback or _ONE_DAY - def job_execution(self, connection: Connection) -> Retry: - """ - This strategy mimics what was accomplished with _retry_and_handle - """ - return Retry( - predicate=self._buffered_predicate(), - initial=self._DEFAULT_INITIAL_DELAY, - maximum=self._DEFAULT_MAXIMUM_DELAY, - timeout=self.job_execution_timeout, - on_error=_on_error(connection), - ) + def job_execution_timeout(self, fallback: Optional[float] = None) -> Optional[float]: + return self._job_execution_timeout or fallback or _ONE_DAY - def job_execution_capped(self, connection: Connection) -> Retry: - """ - This strategy mimics what was accomplished with _retry_and_handle - """ - return Retry( - predicate=self._buffered_predicate(), - timeout=self.job_execution_timeout or 300, - on_error=_on_error(connection), - ) + def retry( + self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None + ) -> Retry: + return DEFAULT_RETRY.with_timeout(timeout or self.job_execution_timeout(fallback_timeout)) def polling( self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None @@ -86,55 +65,53 @@ def polling( """ This strategy mimics what was accomplished with _retry_and_handle """ - return Retry( - predicate=POLLING_PREDICATE, - minimum=1.0, - maximum=10.0, - timeout=timeout or self.job_execution_timeout or fallback_timeout or _ONE_DAY, + return DEFAULT_POLLING.with_timeout( + timeout or self.job_execution_timeout(fallback_timeout) ) - def polling_done( - self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None - ) -> Retry: - return DEFAULT_RETRY.with_timeout( - timeout or self.job_execution_timeout or fallback_timeout or _ONE_DAY + def reopen_with_deadline(self, connection: Connection) -> Retry: + """ + This strategy mimics what was accomplished with _retry_and_handle + """ + return Retry( + predicate=_BufferedPredicate(self._retries), + initial=_DEFAULT_INITIAL_DELAY, + maximum=_DEFAULT_MAXIMUM_DELAY, + deadline=self._job_deadline, + on_error=_reopen_on_error(connection), ) - def _buffered_predicate(self) -> Callable[[Exception], bool]: - class BufferedPredicate: - """ - Count ALL errors, not just retryable errors, up to a threshold - then raises the next error, regardless of whether it is retryable. - - Was previously called _ErrorCounter. - """ - def __init__(self, retries: int) -> None: - self._retries: int = retries - self._error_count = 0 +class _BufferedPredicate: + """ + Count ALL errors, not just retryable errors, up to a threshold. + Raise the next error, regardless of whether it is retryable. + """ - def __call__(self, error: Exception) -> bool: - # exit immediately if the user does not want retries - if self._retries == 0: - return False + def __init__(self, retries: int) -> None: + self._retries: int = retries + self._error_count = 0 - # count all errors - self._error_count += 1 + def __call__(self, error: Exception) -> bool: + # exit immediately if the user does not want retries + if self._retries == 0: + return False - # if the error is retryable and we haven't breached the threshold, log and continue - if _is_retryable(error) and self._error_count <= self._retries: - _logger.debug( - f"Retry attempt {self._error_count} of { self._retries} after error: {repr(error)}" - ) - return True + # count all errors + self._error_count += 1 - # otherwise raise - return False + # if the error is retryable, and we haven't breached the threshold, log and continue + if _is_retryable(error) and self._error_count <= self._retries: + _logger.debug( + f"Retry attempt {self._error_count} of { self._retries} after error: {repr(error)}" + ) + return True - return BufferedPredicate(self._retries) + # otherwise raise + return False -def _on_error(connection: Connection) -> Callable[[Exception], None]: +def _reopen_on_error(connection: Connection) -> Callable[[Exception], None]: def on_error(error: Exception): if isinstance(error, _REOPENABLE_ERRORS): @@ -165,25 +142,3 @@ def _is_retryable(error: Exception) -> bool: ): return True return False - - -class _BufferedPredicate: - """Counts errors seen up to a threshold then raises the next error.""" - - def __init__(self, retries: int) -> None: - self._retries = retries - self._error_count = 0 - - def count_error(self, error): - if self._retries == 0: - return False # Don't log - self._error_count += 1 - if _is_retryable(error) and self._error_count <= self._retries: - _logger.debug( - "Retry attempt {} of {} after error: {}".format( - self._error_count, self._retries, repr(error) - ) - ) - return True - else: - return False diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 19e0e1ab4..6775445b9 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -17,7 +17,7 @@ def setUp(self): self.credentials = Mock(BigQueryCredentials) self.credentials.method = "oauth" self.credentials.job_retries = 1 - self.credentials.job_execution_timeout_seconds = 1 + self.credentials.job_retry_deadline_seconds = 1 self.credentials.scopes = tuple() self.mock_client = Mock(google.cloud.bigquery.Client) @@ -36,18 +36,21 @@ def setUp(self): "dbt.adapters.bigquery.retry.bigquery_client", return_value=Mock(google.cloud.bigquery.Client), ) - def test_retry_connection_reset(self, mock_bigquery_client): - original_handle = self.mock_connection.handle + def test_retry_connection_reset(self, mock_client_factory): + new_mock_client = mock_client_factory.return_value - @self.connections._retry.job_execution(self.mock_connection) + @self.connections._retry.reopen_with_deadline(self.mock_connection) def generate_connection_reset_error(): raise ConnectionResetError + assert self.mock_connection.handle is self.mock_client + with self.assertRaises(ConnectionResetError): # this will always raise the error, we just want to test that the connection was reopening in between generate_connection_reset_error() - assert not self.mock_connection.handle is original_handle + assert self.mock_connection.handle is new_mock_client + assert new_mock_client is not self.mock_client def test_is_retryable(self): _is_retryable = dbt.adapters.bigquery.retry._is_retryable @@ -98,12 +101,10 @@ def test_query_and_results(self, MockQueryJobConfig): def test_copy_bq_table_appends(self): self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND) - args, kwargs = self.mock_client.copy_table.call_args self.mock_client.copy_table.assert_called_once_with( [self._table_ref("project", "dataset", "table1")], self._table_ref("project", "dataset", "table2"), job_config=ANY, - retry=ANY, ) args, kwargs = self.mock_client.copy_table.call_args self.assertEqual( @@ -117,7 +118,6 @@ def test_copy_bq_table_truncates(self): [self._table_ref("project", "dataset", "table1")], self._table_ref("project", "dataset", "table2"), job_config=ANY, - retry=ANY, ) args, kwargs = self.mock_client.copy_table.call_args self.assertEqual( From b8408c2c87ae4f8d0699920bbd1930dc734dc723 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 6 Nov 2024 23:28:04 -0500 Subject: [PATCH 22/41] fix seed load result --- dbt/adapters/bigquery/connections.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 952a83b04..e1aaba935 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -496,11 +496,13 @@ def _load_table_from_file( with open(file_path, "rb") as f: job = client.load_table_from_file(f, table, rewind=True, job_config=config) - if not job.done(retry=self._retry.retry(fallback_timeout=fallback_timeout)): + response = job.result(retry=self._retry.retry(fallback_timeout=fallback_timeout)) + + if response.state != "DONE": raise DbtRuntimeError("BigQuery Timeout Exceeded") - elif job.error_result: - message = "\n".join(error["message"].strip() for error in job.errors) + elif response.error_result: + message = "\n".join(error["message"].strip() for error in response.errors) raise DbtRuntimeError(message) @staticmethod From 5b896ee55a96261f18366359e1a415bcb06f15d6 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Thu, 7 Nov 2024 12:36:44 -0500 Subject: [PATCH 23/41] create a method for the dataproc endpoint --- dbt/adapters/bigquery/clients.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/dbt/adapters/bigquery/clients.py b/dbt/adapters/bigquery/clients.py index 948d47f65..e5aa50b33 100644 --- a/dbt/adapters/bigquery/clients.py +++ b/dbt/adapters/bigquery/clients.py @@ -38,23 +38,17 @@ def storage_client(credentials: BigQueryCredentials) -> StorageClient: @Retry() # google decorator. retries on transient errors with exponential backoff def job_controller_client(credentials: BigQueryCredentials) -> JobControllerClient: - options = ClientOptions( - api_endpoint=f"{credentials.dataproc_region}-dataproc.googleapis.com:443", - ) return JobControllerClient( credentials=google_credentials(credentials), - client_options=options, + client_options=ClientOptions(api_endpoint=_dataproc_endpoint(credentials)), ) @Retry() # google decorator. retries on transient errors with exponential backoff def batch_controller_client(credentials: BigQueryCredentials) -> BatchControllerClient: - options = ClientOptions( - api_endpoint=f"{credentials.dataproc_region}-dataproc.googleapis.com:443", - ) return BatchControllerClient( credentials=google_credentials(credentials), - client_options=options, + client_options=ClientOptions(api_endpoint=_dataproc_endpoint(credentials)), ) @@ -67,3 +61,7 @@ def _bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: client_info=ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}"), client_options=ClientOptions(quota_project_id=credentials.quota_project), ) + + +def _dataproc_endpoint(credentials: BigQueryCredentials) -> str: + return f"{credentials.dataproc_region}-dataproc.googleapis.com:443" From 43c10f10bb6ccc6cc5b0abd0144b31fc9ae35b60 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Thu, 7 Nov 2024 13:29:28 -0500 Subject: [PATCH 24/41] add some readability updates --- dbt/adapters/bigquery/retry.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 0f1be6f81..d1c59c800 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -11,18 +11,20 @@ from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.exceptions.connection import FailedToConnectError -from dbt.adapters.bigquery.credentials import BigQueryCredentials from dbt.adapters.bigquery.clients import bigquery_client - -_logger = AdapterLogger("BigQuery") +from dbt.adapters.bigquery.credentials import BigQueryCredentials -_ONE_DAY = 60 * 60 * 24 # seconds +_logger = AdapterLogger("BigQuery") -_DEFAULT_INITIAL_DELAY = 1.0 # seconds -_DEFAULT_MAXIMUM_DELAY = 3.0 # seconds -_DEFAULT_POLLING_MAXIMUM_DELAY = 10.0 # seconds +_SECOND = 1.0 +_MINUTE = 60 * _SECOND +_HOUR = 60 * _MINUTE +_DAY = 24 * _HOUR +_DEFAULT_INITIAL_DELAY = _SECOND +_DEFAULT_MAXIMUM_DELAY = 3 * _SECOND +_DEFAULT_POLLING_MAXIMUM_DELAY = 10 * _SECOND _REOPENABLE_ERRORS = ( @@ -49,10 +51,14 @@ def __init__(self, credentials: BigQueryCredentials) -> None: self._job_deadline = credentials.job_retry_deadline_seconds def job_creation_timeout(self, fallback: Optional[float] = None) -> Optional[float]: - return self._job_creation_timeout or fallback or _ONE_DAY + return ( + self._job_creation_timeout or fallback or _MINUTE + ) # keep _MINUTE here so it's not overridden by passing fallback=None def job_execution_timeout(self, fallback: Optional[float] = None) -> Optional[float]: - return self._job_execution_timeout or fallback or _ONE_DAY + return ( + self._job_execution_timeout or fallback or _DAY + ) # keep _DAY here so it's not overridden by passing fallback=None def retry( self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None @@ -62,9 +68,6 @@ def retry( def polling( self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None ) -> Retry: - """ - This strategy mimics what was accomplished with _retry_and_handle - """ return DEFAULT_POLLING.with_timeout( timeout or self.job_execution_timeout(fallback_timeout) ) From 42566827fa41806053fe832c0dfe3a81bdc388ce Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Thu, 7 Nov 2024 13:36:10 -0500 Subject: [PATCH 25/41] add some readability updates --- dbt/adapters/bigquery/credentials.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dbt/adapters/bigquery/credentials.py b/dbt/adapters/bigquery/credentials.py index e00e17781..3147b6e95 100644 --- a/dbt/adapters/bigquery/credentials.py +++ b/dbt/adapters/bigquery/credentials.py @@ -18,6 +18,7 @@ from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.exceptions.connection import FailedToConnectError + _logger = AdapterLogger("BigQuery") @@ -34,9 +35,9 @@ def __init__(self, batch_config): class _BigQueryConnectionMethod(StrEnum): OAUTH = "oauth" + OAUTH_SECRETS = "oauth-secrets" SERVICE_ACCOUNT = "service-account" SERVICE_ACCOUNT_JSON = "service-account-json" - OAUTH_SECRETS = "oauth-secrets" @dataclass @@ -260,9 +261,7 @@ def _is_base64(s: Union[str, bytes]) -> bool: # Use the 'validate' parameter to enforce strict Base64 decoding rules base64.b64decode(s, validate=True) return True - except TypeError: - return False - except binascii.Error: # Catch specific errors from the base64 module + except (TypeError, binascii.Error): return False From 5644509369998736317a95a95c84ed87db3d0d10 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Thu, 7 Nov 2024 14:31:03 -0500 Subject: [PATCH 26/41] add some readability updates, simplify submit methods --- dbt/adapters/bigquery/python_submissions.py | 34 +++++++++------------ 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 7118a67cc..067bba7bb 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -1,12 +1,7 @@ from typing import Dict, Union import uuid -from google.cloud.dataproc_v1 import ( - Batch, - CreateBatchRequest, - Job, - RuntimeConfig, -) +from google.cloud.dataproc_v1 import Batch, CreateBatchRequest, Job, RuntimeConfig from dbt.adapters.base import PythonJobHelper from dbt.adapters.events.logging import AdapterLogger @@ -27,7 +22,7 @@ _DEFAULT_JAR_FILE_URI = "gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.13-0.34.0.jar" -class BaseDataProcHelper(PythonJobHelper): +class _BaseDataProcHelper(PythonJobHelper): def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: # validate all additional stuff for python is set for required_config in ["dataproc_region", "gcs_bucket"]: @@ -55,15 +50,8 @@ def _upload_to_gcs(self, compiled_code: str) -> None: blob = bucket.blob(self._model_file_name) blob.upload_from_string(compiled_code) - def submit(self, compiled_code: str) -> Job: - self._upload_to_gcs(compiled_code) - return self._submit_dataproc_job() - - def _submit_dataproc_job(self) -> Job: - raise NotImplementedError("_submit_dataproc_job not implemented") - -class ClusterDataprocHelper(BaseDataProcHelper): +class ClusterDataprocHelper(_BaseDataProcHelper): def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: super().__init__(parsed_model, credentials) self._job_controller_client = job_controller_client(credentials) @@ -76,7 +64,11 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None "Need to supply dataproc_cluster_name in profile or config to submit python job with cluster submission method" ) - def _submit_dataproc_job(self) -> Job: + def submit(self, compiled_code: str) -> Job: + _logger.info(f"Submitting cluster job to: {self._cluster_name}") + + self._upload_to_gcs(compiled_code) + request = { "project_id": self._project, "region": self._region, @@ -92,7 +84,7 @@ def _submit_dataproc_job(self) -> Job: operation = self._job_controller_client.submit_job_as_operation(request) # wait for the job to complete - response = operation.result(polling=self._polling_retry) + response: Job = operation.result(polling=self._polling_retry) if response.status.state == 6: raise ValueError(response.status.details) @@ -100,7 +92,7 @@ def _submit_dataproc_job(self) -> Job: return response -class ServerlessDataProcHelper(BaseDataProcHelper): +class ServerlessDataProcHelper(_BaseDataProcHelper): def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: super().__init__(parsed_model, credentials) self._batch_controller_client = batch_controller_client(credentials) @@ -108,9 +100,11 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None self._jar_file_uri = parsed_model["config"].get("jar_file_uri", _DEFAULT_JAR_FILE_URI) self._dataproc_batch = credentials.dataproc_batch - def _submit_dataproc_job(self) -> Batch: + def submit(self, compiled_code: str) -> Batch: _logger.info(f"Submitting batch job with id: {self._batch_id}") + self._upload_to_gcs(compiled_code) + request = CreateBatchRequest( parent=f"projects/{self._project}/locations/{self._region}", batch=self._batch(), @@ -121,7 +115,7 @@ def _submit_dataproc_job(self) -> Batch: operation = self._batch_controller_client.create_batch(request) # wait for the batch to complete - response = operation.result(polling=self._polling_retry) + response: Batch = operation.result(polling=self._polling_retry) return response From df2971bad178e3637f6edfc49d43bbd64aaea22b Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Thu, 7 Nov 2024 14:36:58 -0500 Subject: [PATCH 27/41] make imports explicit, remove unused constant --- dbt/adapters/bigquery/connections.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index e1aaba935..833bedd09 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -8,8 +8,7 @@ from typing import Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING import uuid -import google.auth -import google.auth.exceptions +from google.auth.exceptions import RefreshError from google.cloud.bigquery import ( Client, CopyJobConfig, @@ -21,9 +20,8 @@ SchemaField, Table, TableReference, - WriteDisposition, ) -import google.cloud.exceptions +from google.cloud.exceptions import BadRequest, Forbidden, NotFound from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event @@ -51,9 +49,8 @@ logger = AdapterLogger("BigQuery") -BQ_QUERY_JOB_SPLIT = "-----Query Job SQL Follows-----" -WRITE_TRUNCATE = WriteDisposition.WRITE_TRUNCATE +BQ_QUERY_JOB_SPLIT = "-----Query Job SQL Follows-----" @dataclass @@ -93,19 +90,19 @@ def exception_handler(self, sql): try: yield - except google.cloud.exceptions.BadRequest as e: + except BadRequest as e: message = "Bad request while running query" self.handle_error(e, message) - except google.cloud.exceptions.Forbidden as e: + except Forbidden as e: message = "Access denied while running query" self.handle_error(e, message) - except google.cloud.exceptions.NotFound as e: + except NotFound as e: message = "Not found while running query" self.handle_error(e, message) - except google.auth.exceptions.RefreshError as e: + except RefreshError as e: message = ( "Unable to generate access token, if you're using " "impersonate_service_account, make sure your " From 0beaac602e9635210eb10fc834b8df874f49665d Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Thu, 7 Nov 2024 14:39:06 -0500 Subject: [PATCH 28/41] changelog --- .changes/unreleased/Under the Hood-20241107-143856.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Under the Hood-20241107-143856.yaml diff --git a/.changes/unreleased/Under the Hood-20241107-143856.yaml b/.changes/unreleased/Under the Hood-20241107-143856.yaml new file mode 100644 index 000000000..918123248 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20241107-143856.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Create a retry factory to simplify retry strategies across dbt-bigquery +time: 2024-11-07T14:38:56.210445-05:00 +custom: + Author: mikealfare + Issue: "1395" From 6e2f4b45211f0b6e2ae7dbb269ea200982553bce Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Fri, 8 Nov 2024 20:20:47 -0500 Subject: [PATCH 29/41] add community member who contributed a solution and research to the changelog --- .changes/unreleased/Under the Hood-20241107-143856.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.changes/unreleased/Under the Hood-20241107-143856.yaml b/.changes/unreleased/Under the Hood-20241107-143856.yaml index 918123248..db8557bf0 100644 --- a/.changes/unreleased/Under the Hood-20241107-143856.yaml +++ b/.changes/unreleased/Under the Hood-20241107-143856.yaml @@ -2,5 +2,5 @@ kind: Under the Hood body: Create a retry factory to simplify retry strategies across dbt-bigquery time: 2024-11-07T14:38:56.210445-05:00 custom: - Author: mikealfare + Author: mikealfare osalama Issue: "1395" From f72da43652201ee6b9b3ce4ae513d57629d753fd Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 11:06:05 -0500 Subject: [PATCH 30/41] update names in clients.py to follow the naming convention --- dbt/adapters/bigquery/clients.py | 14 +++++++------- dbt/adapters/bigquery/connections.py | 4 ++-- dbt/adapters/bigquery/python_submissions.py | 12 ++++++------ dbt/adapters/bigquery/retry.py | 4 ++-- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/dbt/adapters/bigquery/clients.py b/dbt/adapters/bigquery/clients.py index e5aa50b33..edbe30faf 100644 --- a/dbt/adapters/bigquery/clients.py +++ b/dbt/adapters/bigquery/clients.py @@ -19,17 +19,17 @@ _logger = AdapterLogger("BigQuery") -def bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: +def create_bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: try: - return _bigquery_client(credentials) + return _create_bigquery_client(credentials) except DefaultCredentialsError: _logger.info("Please log into GCP to continue") setup_default_credentials() - return _bigquery_client(credentials) + return _create_bigquery_client(credentials) @Retry() # google decorator. retries on transient errors with exponential backoff -def storage_client(credentials: BigQueryCredentials) -> StorageClient: +def create_gcs_client(credentials: BigQueryCredentials) -> StorageClient: return StorageClient( project=credentials.execution_project, credentials=google_credentials(credentials), @@ -37,7 +37,7 @@ def storage_client(credentials: BigQueryCredentials) -> StorageClient: @Retry() # google decorator. retries on transient errors with exponential backoff -def job_controller_client(credentials: BigQueryCredentials) -> JobControllerClient: +def create_dataproc_job_controller_client(credentials: BigQueryCredentials) -> JobControllerClient: return JobControllerClient( credentials=google_credentials(credentials), client_options=ClientOptions(api_endpoint=_dataproc_endpoint(credentials)), @@ -45,7 +45,7 @@ def job_controller_client(credentials: BigQueryCredentials) -> JobControllerClie @Retry() # google decorator. retries on transient errors with exponential backoff -def batch_controller_client(credentials: BigQueryCredentials) -> BatchControllerClient: +def create_dataproc_batch_controller_client(credentials: BigQueryCredentials) -> BatchControllerClient: return BatchControllerClient( credentials=google_credentials(credentials), client_options=ClientOptions(api_endpoint=_dataproc_endpoint(credentials)), @@ -53,7 +53,7 @@ def batch_controller_client(credentials: BigQueryCredentials) -> BatchController @Retry() # google decorator. retries on transient errors with exponential backoff -def _bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: +def _create_bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: return BigQueryClient( credentials.execution_project, google_credentials(credentials), diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 833bedd09..d76d89ca6 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -37,7 +37,7 @@ from dbt.adapters.events.types import SQLQuery from dbt.adapters.exceptions.connection import FailedToConnectError -from dbt.adapters.bigquery.clients import bigquery_client +from dbt.adapters.bigquery.clients import create_bigquery_client from dbt.adapters.bigquery.credentials import Priority from dbt.adapters.bigquery.retry import RetryFactory @@ -192,7 +192,7 @@ def open(cls, connection): return connection try: - connection.handle = bigquery_client(connection.credentials) + connection.handle = create_bigquery_client(connection.credentials) connection.state = ConnectionState.OPEN return connection diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 067bba7bb..9d3eaa5dc 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -9,9 +9,9 @@ from dbt.adapters.bigquery.credentials import BigQueryCredentials, DataprocBatchConfig from dbt.adapters.bigquery.clients import ( - batch_controller_client, - job_controller_client, - storage_client, + create_dataproc_batch_controller_client, + create_dataproc_job_controller_client, + create_gcstorage_client, ) from dbt.adapters.bigquery.retry import RetryFactory @@ -31,7 +31,7 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None f"Need to supply {required_config} in profile to submit python job" ) - self._storage_client = storage_client(credentials) + self._storage_client = create_gcstorage_client(credentials) self._project = credentials.execution_project self._region = credentials.dataproc_region @@ -54,7 +54,7 @@ def _upload_to_gcs(self, compiled_code: str) -> None: class ClusterDataprocHelper(_BaseDataProcHelper): def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: super().__init__(parsed_model, credentials) - self._job_controller_client = job_controller_client(credentials) + self._job_controller_client = create_dataproc_job_controller_client(credentials) self._cluster_name = parsed_model["config"].get( "dataproc_cluster_name", credentials.dataproc_cluster_name ) @@ -95,7 +95,7 @@ def submit(self, compiled_code: str) -> Job: class ServerlessDataProcHelper(_BaseDataProcHelper): def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: super().__init__(parsed_model, credentials) - self._batch_controller_client = batch_controller_client(credentials) + self._batch_controller_client = create_dataproc_batch_controller_client(credentials) self._batch_id = parsed_model["config"].get("batch_id", str(uuid.uuid4())) self._jar_file_uri = parsed_model["config"].get("jar_file_uri", _DEFAULT_JAR_FILE_URI) self._dataproc_batch = credentials.dataproc_batch diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index d1c59c800..b1b1f3a6e 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -11,7 +11,7 @@ from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.exceptions.connection import FailedToConnectError -from dbt.adapters.bigquery.clients import bigquery_client +from dbt.adapters.bigquery.clients import create_bigquery_client from dbt.adapters.bigquery.credentials import BigQueryCredentials @@ -122,7 +122,7 @@ def on_error(error: Exception): connection.handle.close() try: - connection.handle = bigquery_client(connection.credentials) + connection.handle = create_bigquery_client(connection.credentials) connection.state = ConnectionState.OPEN except Exception as e: From 9fb25bce878063bb3a19799873c37c03149d0cfb Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 11:11:20 -0500 Subject: [PATCH 31/41] update names in connections.py to follow the naming convention --- dbt/adapters/bigquery/connections.py | 10 +++++----- dbt/adapters/bigquery/impl.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index d76d89ca6..f84c85a50 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -444,7 +444,7 @@ def copy_bq_table(self, source, destination, write_disposition) -> None: ) copy_job.result(timeout=self._retry.job_execution_timeout(300)) - def load_dataframe( + def write_dataframe_to_table( self, client: Client, file_path: str, @@ -461,9 +461,9 @@ def load_dataframe( field_delimiter=field_delimiter, ) table = self.table_ref(database, schema, identifier) - self._load_table_from_file(client, file_path, table, load_config, fallback_timeout) + self._write_file_to_table(client, file_path, table, load_config, fallback_timeout) - def upload_file( + def write_file_to_table( self, client: Client, file_path: str, @@ -478,9 +478,9 @@ def upload_file( config["schema"] = json.load(config["schema"]) load_config = LoadJobConfig(**config) table = self.table_ref(database, schema, identifier) - self._load_table_from_file(client, file_path, table, load_config, fallback_timeout) + self._write_file_to_table(client, file_path, table, load_config, fallback_timeout) - def _load_table_from_file( + def _write_file_to_table( self, client: Client, file_path: str, diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index 3fc6e8417..51c457129 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -663,7 +663,7 @@ def load_dataframe( table_schema = self._agate_to_schema(agate_table, column_override) file_path = agate_table.original_abspath # type: ignore - self.connections.load_dataframe( + self.connections.write_dataframe_to_table( client, file_path, database, @@ -686,7 +686,7 @@ def upload_file( connection = self.connections.get_thread_connection() client: Client = connection.handle - self.connections.upload_file( + self.connections.write_file_to_table( client, local_file_path, database, From e99d857820a66ca1ced436f00c0c4edff38b8bd3 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 11:16:08 -0500 Subject: [PATCH 32/41] update names in credentials.py to follow the naming convention --- dbt/adapters/bigquery/clients.py | 14 +++++++------- dbt/adapters/bigquery/credentials.py | 20 ++++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/dbt/adapters/bigquery/clients.py b/dbt/adapters/bigquery/clients.py index edbe30faf..1c0570731 100644 --- a/dbt/adapters/bigquery/clients.py +++ b/dbt/adapters/bigquery/clients.py @@ -11,8 +11,8 @@ import dbt.adapters.bigquery.__version__ as dbt_version from dbt.adapters.bigquery.credentials import ( BigQueryCredentials, - google_credentials, - setup_default_credentials, + create_google_credentials, + set_default_credentials, ) @@ -24,7 +24,7 @@ def create_bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: return _create_bigquery_client(credentials) except DefaultCredentialsError: _logger.info("Please log into GCP to continue") - setup_default_credentials() + set_default_credentials() return _create_bigquery_client(credentials) @@ -32,14 +32,14 @@ def create_bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: def create_gcs_client(credentials: BigQueryCredentials) -> StorageClient: return StorageClient( project=credentials.execution_project, - credentials=google_credentials(credentials), + credentials=create_google_credentials(credentials), ) @Retry() # google decorator. retries on transient errors with exponential backoff def create_dataproc_job_controller_client(credentials: BigQueryCredentials) -> JobControllerClient: return JobControllerClient( - credentials=google_credentials(credentials), + credentials=create_google_credentials(credentials), client_options=ClientOptions(api_endpoint=_dataproc_endpoint(credentials)), ) @@ -47,7 +47,7 @@ def create_dataproc_job_controller_client(credentials: BigQueryCredentials) -> J @Retry() # google decorator. retries on transient errors with exponential backoff def create_dataproc_batch_controller_client(credentials: BigQueryCredentials) -> BatchControllerClient: return BatchControllerClient( - credentials=google_credentials(credentials), + credentials=create_google_credentials(credentials), client_options=ClientOptions(api_endpoint=_dataproc_endpoint(credentials)), ) @@ -56,7 +56,7 @@ def create_dataproc_batch_controller_client(credentials: BigQueryCredentials) -> def _create_bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: return BigQueryClient( credentials.execution_project, - google_credentials(credentials), + create_google_credentials(credentials), location=getattr(credentials, "location", None), client_info=ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}"), client_options=ClientOptions(quota_project_id=credentials.quota_project), diff --git a/dbt/adapters/bigquery/credentials.py b/dbt/adapters/bigquery/credentials.py index 3147b6e95..94d70a931 100644 --- a/dbt/adapters/bigquery/credentials.py +++ b/dbt/adapters/bigquery/credentials.py @@ -148,7 +148,7 @@ def __pre_deserialize__(cls, d: Dict[Any, Any]) -> Dict[Any, Any]: # `database` is an alias of `project` in BigQuery if "database" not in d: - _, database = _bigquery_defaults() + _, database = _create_bigquery_defaults() d["database"] = database # `execution_project` default to dataset/project if "execution_project" not in d: @@ -156,7 +156,7 @@ def __pre_deserialize__(cls, d: Dict[Any, Any]) -> Dict[Any, Any]: return d -def setup_default_credentials() -> None: +def set_default_credentials() -> None: try: run_cmd(".", ["gcloud", "--version"]) except OSError as e: @@ -172,29 +172,29 @@ def setup_default_credentials() -> None: run_cmd(".", ["gcloud", "auth", "application-default", "login"]) -def google_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: +def create_google_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: if credentials.impersonate_service_account: - return _impersonated_credentials(credentials) - return _google_credentials(credentials) + return _create_impersonated_credentials(credentials) + return _create_google_credentials(credentials) -def _impersonated_credentials(credentials: BigQueryCredentials) -> ImpersonatedCredentials: +def _create_impersonated_credentials(credentials: BigQueryCredentials) -> ImpersonatedCredentials: if credentials.scopes and isinstance(credentials.scopes, Iterable): target_scopes = list(credentials.scopes) else: target_scopes = [] return ImpersonatedCredentials( - source_credentials=_google_credentials(credentials), + source_credentials=_create_google_credentials(credentials), target_principal=credentials.impersonate_service_account, target_scopes=target_scopes, ) -def _google_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: +def _create_google_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: if credentials.method == _BigQueryConnectionMethod.OAUTH: - creds, _ = _bigquery_defaults(scopes=credentials.scopes) + creds, _ = _create_bigquery_defaults(scopes=credentials.scopes) elif credentials.method == _BigQueryConnectionMethod.SERVICE_ACCOUNT: creds = ServiceAccountCredentials.from_service_account_file( @@ -226,7 +226,7 @@ def _google_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: @lru_cache() -def _bigquery_defaults(scopes=None) -> Tuple[Any, Optional[str]]: +def _create_bigquery_defaults(scopes=None) -> Tuple[Any, Optional[str]]: """ Returns (credentials, project_id) From f8ad9534d5857ff587ae6f0758b042ab6b708230 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 11:19:06 -0500 Subject: [PATCH 33/41] update names in python_submissions.py to follow the naming convention --- dbt/adapters/bigquery/python_submissions.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 9d3eaa5dc..993c605f3 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -11,7 +11,7 @@ from dbt.adapters.bigquery.clients import ( create_dataproc_batch_controller_client, create_dataproc_job_controller_client, - create_gcstorage_client, + create_gcs_client, ) from dbt.adapters.bigquery.retry import RetryFactory @@ -31,7 +31,7 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None f"Need to supply {required_config} in profile to submit python job" ) - self._storage_client = create_gcstorage_client(credentials) + self._storage_client = create_gcs_client(credentials) self._project = credentials.execution_project self._region = credentials.dataproc_region @@ -45,7 +45,7 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None retry = RetryFactory(credentials) self._polling_retry = retry.polling(timeout=parsed_model["config"].get("timeout")) - def _upload_to_gcs(self, compiled_code: str) -> None: + def _write_to_gcs(self, compiled_code: str) -> None: bucket = self._storage_client.get_bucket(self._gcs_bucket) blob = bucket.blob(self._model_file_name) blob.upload_from_string(compiled_code) @@ -65,9 +65,9 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None ) def submit(self, compiled_code: str) -> Job: - _logger.info(f"Submitting cluster job to: {self._cluster_name}") + _logger.debug(f"Submitting cluster job to: {self._cluster_name}") - self._upload_to_gcs(compiled_code) + self._write_to_gcs(compiled_code) request = { "project_id": self._project, @@ -101,13 +101,13 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None self._dataproc_batch = credentials.dataproc_batch def submit(self, compiled_code: str) -> Batch: - _logger.info(f"Submitting batch job with id: {self._batch_id}") + _logger.debug(f"Submitting batch job with id: {self._batch_id}") - self._upload_to_gcs(compiled_code) + self._write_to_gcs(compiled_code) request = CreateBatchRequest( parent=f"projects/{self._project}/locations/{self._region}", - batch=self._batch(), + batch=self._create_batch(), batch_id=self._batch_id, ) @@ -119,7 +119,7 @@ def submit(self, compiled_code: str) -> Batch: return response - def _batch(self) -> Batch: + def _create_batch(self) -> Batch: # create the Dataproc Serverless job config # need to pin dataproc version to 1.1 as it now defaults to 2.0 # https://cloud.google.com/dataproc-serverless/docs/concepts/properties From 5f3a456cf89f2d8fcd3810b194c977f5b7ccd07b Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 11:23:46 -0500 Subject: [PATCH 34/41] update names in retry.py to follow the naming convention --- dbt/adapters/bigquery/connections.py | 20 +++++++++---------- dbt/adapters/bigquery/python_submissions.py | 2 +- dbt/adapters/bigquery/retry.py | 18 ++++++++--------- .../unit/test_bigquery_connection_manager.py | 2 +- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index f84c85a50..14f32383b 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -142,7 +142,7 @@ def cancel_open(self): with self.exception_handler(f"Cancel job: {job_id}"): client.cancel_job( job_id, - retry=self._retry.reopen_with_deadline(connection), + retry=self._retry.create_reopen_with_deadline(connection), ) self.close(connection) @@ -442,7 +442,7 @@ def copy_bq_table(self, source, destination, write_disposition) -> None: destination_ref, job_config=CopyJobConfig(write_disposition=write_disposition), ) - copy_job.result(timeout=self._retry.job_execution_timeout(300)) + copy_job.result(timeout=self._retry.create_job_execution_timeout(300)) def write_dataframe_to_table( self, @@ -493,7 +493,7 @@ def _write_file_to_table( with open(file_path, "rb") as f: job = client.load_table_from_file(f, table, rewind=True, job_config=config) - response = job.result(retry=self._retry.retry(fallback_timeout=fallback_timeout)) + response = job.result(retry=self._retry.create_retry(fallback_timeout=fallback_timeout)) if response.state != "DONE": raise DbtRuntimeError("BigQuery Timeout Exceeded") @@ -520,7 +520,7 @@ def get_bq_table(self, database, schema, identifier) -> Table: schema = schema or conn.credentials.schema return client.get_table( table=self.table_ref(database, schema, identifier), - retry=self._retry.reopen_with_deadline(conn), + retry=self._retry.create_reopen_with_deadline(conn), ) def drop_dataset(self, database, schema) -> None: @@ -531,7 +531,7 @@ def drop_dataset(self, database, schema) -> None: dataset=self.dataset_ref(database, schema), delete_contents=True, not_found_ok=True, - retry=self._retry.reopen_with_deadline(conn), + retry=self._retry.create_reopen_with_deadline(conn), ) def create_dataset(self, database, schema) -> Dataset: @@ -541,7 +541,7 @@ def create_dataset(self, database, schema) -> Dataset: return client.create_dataset( dataset=self.dataset_ref(database, schema), exists_ok=True, - retry=self._retry.reopen_with_deadline(conn), + retry=self._retry.create_reopen_with_deadline(conn), ) def list_dataset(self, database: str): @@ -554,7 +554,7 @@ def list_dataset(self, database: str): all_datasets = client.list_datasets( project=database.strip("`"), max_results=10000, - retry=self._retry.reopen_with_deadline(conn), + retry=self._retry.create_reopen_with_deadline(conn), ) return [ds.dataset_id for ds in all_datasets] @@ -573,7 +573,7 @@ def _query_and_results( query=sql, job_config=QueryJobConfig(**job_params), job_id=job_id, # note, this disables retry since the job_id will have been used - timeout=self._retry.job_creation_timeout(), + timeout=self._retry.create_job_creation_timeout(), ) if ( query_job.location is not None @@ -585,11 +585,11 @@ def _query_and_results( ) try: iterator = query_job.result( - max_results=limit, timeout=self._retry.job_execution_timeout() + max_results=limit, timeout=self._retry.create_job_execution_timeout() ) return query_job, iterator except TimeoutError: - exc = f"Operation did not complete within the designated timeout of {self._retry.job_execution_timeout()} seconds." + exc = f"Operation did not complete within the designated timeout of {self._retry.create_job_execution_timeout()} seconds." raise TimeoutError(exc) def _labels_from_query_comment(self, comment: str) -> Dict: diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 993c605f3..76471da83 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -43,7 +43,7 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None # set retry policy, default to timeout after 24 hours retry = RetryFactory(credentials) - self._polling_retry = retry.polling(timeout=parsed_model["config"].get("timeout")) + self._polling_retry = retry.create_polling(timeout=parsed_model["config"].get("timeout")) def _write_to_gcs(self, compiled_code: str) -> None: bucket = self._storage_client.get_bucket(self._gcs_bucket) diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index b1b1f3a6e..919124b2a 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -50,29 +50,29 @@ def __init__(self, credentials: BigQueryCredentials) -> None: self._job_execution_timeout = credentials.job_execution_timeout_seconds self._job_deadline = credentials.job_retry_deadline_seconds - def job_creation_timeout(self, fallback: Optional[float] = None) -> Optional[float]: + def create_job_creation_timeout(self, fallback: Optional[float] = None) -> Optional[float]: return ( self._job_creation_timeout or fallback or _MINUTE ) # keep _MINUTE here so it's not overridden by passing fallback=None - def job_execution_timeout(self, fallback: Optional[float] = None) -> Optional[float]: + def create_job_execution_timeout(self, fallback: Optional[float] = None) -> Optional[float]: return ( self._job_execution_timeout or fallback or _DAY ) # keep _DAY here so it's not overridden by passing fallback=None - def retry( + def create_retry( self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None ) -> Retry: - return DEFAULT_RETRY.with_timeout(timeout or self.job_execution_timeout(fallback_timeout)) + return DEFAULT_RETRY.with_timeout(timeout or self.create_job_execution_timeout(fallback_timeout)) - def polling( + def create_polling( self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None ) -> Retry: return DEFAULT_POLLING.with_timeout( - timeout or self.job_execution_timeout(fallback_timeout) + timeout or self.create_job_execution_timeout(fallback_timeout) ) - def reopen_with_deadline(self, connection: Connection) -> Retry: + def create_reopen_with_deadline(self, connection: Connection) -> Retry: """ This strategy mimics what was accomplished with _retry_and_handle """ @@ -81,7 +81,7 @@ def reopen_with_deadline(self, connection: Connection) -> Retry: initial=_DEFAULT_INITIAL_DELAY, maximum=_DEFAULT_MAXIMUM_DELAY, deadline=self._job_deadline, - on_error=_reopen_on_error(connection), + on_error=_create_reopen_on_error(connection), ) @@ -114,7 +114,7 @@ def __call__(self, error: Exception) -> bool: return False -def _reopen_on_error(connection: Connection) -> Callable[[Exception], None]: +def _create_reopen_on_error(connection: Connection) -> Callable[[Exception], None]: def on_error(error: Exception): if isinstance(error, _REOPENABLE_ERRORS): diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 6775445b9..1aa1f24ac 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -39,7 +39,7 @@ def setUp(self): def test_retry_connection_reset(self, mock_client_factory): new_mock_client = mock_client_factory.return_value - @self.connections._retry.reopen_with_deadline(self.mock_connection) + @self.connections._retry.create_reopen_with_deadline(self.mock_connection) def generate_connection_reset_error(): raise ConnectionResetError From 7c4388fcef2e047933e58e4256485f7c87e16042 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 11:26:59 -0500 Subject: [PATCH 35/41] run linter and update unit test mocks --- dbt/adapters/bigquery/clients.py | 4 +++- dbt/adapters/bigquery/retry.py | 4 +++- tests/unit/test_bigquery_adapter.py | 4 ++-- tests/unit/test_bigquery_connection_manager.py | 2 +- tests/unit/test_configure_dataproc_batch.py | 4 ++-- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dbt/adapters/bigquery/clients.py b/dbt/adapters/bigquery/clients.py index 1c0570731..18c59fc12 100644 --- a/dbt/adapters/bigquery/clients.py +++ b/dbt/adapters/bigquery/clients.py @@ -45,7 +45,9 @@ def create_dataproc_job_controller_client(credentials: BigQueryCredentials) -> J @Retry() # google decorator. retries on transient errors with exponential backoff -def create_dataproc_batch_controller_client(credentials: BigQueryCredentials) -> BatchControllerClient: +def create_dataproc_batch_controller_client( + credentials: BigQueryCredentials, +) -> BatchControllerClient: return BatchControllerClient( credentials=create_google_credentials(credentials), client_options=ClientOptions(api_endpoint=_dataproc_endpoint(credentials)), diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 919124b2a..13fb5ff3c 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -63,7 +63,9 @@ def create_job_execution_timeout(self, fallback: Optional[float] = None) -> Opti def create_retry( self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None ) -> Retry: - return DEFAULT_RETRY.with_timeout(timeout or self.create_job_execution_timeout(fallback_timeout)) + return DEFAULT_RETRY.with_timeout( + timeout or self.create_job_execution_timeout(fallback_timeout) + ) def create_polling( self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None diff --git a/tests/unit/test_bigquery_adapter.py b/tests/unit/test_bigquery_adapter.py index 3d7e9e77e..e57db9a62 100644 --- a/tests/unit/test_bigquery_adapter.py +++ b/tests/unit/test_bigquery_adapter.py @@ -203,7 +203,7 @@ def get_adapter(self, target) -> BigQueryAdapter: class TestBigQueryAdapterAcquire(BaseTestBigQueryAdapter): @patch( - "dbt.adapters.bigquery.credentials._bigquery_defaults", + "dbt.adapters.bigquery.credentials._create_bigquery_defaults", return_value=("credentials", "project_id"), ) @patch("dbt.adapters.bigquery.BigQueryConnectionManager.open", return_value=_bq_conn()) @@ -244,7 +244,7 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection): mock_open_connection.assert_called_once() @patch( - "dbt.adapters.bigquery.credentials._bigquery_defaults", + "dbt.adapters.bigquery.credentials._create_bigquery_defaults", return_value=("credentials", "project_id"), ) @patch( diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 1aa1f24ac..580ff4422 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -33,7 +33,7 @@ def setUp(self): self.connections.get_thread_connection = lambda: self.mock_connection @patch( - "dbt.adapters.bigquery.retry.bigquery_client", + "dbt.adapters.bigquery.retry.create_bigquery_client", return_value=Mock(google.cloud.bigquery.Client), ) def test_retry_connection_reset(self, mock_client_factory): diff --git a/tests/unit/test_configure_dataproc_batch.py b/tests/unit/test_configure_dataproc_batch.py index e73e5b845..6e5757589 100644 --- a/tests/unit/test_configure_dataproc_batch.py +++ b/tests/unit/test_configure_dataproc_batch.py @@ -12,7 +12,7 @@ # parsed credentials class TestConfigureDataprocBatch(BaseTestBigQueryAdapter): @patch( - "dbt.adapters.bigquery.credentials._bigquery_defaults", + "dbt.adapters.bigquery.credentials._create_bigquery_defaults", return_value=("credentials", "project_id"), ) def test_update_dataproc_serverless_batch(self, mock_get_bigquery_defaults): @@ -64,7 +64,7 @@ def to_str_values(d): ) @patch( - "dbt.adapters.bigquery.credentials._bigquery_defaults", + "dbt.adapters.bigquery.credentials._create_bigquery_defaults", return_value=("credentials", "project_id"), ) def test_default_dataproc_serverless_batch(self, mock_get_bigquery_defaults): From 5928098e4c7e2ce4eb7b3c9b43a63416045374f4 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 21:03:33 -0500 Subject: [PATCH 36/41] update types on retry factory --- dbt/adapters/bigquery/connections.py | 14 +++++----- dbt/adapters/bigquery/retry.py | 41 ++++++++-------------------- 2 files changed, 19 insertions(+), 36 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 14f32383b..17dd0a53d 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -442,7 +442,7 @@ def copy_bq_table(self, source, destination, write_disposition) -> None: destination_ref, job_config=CopyJobConfig(write_disposition=write_disposition), ) - copy_job.result(timeout=self._retry.create_job_execution_timeout(300)) + copy_job.result(timeout=self._retry.create_job_execution_timeout(fallback=300)) def write_dataframe_to_table( self, @@ -493,7 +493,7 @@ def _write_file_to_table( with open(file_path, "rb") as f: job = client.load_table_from_file(f, table, rewind=True, job_config=config) - response = job.result(retry=self._retry.create_retry(fallback_timeout=fallback_timeout)) + response = job.result(retry=self._retry.create_retry(fallback=fallback_timeout)) if response.state != "DONE": raise DbtRuntimeError("BigQuery Timeout Exceeded") @@ -583,14 +583,14 @@ def _query_and_results( logger.debug( self._bq_job_link(query_job.location, query_job.project, query_job.job_id) ) + + timeout = self._retry.create_job_execution_timeout() try: - iterator = query_job.result( - max_results=limit, timeout=self._retry.create_job_execution_timeout() - ) - return query_job, iterator + iterator = query_job.result(max_results=limit, timeout=timeout) except TimeoutError: - exc = f"Operation did not complete within the designated timeout of {self._retry.create_job_execution_timeout()} seconds." + exc = f"Operation did not complete within the designated timeout of {timeout} seconds." raise TimeoutError(exc) + return query_job, iterator def _labels_from_query_comment(self, comment: str) -> Dict: try: diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 13fb5ff3c..324bcf588 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -27,21 +27,6 @@ _DEFAULT_POLLING_MAXIMUM_DELAY = 10 * _SECOND -_REOPENABLE_ERRORS = ( - ConnectionResetError, - ConnectionError, -) - - -_RETRYABLE_ERRORS = ( - ServerError, - BadRequest, - BadGateway, - ConnectionResetError, - ConnectionError, -) - - class RetryFactory: def __init__(self, credentials: BigQueryCredentials) -> None: @@ -50,29 +35,25 @@ def __init__(self, credentials: BigQueryCredentials) -> None: self._job_execution_timeout = credentials.job_execution_timeout_seconds self._job_deadline = credentials.job_retry_deadline_seconds - def create_job_creation_timeout(self, fallback: Optional[float] = None) -> Optional[float]: + def create_job_creation_timeout(self, fallback: float = _MINUTE) -> float: return ( - self._job_creation_timeout or fallback or _MINUTE + self._job_creation_timeout or fallback ) # keep _MINUTE here so it's not overridden by passing fallback=None - def create_job_execution_timeout(self, fallback: Optional[float] = None) -> Optional[float]: + def create_job_execution_timeout(self, fallback: float = _DAY) -> float: return ( - self._job_execution_timeout or fallback or _DAY + self._job_execution_timeout or fallback ) # keep _DAY here so it's not overridden by passing fallback=None def create_retry( - self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None + self, timeout: Optional[float] = None, fallback: Optional[float] = None ) -> Retry: return DEFAULT_RETRY.with_timeout( - timeout or self.create_job_execution_timeout(fallback_timeout) + timeout or self._job_execution_timeout or fallback or _DAY ) - def create_polling( - self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None - ) -> Retry: - return DEFAULT_POLLING.with_timeout( - timeout or self.create_job_execution_timeout(fallback_timeout) - ) + def create_polling(self, timeout: Optional[float] = None, fallback: float = _DAY) -> Retry: + return DEFAULT_POLLING.with_timeout(timeout or self._job_execution_timeout or fallback) def create_reopen_with_deadline(self, connection: Connection) -> Retry: """ @@ -119,7 +100,7 @@ def __call__(self, error: Exception) -> bool: def _create_reopen_on_error(connection: Connection) -> Callable[[Exception], None]: def on_error(error: Exception): - if isinstance(error, _REOPENABLE_ERRORS): + if isinstance(error, (ConnectionResetError, ConnectionError)): _logger.warning("Reopening connection after {!r}".format(error)) connection.handle.close() @@ -140,7 +121,9 @@ def on_error(error: Exception): def _is_retryable(error: Exception) -> bool: """Return true for errors that are unlikely to occur again if retried.""" - if isinstance(error, _RETRYABLE_ERRORS): + if isinstance( + error, (BadGateway, BadRequest, ConnectionError, ConnectionResetError, ServerError) + ): return True elif isinstance(error, Forbidden) and any( e["reason"] == "rateLimitExceeded" for e in error.errors From 02385bb5a3f1e48fba294e47a4b7843dd059f4e2 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 21:11:28 -0500 Subject: [PATCH 37/41] update inputs on retry factory --- dbt/adapters/bigquery/python_submissions.py | 2 +- dbt/adapters/bigquery/retry.py | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 76471da83..9c69c9e7c 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -43,7 +43,7 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None # set retry policy, default to timeout after 24 hours retry = RetryFactory(credentials) - self._polling_retry = retry.create_polling(timeout=parsed_model["config"].get("timeout")) + self._polling_retry = retry.create_polling(model_timeout=parsed_model["config"].get("timeout")) def _write_to_gcs(self, compiled_code: str) -> None: bucket = self._storage_client.get_bucket(self._gcs_bucket) diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 324bcf588..f275cf76b 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -45,15 +45,11 @@ def create_job_execution_timeout(self, fallback: float = _DAY) -> float: self._job_execution_timeout or fallback ) # keep _DAY here so it's not overridden by passing fallback=None - def create_retry( - self, timeout: Optional[float] = None, fallback: Optional[float] = None - ) -> Retry: - return DEFAULT_RETRY.with_timeout( - timeout or self._job_execution_timeout or fallback or _DAY - ) + def create_retry(self, fallback: Optional[float] = None) -> Retry: + return DEFAULT_RETRY.with_timeout(self._job_execution_timeout or fallback or _DAY) - def create_polling(self, timeout: Optional[float] = None, fallback: float = _DAY) -> Retry: - return DEFAULT_POLLING.with_timeout(timeout or self._job_execution_timeout or fallback) + def create_polling(self, model_timeout: Optional[float] = None) -> Retry: + return DEFAULT_POLLING.with_timeout(model_timeout or self._job_execution_timeout or _DAY) def create_reopen_with_deadline(self, connection: Connection) -> Retry: """ From 51cc87ff8fb3717e7d8199eac1e94b3c5bb0361b Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 21:13:41 -0500 Subject: [PATCH 38/41] update predicate class name --- dbt/adapters/bigquery/retry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index f275cf76b..02aa81edf 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -56,7 +56,7 @@ def create_reopen_with_deadline(self, connection: Connection) -> Retry: This strategy mimics what was accomplished with _retry_and_handle """ return Retry( - predicate=_BufferedPredicate(self._retries), + predicate=_DeferredException(self._retries), initial=_DEFAULT_INITIAL_DELAY, maximum=_DEFAULT_MAXIMUM_DELAY, deadline=self._job_deadline, @@ -64,7 +64,7 @@ def create_reopen_with_deadline(self, connection: Connection) -> Retry: ) -class _BufferedPredicate: +class _DeferredException: """ Count ALL errors, not just retryable errors, up to a threshold. Raise the next error, regardless of whether it is retryable. From eaab97629d7130746f0258eaf73d6caa8df3f8a3 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 21:46:41 -0500 Subject: [PATCH 39/41] add retry strategy back to copy table --- dbt/adapters/bigquery/connections.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 17dd0a53d..61fa87d40 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -441,6 +441,7 @@ def copy_bq_table(self, source, destination, write_disposition) -> None: source_ref_array, destination_ref, job_config=CopyJobConfig(write_disposition=write_disposition), + retry=self._retry.create_reopen_with_deadline(conn), ) copy_job.result(timeout=self._retry.create_job_execution_timeout(fallback=300)) From a81289fe3020144b54a7ac1e1bd0566604e6d775 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 23:05:50 -0500 Subject: [PATCH 40/41] linting and fix unit test for new argument --- dbt/adapters/bigquery/python_submissions.py | 4 +++- tests/unit/test_bigquery_connection_manager.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 9c69c9e7c..cd7f7d86f 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -43,7 +43,9 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None # set retry policy, default to timeout after 24 hours retry = RetryFactory(credentials) - self._polling_retry = retry.create_polling(model_timeout=parsed_model["config"].get("timeout")) + self._polling_retry = retry.create_polling( + model_timeout=parsed_model["config"].get("timeout") + ) def _write_to_gcs(self, compiled_code: str) -> None: bucket = self._storage_client.get_bucket(self._gcs_bucket) diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 580ff4422..d4c95792e 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -105,6 +105,7 @@ def test_copy_bq_table_appends(self): [self._table_ref("project", "dataset", "table1")], self._table_ref("project", "dataset", "table2"), job_config=ANY, + retry=ANY, ) args, kwargs = self.mock_client.copy_table.call_args self.assertEqual( @@ -118,6 +119,7 @@ def test_copy_bq_table_truncates(self): [self._table_ref("project", "dataset", "table1")], self._table_ref("project", "dataset", "table2"), job_config=ANY, + retry=ANY, ) args, kwargs = self.mock_client.copy_table.call_args self.assertEqual( From 76d6979ec78b7452a4df1c63a9c191bc0319dfd2 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 19 Nov 2024 23:24:44 -0500 Subject: [PATCH 41/41] fix whitespace --- dbt/adapters/bigquery/retry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 02aa81edf..391c00e46 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -85,7 +85,7 @@ def __call__(self, error: Exception) -> bool: # if the error is retryable, and we haven't breached the threshold, log and continue if _is_retryable(error) and self._error_count <= self._retries: _logger.debug( - f"Retry attempt {self._error_count} of { self._retries} after error: {repr(error)}" + f"Retry attempt {self._error_count} of {self._retries} after error: {repr(error)}" ) return True