Skip to content

Commit e0ca049

Browse files
authored
Arrow performance optimizations (#638)
* Minor fix * Perf update * more * test fix
1 parent ba1eab3 commit e0ca049

File tree

4 files changed

+81
-80
lines changed

4 files changed

+81
-80
lines changed

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import logging
22
from dataclasses import dataclass
33

4-
import requests
5-
from requests.adapters import HTTPAdapter, Retry
4+
from requests.adapters import Retry
65
import lz4.frame
76
import time
8-
7+
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod
98
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
109
from databricks.sql.exc import Error
1110
from databricks.sql.types import SSLOptions
@@ -70,6 +69,7 @@ def __init__(
7069
self.settings = settings
7170
self.link = link
7271
self._ssl_options = ssl_options
72+
self._http_client = DatabricksHttpClient.get_instance()
7373

7474
def run(self) -> DownloadedFile:
7575
"""
@@ -90,19 +90,14 @@ def run(self) -> DownloadedFile:
9090
self.link, self.settings.link_expiry_buffer_secs
9191
)
9292

93-
session = requests.Session()
94-
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
95-
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))
96-
97-
try:
98-
# Get the file via HTTP request
99-
response = session.get(
100-
self.link.fileLink,
101-
timeout=self.settings.download_timeout,
102-
verify=self._ssl_options.tls_verify,
103-
headers=self.link.httpHeaders
104-
# TODO: Pass cert from `self._ssl_options`
105-
)
93+
with self._http_client.execute(
94+
method=HttpMethod.GET,
95+
url=self.link.fileLink,
96+
timeout=self.settings.download_timeout,
97+
verify=self._ssl_options.tls_verify,
98+
headers=self.link.httpHeaders
99+
# TODO: Pass cert from `self._ssl_options`
100+
) as response:
106101
response.raise_for_status()
107102

108103
# Save (and decompress if needed) the downloaded file
@@ -132,9 +127,6 @@ def run(self) -> DownloadedFile:
132127
self.link.startRowOffset,
133128
self.link.rowCount,
134129
)
135-
finally:
136-
if session:
137-
session.close()
138130

139131
@staticmethod
140132
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):

src/databricks/sql/result_set.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
277277
if size < 0:
278278
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
279279
results = self.results.next_n_rows(size)
280+
partial_result_chunks = [results]
280281
n_remaining_rows = size - results.num_rows
281282
self._next_row_index += results.num_rows
282283

@@ -287,11 +288,11 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
287288
):
288289
self._fill_results_buffer()
289290
partial_results = self.results.next_n_rows(n_remaining_rows)
290-
results = pyarrow.concat_tables([results, partial_results])
291+
partial_result_chunks.append(partial_results)
291292
n_remaining_rows -= partial_results.num_rows
292293
self._next_row_index += partial_results.num_rows
293294

294-
return results
295+
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
295296

296297
def fetchmany_columnar(self, size: int):
297298
"""
@@ -322,7 +323,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
322323
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
323324
results = self.results.remaining_rows()
324325
self._next_row_index += results.num_rows
325-
326+
partial_result_chunks = [results]
326327
while not self.has_been_closed_server_side and self.has_more_rows:
327328
self._fill_results_buffer()
328329
partial_results = self.results.remaining_rows()
@@ -331,7 +332,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
331332
):
332333
results = self.merge_columnar(results, partial_results)
333334
else:
334-
results = pyarrow.concat_tables([results, partial_results])
335+
partial_result_chunks.append(partial_results)
335336
self._next_row_index += partial_results.num_rows
336337

337338
# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
@@ -342,7 +343,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
342343
for name, col in zip(results.column_names, results.column_table)
343344
}
344345
return pyarrow.Table.from_pydict(data)
345-
return results
346+
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
346347

347348
def fetchall_columnar(self):
348349
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""

