Skip to content

Commit

Permalink
Disable SSL verification for CloudFetch links (#414)
Browse files Browse the repository at this point in the history
* Disable SSL verification for CloudFetch links

Signed-off-by: Levko Kravets <[email protected]>

* Use existing `_tls_no_verify` option in CloudFetch downloader

Signed-off-by: Levko Kravets <[email protected]>

* Update tests

Signed-off-by: Levko Kravets <[email protected]>

---------

Signed-off-by: Levko Kravets <[email protected]>
  • Loading branch information
kravets-levko authored Jul 16, 2024
1 parent 134b21d commit dbf183b
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 29 deletions.
2 changes: 2 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def read(self) -> Optional[OAuthToken]:
# Which port to connect to
# _skip_routing_headers:
# Don't set routing headers if set to True (for use when connecting directly to server)
# _tls_no_verify
# Set to True (Boolean) to completely disable SSL verification.
# _tls_verify_hostname
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
# _tls_trusted_ca_file
Expand Down
9 changes: 8 additions & 1 deletion src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from ssl import SSLContext
from concurrent.futures import ThreadPoolExecutor, Future
from typing import List, Union

Expand All @@ -19,6 +20,7 @@ def __init__(
links: List[TSparkArrowResultLink],
max_download_threads: int,
lz4_compressed: bool,
ssl_context: SSLContext,
):
self._pending_links: List[TSparkArrowResultLink] = []
for link in links:
Expand All @@ -36,6 +38,7 @@ def __init__(
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)

self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self._ssl_context = ssl_context

def get_next_downloaded_file(
self, next_row_offset: int
Expand Down Expand Up @@ -89,7 +92,11 @@ def _schedule_downloads(self):
logger.debug(
"- start: {}, row count: {}".format(link.startRowOffset, link.rowCount)
)
handler = ResultSetDownloadHandler(self._downloadable_result_settings, link)
handler = ResultSetDownloadHandler(
settings=self._downloadable_result_settings,
link=link,
ssl_context=self._ssl_context,
)
task = self._thread_pool.submit(handler.run)
self._download_tasks.append(task)

Expand Down
9 changes: 8 additions & 1 deletion src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import requests
from requests.adapters import HTTPAdapter, Retry
from ssl import SSLContext, CERT_NONE
import lz4.frame
import time

Expand Down Expand Up @@ -65,9 +66,11 @@ def __init__(
self,
settings: DownloadableResultSettings,
link: TSparkArrowResultLink,
ssl_context: SSLContext,
):
self.settings = settings
self.link = link
self._ssl_context = ssl_context

def run(self) -> DownloadedFile:
"""
Expand All @@ -92,10 +95,14 @@ def run(self) -> DownloadedFile:
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))

ssl_verify = self._ssl_context.verify_mode != CERT_NONE

try:
# Get the file via HTTP request
response = session.get(
self.link.fileLink, timeout=self.settings.download_timeout
self.link.fileLink,
timeout=self.settings.download_timeout,
verify=ssl_verify,
)
response.raise_for_status()

Expand Down
6 changes: 5 additions & 1 deletion src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def __init__(
password=tls_client_cert_key_password,
)

self._ssl_context = ssl_context

self._auth_provider = auth_provider

# Connector version 3 retry approach
Expand Down Expand Up @@ -223,7 +225,7 @@ def __init__(
self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
auth_provider=self._auth_provider,
uri_or_host=uri,
ssl_context=ssl_context,
ssl_context=self._ssl_context,
**additional_transport_args, # type: ignore
)

Expand Down Expand Up @@ -774,6 +776,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_context=self._ssl_context,
)
else:
arrow_queue_opt = None
Expand Down Expand Up @@ -1005,6 +1008,7 @@ def fetch_results(
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_context=self._ssl_context,
)

return queue, resp.hasMoreRows
Expand Down
13 changes: 11 additions & 2 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union
import re
from ssl import SSLContext

import lz4.frame
import pyarrow
Expand Down Expand Up @@ -47,6 +48,7 @@ def build_queue(
t_row_set: TRowSet,
arrow_schema_bytes: bytes,
max_download_threads: int,
ssl_context: SSLContext,
lz4_compressed: bool = True,
description: Optional[List[List[Any]]] = None,
) -> ResultSetQueue:
Expand All @@ -60,6 +62,7 @@ def build_queue(
lz4_compressed (bool): Whether result data has been lz4 compressed.
description (List[List[Any]]): Hive table schema description.
max_download_threads (int): Maximum number of downloader thread pool threads.
ssl_context (SSLContext): SSLContext object for CloudFetchQueue
Returns:
ResultSetQueue
Expand All @@ -82,12 +85,13 @@ def build_queue(
return ArrowQueue(converted_arrow_table, n_valid_rows)
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
return CloudFetchQueue(
arrow_schema_bytes,
schema_bytes=arrow_schema_bytes,
start_row_offset=t_row_set.startRowOffset,
result_links=t_row_set.resultLinks,
lz4_compressed=lz4_compressed,
description=description,
max_download_threads=max_download_threads,
ssl_context=ssl_context,
)
else:
raise AssertionError("Row set type is not valid")
Expand Down Expand Up @@ -133,6 +137,7 @@ def __init__(
self,
schema_bytes,
max_download_threads: int,
ssl_context: SSLContext,
start_row_offset: int = 0,
result_links: Optional[List[TSparkArrowResultLink]] = None,
lz4_compressed: bool = True,
Expand All @@ -155,6 +160,7 @@ def __init__(
self.result_links = result_links
self.lz4_compressed = lz4_compressed
self.description = description
self._ssl_context = ssl_context

logger.debug(
"Initialize CloudFetch loader, row set start offset: {}, file list:".format(
Expand All @@ -169,7 +175,10 @@ def __init__(
)
)
self.download_manager = ResultFileDownloadManager(
result_links or [], self.max_download_threads, self.lz4_compressed
links=result_links or [],
max_download_threads=self.max_download_threads,
lz4_compressed=self.lz4_compressed,
ssl_context=self._ssl_context,
)

self.table = self._create_next_table()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def test_cancel_command_calls_the_backend(self):
mock_op_handle = Mock()
cursor.active_op_handle = mock_op_handle
cursor.cancel()
self.assertTrue(mock_thrift_backend.cancel_command.called_with(mock_op_handle))
mock_thrift_backend.cancel_command.assert_called_with(mock_op_handle)

@patch("databricks.sql.client.logger")
def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command(
Expand Down
Loading

0 comments on commit dbf183b

Please sign in to comment.