Skip to content

Commit 0cbfae6

Browse files
committed
Minor fix
1 parent 576eafc commit 0cbfae6

File tree

3 files changed

+25
-30
lines changed

3 files changed

+25
-30
lines changed

src/databricks/sql/client.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,19 +1454,19 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
14541454
results = self.results.next_n_rows(size)
14551455
n_remaining_rows = size - results.num_rows
14561456
self._next_row_index += results.num_rows
1457-
1457+
partial_result_chunks = [results]
14581458
while (
14591459
n_remaining_rows > 0
14601460
and not self.has_been_closed_server_side
14611461
and self.has_more_rows
14621462
):
14631463
self._fill_results_buffer()
14641464
partial_results = self.results.next_n_rows(n_remaining_rows)
1465-
results = pyarrow.concat_tables([results, partial_results])
1465+
partial_result_chunks.append(partial_results)
14661466
n_remaining_rows -= partial_results.num_rows
14671467
self._next_row_index += partial_results.num_rows
14681468

1469-
return results
1469+
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
14701470

14711471
def merge_columnar(self, result1, result2):
14721472
"""
@@ -1514,7 +1514,8 @@ def fetchall_arrow(self) -> "pyarrow.Table":
15141514
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
15151515
results = self.results.remaining_rows()
15161516
self._next_row_index += results.num_rows
1517-
1517+
1518+
partial_result_chunks = [results]
15181519
while not self.has_been_closed_server_side and self.has_more_rows:
15191520
self._fill_results_buffer()
15201521
partial_results = self.results.remaining_rows()
@@ -1523,7 +1524,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
15231524
):
15241525
results = self.merge_columnar(results, partial_results)
15251526
else:
1526-
results = pyarrow.concat_tables([results, partial_results])
1527+
partial_result_chunks.append(partial_results)
15271528
self._next_row_index += partial_results.num_rows
15281529

15291530
# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
@@ -1534,7 +1535,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
15341535
for name, col in zip(results.column_names, results.column_table)
15351536
}
15361537
return pyarrow.Table.from_pydict(data)
1537-
return results
1538+
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
15381539

15391540
def fetchall_columnar(self):
15401541
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import logging
22
from dataclasses import dataclass
33

4-
import requests
5-
from requests.adapters import HTTPAdapter, Retry
4+
from requests.adapters import Retry
65
import lz4.frame
76
import time
8-
7+
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod
98
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
109
from databricks.sql.exc import Error
1110
from databricks.sql.types import SSLOptions
@@ -70,6 +69,7 @@ def __init__(
7069
self.settings = settings
7170
self.link = link
7271
self._ssl_options = ssl_options
72+
self._http_client = DatabricksHttpClient.get_instance()
7373

7474
def run(self) -> DownloadedFile:
7575
"""
@@ -89,20 +89,15 @@ def run(self) -> DownloadedFile:
8989
ResultSetDownloadHandler._validate_link(
9090
self.link, self.settings.link_expiry_buffer_secs
9191
)
92-
93-
session = requests.Session()
94-
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
95-
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))
96-
97-
try:
98-
# Get the file via HTTP request
99-
response = session.get(
100-
self.link.fileLink,
101-
timeout=self.settings.download_timeout,
102-
verify=self._ssl_options.tls_verify,
103-
headers=self.link.httpHeaders
104-
# TODO: Pass cert from `self._ssl_options`
105-
)
92+
93+
with self._http_client.execute(
94+
method=HttpMethod.GET,
95+
url=self.link.fileLink,
96+
timeout=self.settings.download_timeout,
97+
verify=self._ssl_options.tls_verify,
98+
headers=self.link.httpHeaders
99+
# TODO: Pass cert from `self._ssl_options`
100+
) as response:
106101
response.raise_for_status()
107102

108103
# Save (and decompress if needed) the downloaded file
@@ -132,9 +127,6 @@ def run(self) -> DownloadedFile:
132127
self.link.startRowOffset,
133128
self.link.rowCount,
134129
)
135-
finally:
136-
if session:
137-
session.close()
138130

139131
@staticmethod
140132
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):

src/databricks/sql/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,12 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
271271
return self._create_empty_table()
272272
logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows))
273273
results = self.table.slice(0, 0)
274+
partial_result_chunks = [results]
274275
while num_rows > 0 and self.table:
275276
# Get remaining of num_rows or the rest of the current table, whichever is smaller
276277
length = min(num_rows, self.table.num_rows - self.table_row_index)
277278
table_slice = self.table.slice(self.table_row_index, length)
278-
results = pyarrow.concat_tables([results, table_slice])
279+
partial_result_chunks.append(table_slice)
279280
self.table_row_index += table_slice.num_rows
280281

281282
# Replace current table with the next table if we are at the end of the current table
@@ -285,7 +286,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
285286
num_rows -= table_slice.num_rows
286287

287288
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
288-
return results
289+
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
289290

290291
def remaining_rows(self) -> "pyarrow.Table":
291292
"""
@@ -298,15 +299,16 @@ def remaining_rows(self) -> "pyarrow.Table":
298299
# Return empty pyarrow table to cause retry of fetch
299300
return self._create_empty_table()
300301
results = self.table.slice(0, 0)
302+
partial_result_chunks = [results]
301303
while self.table:
302304
table_slice = self.table.slice(
303305
self.table_row_index, self.table.num_rows - self.table_row_index
304306
)
305-
results = pyarrow.concat_tables([results, table_slice])
307+
partial_result_chunks.append(table_slice)
306308
self.table_row_index += table_slice.num_rows
307309
self.table = self._create_next_table()
308310
self.table_row_index = 0
309-
return results
311+
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
310312

311313
def _create_next_table(self) -> Union["pyarrow.Table", None]:
312314
logger.debug(

0 commit comments

Comments
 (0)