src/databricks/sql/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,12 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
276276
return self._create_empty_table()
277277
logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows))
278278
results = self.table.slice(0, 0)
279+
partial_result_chunks = [results]
279280
while num_rows > 0 and self.table:
280281
# Get remaining of num_rows or the rest of the current table, whichever is smaller
281282
length = min(num_rows, self.table.num_rows - self.table_row_index)
282283
table_slice = self.table.slice(self.table_row_index, length)
283-
results = pyarrow.concat_tables([results, table_slice])
284+
partial_result_chunks.append(table_slice)
284285
self.table_row_index += table_slice.num_rows
285286

286287
# Replace current table with the next table if we are at the end of the current table
@@ -290,7 +291,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
290291
num_rows -= table_slice.num_rows
291292

292293
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
293-
return results
294+
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
294295

295296
def remaining_rows(self) -> "pyarrow.Table":
296297
"""
@@ -304,15 +305,16 @@ def remaining_rows(self) -> "pyarrow.Table":
304305
# Return empty pyarrow table to cause retry of fetch
305306
return self._create_empty_table()
306307
results = self.table.slice(0, 0)
308+
partial_result_chunks = [results]
307309
while self.table:
308310
table_slice = self.table.slice(
309311
self.table_row_index, self.table.num_rows - self.table_row_index
310312
)
311-
results = pyarrow.concat_tables([results, table_slice])
313+
partial_result_chunks.append(table_slice)
312314
self.table_row_index += table_slice.num_rows
313315
self.table = self._create_next_table()
314316
self.table_row_index = 0
315-
return results
317+
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
316318

317319
def _create_next_table(self) -> Union["pyarrow.Table", None]:
318320
logger.debug(

tests/unit/test_downloader.py

Lines changed: 58 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from contextlib import contextmanager
12
import unittest
23
from unittest.mock import Mock, patch, MagicMock
34

45
import requests
56

67
import databricks.sql.cloudfetch.downloader as downloader
8+
from databricks.sql.common.http import DatabricksHttpClient
79
from databricks.sql.exc import Error
810
from databricks.sql.types import SSLOptions
911

@@ -12,6 +14,7 @@ def create_response(**kwargs) -> requests.Response:
1214
result = requests.Response()
1315
for k, v in kwargs.items():
1416
setattr(result, k, v)
17+
result.close = Mock()
1518
return result
1619

1720

@@ -52,91 +55,94 @@ def test_run_link_past_expiry_buffer(self, mock_time):
5255

5356
mock_time.assert_called_once()
5457

55-
@patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None)))
5658
@patch("time.time", return_value=1000)
57-
def test_run_get_response_not_ok(self, mock_time, mock_session):
58-
mock_session.return_value.get.return_value = create_response(status_code=404)
59-
59+
def test_run_get_response_not_ok(self, mock_time):
60+
http_client = DatabricksHttpClient.get_instance()
6061
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0)
6162
settings.download_timeout = 0
6263
settings.use_proxy = False
6364
result_link = Mock(expiryTime=1001)
6465

65-
d = downloader.ResultSetDownloadHandler(
66-
settings, result_link, ssl_options=SSLOptions()
67-
)
68-
with self.assertRaises(requests.exceptions.HTTPError) as context:
69-
d.run()
70-
self.assertTrue("404" in str(context.exception))
66+
with patch.object(
67+
http_client,
68+
"execute",
69+
return_value=create_response(status_code=404, _content=b"1234"),
70+
):
71+
d = downloader.ResultSetDownloadHandler(
72+
settings, result_link, ssl_options=SSLOptions()
73+
)
74+
with self.assertRaises(requests.exceptions.HTTPError) as context:
75+
d.run()
76+
self.assertTrue("404" in str(context.exception))
7177

72-
@patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None)))
7378
@patch("time.time", return_value=1000)
74-
def test_run_uncompressed_successful(self, mock_time, mock_session):
79+
def test_run_uncompressed_successful(self, mock_time):
80+
http_client = DatabricksHttpClient.get_instance()
7581
file_bytes = b"1234567890" * 10
76-
mock_session.return_value.get.return_value = create_response(
77-
status_code=200, _content=file_bytes
78-
)
79-
8082
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
8183
settings.is_lz4_compressed = False
8284
result_link = Mock(bytesNum=100, expiryTime=1001)
8385

84-
d = downloader.ResultSetDownloadHandler(
85-
settings, result_link, ssl_options=SSLOptions()
86-
)
87-
file = d.run()
86+
with patch.object(
87+
http_client,
88+
"execute",
89+
return_value=create_response(status_code=200, _content=file_bytes),
90+
):
91+
d = downloader.ResultSetDownloadHandler(
92+
settings, result_link, ssl_options=SSLOptions()
93+
)
94+
file = d.run()
8895

89-
assert file.file_bytes == b"1234567890" * 10
96+
assert file.file_bytes == b"1234567890" * 10
9097

91-
@patch(
92-
"requests.Session",
93-
return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True))),
94-
)
9598
@patch("time.time", return_value=1000)
96-
def test_run_compressed_successful(self, mock_time, mock_session):
99+
def test_run_compressed_successful(self, mock_time):
100+
http_client = DatabricksHttpClient.get_instance()
97101
file_bytes = b"1234567890" * 10
98102
compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
99-
mock_session.return_value.get.return_value = create_response(
100-
status_code=200, _content=compressed_bytes
101-
)
102103

103104
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
104105
settings.is_lz4_compressed = True
105106
result_link = Mock(bytesNum=100, expiryTime=1001)
107+
with patch.object(
108+
http_client,
109+
"execute",
110+
return_value=create_response(status_code=200, _content=compressed_bytes),
111+
):
112+
d = downloader.ResultSetDownloadHandler(
113+
settings, result_link, ssl_options=SSLOptions()
114+
)
115+
file = d.run()
116+
117+
assert file.file_bytes == b"1234567890" * 10
106118

107-
d = downloader.ResultSetDownloadHandler(
108-
settings, result_link, ssl_options=SSLOptions()
109-
)
110-
file = d.run()
111-
112-
assert file.file_bytes == b"1234567890" * 10
113-
114-
@patch("requests.Session.get", side_effect=ConnectionError("foo"))
115119
@patch("time.time", return_value=1000)
116-
def test_download_connection_error(self, mock_time, mock_session):
120+
def test_download_connection_error(self, mock_time):
121+
122+
http_client = DatabricksHttpClient.get_instance()
117123
settings = Mock(
118124
link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True
119125
)
120126
result_link = Mock(bytesNum=100, expiryTime=1001)
121-
mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
122127

123-
d = downloader.ResultSetDownloadHandler(
124-
settings, result_link, ssl_options=SSLOptions()
125-
)
126-
with self.assertRaises(ConnectionError):
127-
d.run()
128+
with patch.object(http_client, "execute", side_effect=ConnectionError("foo")):
129+
d = downloader.ResultSetDownloadHandler(
130+
settings, result_link, ssl_options=SSLOptions()
131+
)
132+
with self.assertRaises(ConnectionError):
133+
d.run()
128134

129-
@patch("requests.Session.get", side_effect=TimeoutError("foo"))
130135
@patch("time.time", return_value=1000)
131-
def test_download_timeout(self, mock_time, mock_session):
136+
def test_download_timeout(self, mock_time):
137+
http_client = DatabricksHttpClient.get_instance()
132138
settings = Mock(
133139
link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True
134140
)
135141
result_link = Mock(bytesNum=100, expiryTime=1001)
136-
mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
137142

138-
d = downloader.ResultSetDownloadHandler(
139-
settings, result_link, ssl_options=SSLOptions()
140-
)
141-
with self.assertRaises(TimeoutError):
142-
d.run()
143+
with patch.object(http_client, "execute", side_effect=TimeoutError("foo")):
144+
d = downloader.ResultSetDownloadHandler(
145+
settings, result_link, ssl_options=SSLOptions()
146+
)
147+
with self.assertRaises(TimeoutError):
148+
d.run()

0 commit comments

Comments
 (0)