From 0cbfae6be0c0003c3afe5880fffe723668bb606a Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 15 Jul 2025 11:36:12 +0530 Subject: [PATCH 1/4] Minor fix --- src/databricks/sql/client.py | 13 +++++---- src/databricks/sql/cloudfetch/downloader.py | 32 ++++++++------------- src/databricks/sql/utils.py | 10 ++++--- 3 files changed, 25 insertions(+), 30 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index b4cd78cf8..4425ec131 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1454,7 +1454,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows - + partial_result_chunks = [results] while ( n_remaining_rows > 0 and not self.has_been_closed_server_side @@ -1462,11 +1462,11 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables([results, partial_results]) + partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return results + return pyarrow.concat_tables(partial_result_chunks, use_threads=True) def merge_columnar(self, result1, result2): """ @@ -1514,7 +1514,8 @@ def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows - + + partial_result_chunks = [results] while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() @@ -1523,7 +1524,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": ): results = self.merge_columnar(results, partial_results) else: - results = pyarrow.concat_tables([results, partial_results]) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table @@ -1534,7 +1535,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": for name, col in zip(results.column_names, results.column_table) } return pyarrow.Table.from_pydict(data) - return results + return pyarrow.concat_tables(partial_result_chunks, use_threads=True) def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 228e07d6c..905dab6b6 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -1,11 +1,10 @@ import logging from dataclasses import dataclass -import requests -from requests.adapters import HTTPAdapter, Retry +from requests.adapters import Retry import lz4.frame import time - +from databricks.sql.common.http import DatabricksHttpClient, HttpMethod from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.exc import Error from databricks.sql.types import SSLOptions @@ -70,6 +69,7 @@ def __init__( self.settings = settings self.link = link self._ssl_options = ssl_options + self._http_client = DatabricksHttpClient.get_instance() def run(self) -> DownloadedFile: """ @@ -89,20 +89,15 @@ def run(self) -> DownloadedFile: ResultSetDownloadHandler._validate_link( self.link, self.settings.link_expiry_buffer_secs ) - - session = requests.Session() - session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) - session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) - - try: - # Get the file via HTTP request - response = session.get( - self.link.fileLink, - timeout=self.settings.download_timeout, - verify=self._ssl_options.tls_verify, - headers=self.link.httpHeaders - # TODO: Pass cert from `self._ssl_options` - ) + + with self._http_client.execute( + method=HttpMethod.GET, + url=self.link.fileLink, + timeout=self.settings.download_timeout, + verify=self._ssl_options.tls_verify, + headers=self.link.httpHeaders + # TODO: Pass cert from `self._ssl_options` + ) as response: response.raise_for_status() # Save (and decompress if needed) the downloaded file @@ -132,9 +127,6 @@ def run(self) -> DownloadedFile: self.link.startRowOffset, self.link.rowCount, ) - finally: - if session: - session.close() @staticmethod def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int): diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 233808777..312b98b03 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -271,11 +271,12 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": return self._create_empty_table() logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows)) results = self.table.slice(0, 0) + partial_result_chunks = [results] while num_rows > 0 and self.table: # Get remaining of num_rows or the rest of the current table, whichever is smaller length = min(num_rows, self.table.num_rows - self.table_row_index) table_slice = self.table.slice(self.table_row_index, length) - results = pyarrow.concat_tables([results, table_slice]) + partial_result_chunks.append(table_slice) self.table_row_index += table_slice.num_rows # Replace current table with the next table if we are at the end of the current table @@ -285,7 +286,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": num_rows -= table_slice.num_rows logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) - return results + return pyarrow.concat_tables(partial_result_chunks, use_threads=True) def remaining_rows(self) -> "pyarrow.Table": """ @@ -298,15 +299,16 @@ def remaining_rows(self) -> "pyarrow.Table": # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() results = self.table.slice(0, 0) + partial_result_chunks = [results] while self.table: table_slice = self.table.slice( self.table_row_index, self.table.num_rows - self.table_row_index ) - results = pyarrow.concat_tables([results, table_slice]) + partial_result_chunks.append(table_slice) self.table_row_index += table_slice.num_rows self.table = self._create_next_table() self.table_row_index = 0 - return results + return pyarrow.concat_tables(partial_result_chunks, use_threads=True) def _create_next_table(self) -> Union["pyarrow.Table", None]: logger.debug( From 8cdfd88014de53ba654a74f4f3a5c9a2bf0c4a7f Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 15 Jul 2025 12:39:42 +0530 Subject: [PATCH 2/4] Perf update --- src/databricks/sql/client.py | 2 +- src/databricks/sql/cloudfetch/downloader.py | 2 +- tests/unit/test_telemetry.py | 108 +++++++++++--------- 3 files changed, 60 insertions(+), 52 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 4425ec131..b4162113e 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1514,7 +1514,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows - + partial_result_chunks = [results] while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 905dab6b6..4421c4770 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -89,7 +89,7 @@ def run(self) -> DownloadedFile: ResultSetDownloadHandler._validate_link( self.link, self.settings.link_expiry_buffer_secs ) - + with self._http_client.execute( method=HttpMethod.GET, url=self.link.fileLink, diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index f57f75562..dc1c7d630 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -8,7 +8,7 @@ NoopTelemetryClient, TelemetryClientFactory, TelemetryHelper, - BaseTelemetryClient + BaseTelemetryClient, ) from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow from databricks.sql.auth.authenticators import ( @@ -24,7 +24,7 @@ def mock_telemetry_client(): session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() - + return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -43,7 +43,7 @@ def test_noop_client_behavior(self): client1 = NoopTelemetryClient() client2 = NoopTelemetryClient() assert client1 is client2 - + # Test that all methods can be called without exceptions client1.export_initial_telemetry_log(MagicMock(), "test-agent") client1.export_failure_log("TestError", "Test message") @@ -58,61 +58,61 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): """Test the complete event batching and flushing flow.""" client = mock_telemetry_client client._batch_size = 3 # Small batch for testing - + # Mock the network call - with patch.object(client, '_send_telemetry') as mock_send: + with patch.object(client, "_send_telemetry") as mock_send: # Add events one by one - should not flush yet client._export_event("event1") client._export_event("event2") mock_send.assert_not_called() assert len(client._events_batch) == 2 - + # Third event should trigger flush client._export_event("event3") mock_send.assert_called_once() assert len(client._events_batch) == 0 # Batch cleared after flush - - @patch('requests.post') + + @patch("requests.post") def test_network_request_flow(self, mock_post, mock_telemetry_client): """Test the complete network request flow with authentication.""" mock_post.return_value.status_code = 200 client = mock_telemetry_client - + # Create mock events mock_events = [MagicMock() for _ in range(2)] for i, event in enumerate(mock_events): event.to_json.return_value = f'{{"event": "{i}"}}' - + # Send telemetry client._send_telemetry(mock_events) - + # Verify request was submitted to executor client._executor.submit.assert_called_once() args, kwargs = client._executor.submit.call_args - + # Verify correct function and URL assert args[0] == requests.post - assert args[1] == 'https://test-host.com/telemetry-ext' - assert kwargs['headers']['Authorization'] == 'Bearer test-token' - + assert args[1] == "https://test-host.com/telemetry-ext" + assert kwargs["headers"]["Authorization"] == "Bearer test-token" + # Verify request body structure - request_data = kwargs['data'] + request_data = kwargs["data"] assert '"uploadTime"' in request_data assert '"protoLogs"' in request_data def test_telemetry_logging_flows(self, mock_telemetry_client): """Test all telemetry logging methods work end-to-end.""" client = mock_telemetry_client - - with patch.object(client, '_export_event') as mock_export: + + with patch.object(client, "_export_event") as mock_export: # Test initial log client.export_initial_telemetry_log(MagicMock(), "test-agent") assert mock_export.call_count == 1 - + # Test failure log client.export_failure_log("TestError", "Error message") assert mock_export.call_count == 2 - + # Test latency log client.export_latency_log(150, "EXECUTE_STATEMENT", "stmt-123") assert mock_export.call_count == 3 @@ -120,14 +120,14 @@ def test_telemetry_logging_flows(self, mock_telemetry_client): def test_error_handling_resilience(self, mock_telemetry_client): """Test that telemetry errors don't break the client.""" client = mock_telemetry_client - + # Test that exceptions in telemetry don't propagate - with patch.object(client, '_export_event', side_effect=Exception("Test error")): + with patch.object(client, "_export_event", side_effect=Exception("Test error")): # These should not raise exceptions client.export_initial_telemetry_log(MagicMock(), "test-agent") client.export_failure_log("TestError", "Error message") client.export_latency_log(100, "EXECUTE_STATEMENT", "stmt-123") - + # Test executor submission failure client._executor.submit.side_effect = Exception("Thread pool error") client._send_telemetry([MagicMock()]) # Should not raise @@ -140,7 +140,7 @@ def test_system_configuration_caching(self): """Test that system configuration is cached and contains expected data.""" config1 = TelemetryHelper.get_driver_system_configuration() config2 = TelemetryHelper.get_driver_system_configuration() - + # Should be cached (same instance) assert config1 is config2 @@ -153,7 +153,7 @@ def test_auth_mechanism_detection(self): (MagicMock(), AuthMech.OTHER), # Unknown provider (None, None), ] - + for provider, expected in test_cases: assert TelemetryHelper.get_auth_mechanism(provider) == expected @@ -163,19 +163,25 @@ def test_auth_flow_detection(self): oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) oauth_with_tokens._access_token = "test-access-token" oauth_with_tokens._refresh_token = "test-refresh-token" - assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH - + assert ( + TelemetryHelper.get_auth_flow(oauth_with_tokens) + == AuthFlow.TOKEN_PASSTHROUGH + ) + # Test OAuth with browser-based auth oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) oauth_with_browser._access_token = None oauth_with_browser._refresh_token = None oauth_with_browser.oauth_manager = MagicMock() - assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION - + assert ( + TelemetryHelper.get_auth_flow(oauth_with_browser) + == AuthFlow.BROWSER_BASED_AUTHENTICATION + ) + # Test non-OAuth provider pat_auth = AccessTokenAuthProvider("test-token") assert TelemetryHelper.get_auth_flow(pat_auth) is None - + # Test None auth provider assert TelemetryHelper.get_auth_flow(None) is None @@ -202,24 +208,24 @@ def test_client_lifecycle_flow(self): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") - + # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, auth_provider=auth_provider, - host_url="test-host.com" + host_url="test-host.com", ) - + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) assert client._session_id_hex == session_id_hex - + # Close client - with patch.object(client, 'close') as mock_close: + with patch.object(client, "close") as mock_close: TelemetryClientFactory.close(session_id_hex) mock_close.assert_called_once() - + # Should get NoopTelemetryClient after close client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) @@ -227,31 +233,33 @@ def test_client_lifecycle_flow(self): def test_disabled_telemetry_flow(self): """Test that disabled telemetry uses NoopTelemetryClient.""" session_id_hex = "test-session" - + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, session_id_hex=session_id_hex, auth_provider=None, - host_url="test-host.com" + host_url="test-host.com", ) - + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) def test_factory_error_handling(self): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" - + # Simulate initialization error - with patch('databricks.sql.telemetry.telemetry_client.TelemetryClient', - side_effect=Exception("Init error")): + with patch( + "databricks.sql.telemetry.telemetry_client.TelemetryClient", + side_effect=Exception("Init error"), + ): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id, auth_provider=AccessTokenAuthProvider("token"), - host_url="test-host.com" + host_url="test-host.com", ) - + # Should fall back to NoopTelemetryClient client = TelemetryClientFactory.get_telemetry_client(session_id) assert isinstance(client, NoopTelemetryClient) @@ -260,25 +268,25 @@ def test_factory_shutdown_flow(self): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" - + # Initialize multiple clients for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session, auth_provider=AccessTokenAuthProvider("token"), - host_url="test-host.com" + host_url="test-host.com", ) - + # Factory should be initialized assert TelemetryClientFactory._initialized is True assert TelemetryClientFactory._executor is not None - + # Close first client - factory should stay initialized TelemetryClientFactory.close(session1) assert TelemetryClientFactory._initialized is True - + # Close second client - factory should shut down TelemetryClientFactory.close(session2) assert TelemetryClientFactory._initialized is False - assert TelemetryClientFactory._executor is None \ No newline at end of file + assert TelemetryClientFactory._executor is None From 7c7b121fec0c47186ff9916ac00f11f1bcbb4c8e Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 15 Jul 2025 13:52:45 +0530 Subject: [PATCH 3/4] more --- src/databricks/sql/client.py | 2 +- src/databricks/sql/result_set.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9935ee20b..e4166f117 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1230,4 +1230,4 @@ def setinputsizes(self, sizes): def setoutputsize(self, size, column=None): """Does nothing by default""" - pass \ No newline at end of file + pass diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 2ffc3f257..074877d32 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -277,6 +277,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) results = self.results.next_n_rows(size) + partial_result_chunks = [results] n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows @@ -287,11 +288,11 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables([results, partial_results]) + partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return results + return pyarrow.concat_tables(partial_result_chunks, use_threads=True) def fetchmany_columnar(self, size: int): """ @@ -322,7 +323,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows - + partial_result_chunks = [results] while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() @@ -331,7 +332,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": ): results = self.merge_columnar(results, partial_results) else: - results = pyarrow.concat_tables([results, partial_results]) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table @@ -342,7 +343,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": for name, col in zip(results.column_names, results.column_table) } return pyarrow.Table.from_pydict(data) - return results + return pyarrow.concat_tables(partial_result_chunks, use_threads=True) def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" From e9040cb9efb94a18292186a3844b331caa35820b Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 15 Jul 2025 16:31:15 +0530 Subject: [PATCH 4/4] test fix --- tests/unit/test_downloader.py | 110 ++++++++++++++++++---------------- 1 file changed, 58 insertions(+), 52 deletions(-) diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 2a3b715b5..1013ba999 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -1,9 +1,11 @@ +from contextlib import contextmanager import unittest from unittest.mock import Mock, patch, MagicMock import requests import databricks.sql.cloudfetch.downloader as downloader +from databricks.sql.common.http import DatabricksHttpClient from databricks.sql.exc import Error from databricks.sql.types import SSLOptions @@ -12,6 +14,7 @@ def create_response(**kwargs) -> requests.Response: result = requests.Response() for k, v in kwargs.items(): setattr(result, k, v) + result.close = Mock() return result @@ -52,91 +55,94 @@ def test_run_link_past_expiry_buffer(self, mock_time): mock_time.assert_called_once() - @patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None))) @patch("time.time", return_value=1000) - def test_run_get_response_not_ok(self, mock_time, mock_session): - mock_session.return_value.get.return_value = create_response(status_code=404) - + def test_run_get_response_not_ok(self, mock_time): + http_client = DatabricksHttpClient.get_instance() settings = Mock(link_expiry_buffer_secs=0, download_timeout=0) settings.download_timeout = 0 settings.use_proxy = False result_link = Mock(expiryTime=1001) - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) - with self.assertRaises(requests.exceptions.HTTPError) as context: - d.run() - self.assertTrue("404" in str(context.exception)) + with patch.object( + http_client, + "execute", + return_value=create_response(status_code=404, _content=b"1234"), + ): + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) + with self.assertRaises(requests.exceptions.HTTPError) as context: + d.run() + self.assertTrue("404" in str(context.exception)) - @patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None))) @patch("time.time", return_value=1000) - def test_run_uncompressed_successful(self, mock_time, mock_session): + def test_run_uncompressed_successful(self, mock_time): + http_client = DatabricksHttpClient.get_instance() file_bytes = b"1234567890" * 10 - mock_session.return_value.get.return_value = create_response( - status_code=200, _content=file_bytes - ) - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) - file = d.run() + with patch.object( + http_client, + "execute", + return_value=create_response(status_code=200, _content=file_bytes), + ): + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) + file = d.run() - assert file.file_bytes == b"1234567890" * 10 + assert file.file_bytes == b"1234567890" * 10 - @patch( - "requests.Session", - return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True))), - ) @patch("time.time", return_value=1000) - def test_run_compressed_successful(self, mock_time, mock_session): + def test_run_compressed_successful(self, mock_time): + http_client = DatabricksHttpClient.get_instance() file_bytes = b"1234567890" * 10 compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - mock_session.return_value.get.return_value = create_response( - status_code=200, _content=compressed_bytes - ) settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) + with patch.object( + http_client, + "execute", + return_value=create_response(status_code=200, _content=compressed_bytes), + ): + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) + file = d.run() + + assert file.file_bytes == b"1234567890" * 10 - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) - file = d.run() - - assert file.file_bytes == b"1234567890" * 10 - - @patch("requests.Session.get", side_effect=ConnectionError("foo")) @patch("time.time", return_value=1000) - def test_download_connection_error(self, mock_time, mock_session): + def test_download_connection_error(self, mock_time): + + http_client = DatabricksHttpClient.get_instance() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) - with self.assertRaises(ConnectionError): - d.run() + with patch.object(http_client, "execute", side_effect=ConnectionError("foo")): + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) + with self.assertRaises(ConnectionError): + d.run() - @patch("requests.Session.get", side_effect=TimeoutError("foo")) @patch("time.time", return_value=1000) - def test_download_timeout(self, mock_time, mock_session): + def test_download_timeout(self, mock_time): + http_client = DatabricksHttpClient.get_instance() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) - with self.assertRaises(TimeoutError): - d.run() + with patch.object(http_client, "execute", side_effect=TimeoutError("foo")): + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) + with self.assertRaises(TimeoutError): + d.run()