diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 228e07d6c..4421c4770 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -1,11 +1,10 @@ import logging from dataclasses import dataclass -import requests -from requests.adapters import HTTPAdapter, Retry +from requests.adapters import Retry import lz4.frame import time - +from databricks.sql.common.http import DatabricksHttpClient, HttpMethod from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.exc import Error from databricks.sql.types import SSLOptions @@ -70,6 +69,7 @@ def __init__( self.settings = settings self.link = link self._ssl_options = ssl_options + self._http_client = DatabricksHttpClient.get_instance() def run(self) -> DownloadedFile: """ @@ -90,19 +90,14 @@ def run(self) -> DownloadedFile: self.link, self.settings.link_expiry_buffer_secs ) - session = requests.Session() - session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) - session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) - - try: - # Get the file via HTTP request - response = session.get( - self.link.fileLink, - timeout=self.settings.download_timeout, - verify=self._ssl_options.tls_verify, - headers=self.link.httpHeaders - # TODO: Pass cert from `self._ssl_options` - ) + with self._http_client.execute( + method=HttpMethod.GET, + url=self.link.fileLink, + timeout=self.settings.download_timeout, + verify=self._ssl_options.tls_verify, + headers=self.link.httpHeaders + # TODO: Pass cert from `self._ssl_options` + ) as response: response.raise_for_status() # Save (and decompress if needed) the downloaded file @@ -132,9 +127,6 @@ def run(self) -> DownloadedFile: self.link.startRowOffset, self.link.rowCount, ) - finally: - if session: - session.close() @staticmethod def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int): diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 2ffc3f257..074877d32 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -277,6 +277,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) results = self.results.next_n_rows(size) + partial_result_chunks = [results] n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows @@ -287,11 +288,11 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables([results, partial_results]) + partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return results + return pyarrow.concat_tables(partial_result_chunks, use_threads=True) def fetchmany_columnar(self, size: int): """ @@ -322,7 +323,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows - + partial_result_chunks = [results] while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() @@ -331,7 +332,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": ): results = self.merge_columnar(results, partial_results) else: - results = pyarrow.concat_tables([results, partial_results]) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table @@ -342,7 +343,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": for name, col in zip(results.column_names, results.column_table) } return pyarrow.Table.from_pydict(data) - return results + return pyarrow.concat_tables(partial_result_chunks, use_threads=True) def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index f39885ac6..a3e3e1dd0 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -276,11 +276,12 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": return self._create_empty_table() logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows)) results = self.table.slice(0, 0) + partial_result_chunks = [results] while num_rows > 0 and self.table: # Get remaining of num_rows or the rest of the current table, whichever is smaller length = min(num_rows, self.table.num_rows - self.table_row_index) table_slice = self.table.slice(self.table_row_index, length) - results = pyarrow.concat_tables([results, table_slice]) + partial_result_chunks.append(table_slice) self.table_row_index += table_slice.num_rows # Replace current table with the next table if we are at the end of the current table @@ -290,7 +291,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": num_rows -= table_slice.num_rows logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) - return results + return pyarrow.concat_tables(partial_result_chunks, use_threads=True) def remaining_rows(self) -> "pyarrow.Table": """ @@ -304,15 +305,16 @@ def remaining_rows(self) -> "pyarrow.Table": # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() results = self.table.slice(0, 0) + partial_result_chunks = [results] while self.table: table_slice = self.table.slice( self.table_row_index, self.table.num_rows - self.table_row_index ) - results = pyarrow.concat_tables([results, table_slice]) + partial_result_chunks.append(table_slice) self.table_row_index += table_slice.num_rows self.table = self._create_next_table() self.table_row_index = 0 - return results + return pyarrow.concat_tables(partial_result_chunks, use_threads=True) def _create_next_table(self) -> Union["pyarrow.Table", None]: logger.debug( diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 2a3b715b5..1013ba999 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -1,9 +1,11 @@ +from contextlib import contextmanager import unittest from unittest.mock import Mock, patch, MagicMock import requests import databricks.sql.cloudfetch.downloader as downloader +from databricks.sql.common.http import DatabricksHttpClient from databricks.sql.exc import Error from databricks.sql.types import SSLOptions @@ -12,6 +14,7 @@ def create_response(**kwargs) -> requests.Response: result = requests.Response() for k, v in kwargs.items(): setattr(result, k, v) + result.close = Mock() return result @@ -52,91 +55,94 @@ def test_run_link_past_expiry_buffer(self, mock_time): mock_time.assert_called_once() - @patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None))) @patch("time.time", return_value=1000) - def test_run_get_response_not_ok(self, mock_time, mock_session): - mock_session.return_value.get.return_value = create_response(status_code=404) - + def test_run_get_response_not_ok(self, mock_time): + http_client = DatabricksHttpClient.get_instance() settings = Mock(link_expiry_buffer_secs=0, download_timeout=0) settings.download_timeout = 0 settings.use_proxy = False result_link = Mock(expiryTime=1001) - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) - with self.assertRaises(requests.exceptions.HTTPError) as context: - d.run() - self.assertTrue("404" in str(context.exception)) + with patch.object( + http_client, + "execute", + return_value=create_response(status_code=404, _content=b"1234"), + ): + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) + with self.assertRaises(requests.exceptions.HTTPError) as context: + d.run() + self.assertTrue("404" in str(context.exception)) - @patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None))) @patch("time.time", return_value=1000) - def test_run_uncompressed_successful(self, mock_time, mock_session): + def test_run_uncompressed_successful(self, mock_time): + http_client = DatabricksHttpClient.get_instance() file_bytes = b"1234567890" * 10 - mock_session.return_value.get.return_value = create_response( - status_code=200, _content=file_bytes - ) - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) - file = d.run() + with patch.object( + http_client, + "execute", + return_value=create_response(status_code=200, _content=file_bytes), + ): + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) + file = d.run() - assert file.file_bytes == b"1234567890" * 10 + assert file.file_bytes == b"1234567890" * 10 - @patch( - "requests.Session", - return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True))), - ) @patch("time.time", return_value=1000) - def test_run_compressed_successful(self, mock_time, mock_session): + def test_run_compressed_successful(self, mock_time): + http_client = DatabricksHttpClient.get_instance() file_bytes = b"1234567890" * 10 compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - mock_session.return_value.get.return_value = create_response( - status_code=200, _content=compressed_bytes - ) settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) + with patch.object( + http_client, + "execute", + return_value=create_response(status_code=200, _content=compressed_bytes), + ): + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) + file = d.run() + + assert file.file_bytes == b"1234567890" * 10 - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) - file = d.run() - - assert file.file_bytes == b"1234567890" * 10 - - @patch("requests.Session.get", side_effect=ConnectionError("foo")) @patch("time.time", return_value=1000) - def test_download_connection_error(self, mock_time, mock_session): + def test_download_connection_error(self, mock_time): + + http_client = DatabricksHttpClient.get_instance() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) - with self.assertRaises(ConnectionError): - d.run() + with patch.object(http_client, "execute", side_effect=ConnectionError("foo")): + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) + with self.assertRaises(ConnectionError): + d.run() - @patch("requests.Session.get", side_effect=TimeoutError("foo")) @patch("time.time", return_value=1000) - def test_download_timeout(self, mock_time, mock_session): + def test_download_timeout(self, mock_time): + http_client = DatabricksHttpClient.get_instance() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) - with self.assertRaises(TimeoutError): - d.run() + with patch.object(http_client, "execute", side_effect=TimeoutError("foo")): + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) + with self.assertRaises(TimeoutError): + d.run()