From ea3c337f9f88d11e882d8bb96d6ec082e00c0604 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Fri, 11 Jul 2025 11:09:42 +0530 Subject: [PATCH 01/12] Added logs --- src/databricks/sql/client.py | 3 +++ src/databricks/sql/cloudfetch/download_manager.py | 2 +- src/databricks/sql/cloudfetch/downloader.py | 14 +++++++++++++- src/databricks/sql/utils.py | 10 +++++++++- 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1f409bb07..a1159f469 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1515,7 +1515,10 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows + print("Server side has more rows", self.has_more_rows) + while not self.has_been_closed_server_side and self.has_more_rows: + print(f"RESULT SIZE TOTAL {results.num_rows}") self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 7e96cd323..ad1fbbe7c 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -78,13 +78,13 @@ def get_next_downloaded_file( next_row_offset, file.start_row_offset, file.row_count ) ) - return file def _schedule_downloads(self): """ While download queue has a capacity, peek pending links and submit them to thread pool. """ + print("Schedule_downloads") logger.debug("ResultFileDownloadManager: schedule downloads") while (len(self._download_tasks) < self._max_download_threads) and ( len(self._pending_links) > 0 diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 228e07d6c..700e514e6 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -95,6 +95,10 @@ def run(self) -> DownloadedFile: session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) try: + print_text = [ + + ] + start_time = time.time() # Get the file via HTTP request response = session.get( self.link.fileLink, @@ -104,7 +108,8 @@ def run(self) -> DownloadedFile: # TODO: Pass cert from `self._ssl_options` ) response.raise_for_status() - + end_time = time.time() + print_text.append(f"Downloaded file in {end_time - start_time} seconds") # Save (and decompress if needed) the downloaded file compressed_data = response.content decompressed_data = ( @@ -127,6 +132,13 @@ def run(self) -> DownloadedFile: ) ) + print_text.append( + f"Downloaded file startRowOffset - {self.link.startRowOffset} - rowCount - {self.link.rowCount}" + ) + + for text in print_text: + print(text) + return DownloadedFile( decompressed_data, self.link.startRowOffset, diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 0ce2fa169..d9089de09 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -235,6 +235,8 @@ def __init__( result_link.startRowOffset, result_link.rowCount ) ) + print("Initial Setup Cloudfetch Queue") + print(f"No of result links - {len(result_links)}") self.download_manager = ResultFileDownloadManager( links=result_links or [], max_download_threads=self.max_download_threads, @@ -288,6 +290,9 @@ def remaining_rows(self) -> "pyarrow.Table": # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() results = self.table.slice(0, 0) + + print("remaining_rows call") + print(f"self.table.num_rows - {self.table.num_rows}") while self.table: table_slice = self.table.slice( self.table_row_index, self.table.num_rows - self.table_row_index @@ -296,6 +301,7 @@ def remaining_rows(self) -> "pyarrow.Table": self.table_row_index += table_slice.num_rows self.table = self._create_next_table() self.table_row_index = 0 + print(f"results.num_rows - {results.num_rows}") return results def _create_next_table(self) -> Union["pyarrow.Table", None]: @@ -334,7 +340,9 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]: arrow_table.num_rows, self.start_row_index ) ) - + + print("_create_next_table") + print(f"arrow_table.num_rows - {arrow_table.num_rows}") return arrow_table def _create_empty_table(self) -> "pyarrow.Table": From c7492cc78c0ea7299d2de967a55b2ca53666c87f Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Sat, 12 Jul 2025 11:33:54 +0530 Subject: [PATCH 02/12] LAST CHECKPOINT --- src/databricks/sql/client.py | 51 +++++------- src/databricks/sql/cloudfetch/downloader.py | 89 ++++++++++++++++----- src/databricks/sql/common/http.py | 5 +- src/databricks/sql/utils.py | 43 ++++++++-- 4 files changed, 132 insertions(+), 56 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index a1159f469..9d6a00c9f 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -31,6 +31,8 @@ transform_paramstyle, ColumnTable, ColumnQueue, + concat_chunked_tables, + merge_columnar, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -1454,36 +1456,25 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows + partial_result_chunks = [results] + TOTAL_SIZE = results.num_rows while ( n_remaining_rows > 0 and not self.has_been_closed_server_side and self.has_more_rows ): + print(f"TOTAL DATA ROWS {TOTAL_SIZE}") self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables([results, partial_results]) + partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows + TOTAL_SIZE += partial_results.num_rows - return results - - def merge_columnar(self, result1, result2): - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) + return concat_chunked_tables(partial_result_chunks) + + def fetchmany_columnar(self, size: int): """ @@ -1504,7 +1495,7 @@ def fetchmany_columnar(self, size: int): ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = self.merge_columnar(results, partial_results) + results = merge_columnar(results, partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows @@ -1514,20 +1505,20 @@ def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows - - print("Server side has more rows", self.has_more_rows) + partial_result_chunks = [results] + print("Server side has more rows", self.has_more_rows) + TOTAL_SIZE = results.num_rows + while not self.has_been_closed_server_side and self.has_more_rows: - print(f"RESULT SIZE TOTAL {results.num_rows}") + print(f"TOTAL DATA ROWS {TOTAL_SIZE}") self._fill_results_buffer() partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): - results = self.merge_columnar(results, partial_results) - else: - results = pyarrow.concat_tables([results, partial_results]) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows + TOTAL_SIZE += partial_results.num_rows + + results = concat_chunked_tables(partial_result_chunks) # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table # Valid only for metadata commands result set @@ -1547,7 +1538,7 @@ def fetchall_columnar(self): while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - results = self.merge_columnar(results, partial_results) + results = merge_columnar(results, partial_results) self._next_row_index += partial_results.num_rows return results diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 700e514e6..3bd337a58 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -9,6 +9,7 @@ from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.exc import Error from databricks.sql.types import SSLOptions +from databricks.sql.common.http import DatabricksHttpClient, HttpMethod logger = logging.getLogger(__name__) @@ -70,6 +71,7 @@ def __init__( self.settings = settings self.link = link self._ssl_options = ssl_options + self._http_client = DatabricksHttpClient.get_instance() def run(self) -> DownloadedFile: """ @@ -89,27 +91,20 @@ def run(self) -> DownloadedFile: ResultSetDownloadHandler._validate_link( self.link, self.settings.link_expiry_buffer_secs ) - - session = requests.Session() - session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) - session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) - - try: + + with self._http_client.execute( + method=HttpMethod.GET, + url=self.link.fileLink, + timeout=self.settings.download_timeout, + verify=self._ssl_options.tls_verify, + headers=self.link.httpHeaders + ) as response: print_text = [ ] - start_time = time.time() - # Get the file via HTTP request - response = session.get( - self.link.fileLink, - timeout=self.settings.download_timeout, - verify=self._ssl_options.tls_verify, - headers=self.link.httpHeaders - # TODO: Pass cert from `self._ssl_options` - ) + response.raise_for_status() - end_time = time.time() - print_text.append(f"Downloaded file in {end_time - start_time} seconds") + # Save (and decompress if needed) the downloaded file compressed_data = response.content decompressed_data = ( @@ -144,9 +139,63 @@ def run(self) -> DownloadedFile: self.link.startRowOffset, self.link.rowCount, ) - finally: - if session: - session.close() + # session = requests.Session() + # session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) + # session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) + + # try: + # print_text = [ + + # ] + # start_time = time.time() + # # Get the file via HTTP request + # response = session.get( + # self.link.fileLink, + # timeout=self.settings.download_timeout, + # verify=self._ssl_options.tls_verify, + # headers=self.link.httpHeaders + # # TODO: Pass cert from `self._ssl_options` + # ) + # response.raise_for_status() + # end_time = time.time() + # print_text.append(f"Downloaded file in {end_time - start_time} seconds") + # # Save (and decompress if needed) the downloaded file + # compressed_data = response.content + # decompressed_data = ( + # ResultSetDownloadHandler._decompress_data(compressed_data) + # if self.settings.is_lz4_compressed + # else compressed_data + # ) + + # # The size of the downloaded file should match the size specified from TSparkArrowResultLink + # if len(decompressed_data) != self.link.bytesNum: + # logger.debug( + # "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( + # len(decompressed_data), self.link.bytesNum + # ) + # ) + + # logger.debug( + # "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( + # self.link.startRowOffset, self.link.rowCount + # ) + # ) + + # print_text.append( + # f"Downloaded file startRowOffset - {self.link.startRowOffset} - rowCount - {self.link.rowCount}" + # ) + + # for text in print_text: + # print(text) + + # return DownloadedFile( + # decompressed_data, + # self.link.startRowOffset, + # self.link.rowCount, + # ) + # finally: + # if session: + # session.close() @staticmethod def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int): diff --git a/src/databricks/sql/common/http.py b/src/databricks/sql/common/http.py index ec4e3341a..92e80e7fb 100644 --- a/src/databricks/sql/common/http.py +++ b/src/databricks/sql/common/http.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from typing import Generator import logging - +import time logger = logging.getLogger(__name__) @@ -70,7 +70,10 @@ def execute( logger.info("Executing HTTP request: %s with url: %s", method.value, url) response = None try: + start_time = time.time() response = self.session.request(method.value, url, **kwargs) + end_time = time.time() + print(f"Downloaded file in {end_time - start_time} seconds") yield response except Exception as e: logger.error("Error executing HTTP request in DatabricksHttpClient: %s", e) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index d9089de09..52dc2b1ce 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -137,6 +137,11 @@ def __eq__(self, other): ) +class ArrowStreamTable: + def __init__(self, arrow_stream, num_rows): + self.arrow_stream = arrow_stream + self.num_rows = num_rows + class ColumnQueue(ResultSetQueue): def __init__(self, column_table: ColumnTable): self.column_table = column_table @@ -263,11 +268,12 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": return self._create_empty_table() logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows)) results = self.table.slice(0, 0) + partial_result_chunks = [results] while num_rows > 0 and self.table: # Get remaining of num_rows or the rest of the current table, whichever is smaller length = min(num_rows, self.table.num_rows - self.table_row_index) table_slice = self.table.slice(self.table_row_index, length) - results = pyarrow.concat_tables([results, table_slice]) + partial_result_chunks.append(table_slice) self.table_row_index += table_slice.num_rows # Replace current table with the next table if we are at the end of the current table @@ -277,7 +283,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": num_rows -= table_slice.num_rows logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) - return results + return concat_chunked_tables(partial_result_chunks) def remaining_rows(self) -> "pyarrow.Table": """ @@ -290,19 +296,19 @@ def remaining_rows(self) -> "pyarrow.Table": # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() results = self.table.slice(0, 0) - + partial_result_chunks = [results] print("remaining_rows call") print(f"self.table.num_rows - {self.table.num_rows}") while self.table: table_slice = self.table.slice( self.table_row_index, self.table.num_rows - self.table_row_index ) - results = pyarrow.concat_tables([results, table_slice]) + partial_result_chunks.append(table_slice) self.table_row_index += table_slice.num_rows self.table = self._create_next_table() self.table_row_index = 0 print(f"results.num_rows - {results.num_rows}") - return results + return concat_chunked_tables(partial_result_chunks) def _create_next_table(self) -> Union["pyarrow.Table", None]: logger.debug( @@ -771,3 +777,30 @@ def _create_python_tuple(t_col_value_wrapper): result[i] = None return tuple(result) + + +def concat_chunked_tables(tables: List[Union["pyarrow.Table", ColumnTable]]) -> Union["pyarrow.Table", ColumnTable]: + if isinstance(tables[0], ColumnTable): + base_table = tables[0] + for table in tables[1:]: + base_table = merge_columnar(base_table, table) + return base_table + else: + return pyarrow.concat_tables(tables) + +def merge_columnar(result1: ColumnTable, result2: ColumnTable) -> ColumnTable: + """ + Function to merge / combining the columnar results into a single result + :param result1: + :param result2: + :return: + """ + + if result1.column_names != result2.column_names: + raise ValueError("The columns in the results don't match") + + merged_result = [ + result1.column_table[i] + result2.column_table[i] + for i in range(result1.num_columns) + ] + return ColumnTable(merged_result, result1.column_names) \ No newline at end of file From f599ebc2e0b7aeb7f70550b170cb7462525ae75a Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Sat, 12 Jul 2025 16:57:05 +0530 Subject: [PATCH 03/12] WORKING fetchall_arrow --- src/databricks/sql/client.py | 27 ++--- src/databricks/sql/utils.py | 185 +++++++++++++++++++++++++++-------- 2 files changed, 156 insertions(+), 56 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9d6a00c9f..dd898232c 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1456,7 +1456,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows - partial_result_chunks = [results] + # partial_result_chunks = [results] TOTAL_SIZE = results.num_rows while ( @@ -1467,12 +1467,13 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": print(f"TOTAL DATA ROWS {TOTAL_SIZE}") self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - partial_result_chunks.append(partial_results) + results.append(partial_results) + # partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows TOTAL_SIZE += partial_results.num_rows - return concat_chunked_tables(partial_result_chunks) + return results.to_arrow_table() @@ -1506,7 +1507,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - partial_result_chunks = [results] + # partial_result_chunks = [results] print("Server side has more rows", self.has_more_rows) TOTAL_SIZE = results.num_rows @@ -1514,21 +1515,21 @@ def fetchall_arrow(self) -> "pyarrow.Table": print(f"TOTAL DATA ROWS {TOTAL_SIZE}") self._fill_results_buffer() partial_results = self.results.remaining_rows() - partial_result_chunks.append(partial_results) + results.append(partial_results) self._next_row_index += partial_results.num_rows TOTAL_SIZE += partial_results.num_rows - results = concat_chunked_tables(partial_result_chunks) + # results = concat_chunked_tables(partial_result_chunks) # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: - data = { - name: col - for name, col in zip(results.column_names, results.column_table) - } - return pyarrow.Table.from_pydict(data) - return results + # if isinstance(results, ColumnTable) and pyarrow: + # data = { + # name: col + # for name, col in zip(results.column_names, results.column_table) + # } + # return pyarrow.Table.from_pydict(data) + return results.to_arrow_table() def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 52dc2b1ce..f5ca17ada 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -123,6 +123,23 @@ def num_columns(self): def get_item(self, col_index, row_index): return self.column_table[col_index][row_index] + + def append(self, other: ColumnTable): + if self.column_names != other.column_names: + raise ValueError("The columns in the results don't match") + + merged_result = [ + self.column_table[i] + other.column_table[i] + for i in range(self.num_columns) + ] + return ColumnTable(merged_result, self.column_names) + + def to_arrow_table(self): + data = { + name: col + for name, col in zip(self.column_names, self.column_table) + } + return pyarrow.Table.from_pydict(data) def slice(self, curr_index, length): sliced_column_table = [ @@ -138,10 +155,72 @@ def __eq__(self, other): class ArrowStreamTable: - def __init__(self, arrow_stream, num_rows): - self.arrow_stream = arrow_stream + def __init__(self, record_batches, num_rows, column_description): + self.record_batches = record_batches self.num_rows = num_rows - + self.column_description = column_description + self.curr_batch_index = 0 + + def append(self, other: ArrowStreamTable): + if self.column_description != other.column_description: + raise ValueError("ArrowStreamTable: Column descriptions do not match for the tables to be appended") + + self.record_batches.extend(other.record_batches) + self.num_rows += other.num_rows + + def next_n_rows(self, req_num_rows: int): + consumed_batches = [] + consumed_num_rows = 0 + while req_num_rows > 0 and self.record_batches: + current = self.record_batches[0] + if current.num_rows <= req_num_rows: + consumed_batches.append(current) + req_num_rows -= current.num_rows + consumed_num_rows += current.num_rows + self.num_rows -= current.num_rows + self.record_batches.pop(0) + else: + consumed_batches.append(current.slice(0, req_num_rows)) + self.record_batches[0] = current.slice(req_num_rows) + self.num_rows -= req_num_rows + consumed_num_rows += req_num_rows + req_num_rows = 0 + + return ArrowStreamTable(consumed_batches, consumed_num_rows, self.column_description) + + + def convert_decimals_in_record_batch(self,batch: "pyarrow.RecordBatch") -> "pyarrow.RecordBatch": + new_columns = [] + new_fields = [] + + for i, col in enumerate(batch.columns): + field = batch.schema.field(i) + + if self.column_description[i][1] == "decimal": + precision, scale = self.column_description[i][4], self.column_description[i][5] + assert scale is not None and precision is not None + dtype = pyarrow.decimal128(precision, scale) + + new_col = col.cast(dtype) + new_field = field.with_type(dtype) + + new_columns.append(new_col) + new_fields.append(new_field) + else: + new_columns.append(col) + new_fields.append(field) + + new_schema = pyarrow.schema(new_fields) + return pyarrow.RecordBatch.from_arrays(new_columns, schema=new_schema) + + def to_arrow_table(self) -> "pyarrow.Table": + def batch_generator(): + for batch in self.record_batches: + yield self.convert_decimals_in_record_batch(batch) + + return pyarrow.Table.from_batches(batch_generator()) + + class ColumnQueue(ResultSetQueue): def __init__(self, column_table: ColumnTable): self.column_table = column_table @@ -250,9 +329,9 @@ def __init__( ) self.table = self._create_next_table() - self.table_row_index = 0 + # self.table_row_index = 0 - def next_n_rows(self, num_rows: int) -> "pyarrow.Table": + def next_n_rows(self, num_rows: int): """ Get up to the next n rows of the cloud fetch Arrow dataframes. @@ -262,55 +341,62 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": Returns: pyarrow.Table """ + results = self._create_empty_table() if not self.table: logger.debug("CloudFetchQueue: no more rows available") # Return empty pyarrow table to cause retry of fetch - return self._create_empty_table() + return results logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows)) - results = self.table.slice(0, 0) - partial_result_chunks = [results] + + # results = self.table.slice(0, 0) + # partial_result_chunks = [results] while num_rows > 0 and self.table: # Get remaining of num_rows or the rest of the current table, whichever is smaller - length = min(num_rows, self.table.num_rows - self.table_row_index) - table_slice = self.table.slice(self.table_row_index, length) - partial_result_chunks.append(table_slice) - self.table_row_index += table_slice.num_rows + length = min(num_rows, self.table.num_rows) + nxt_result = self.table.next_n_rows(length) + results.append(nxt_result) + num_rows -= nxt_result.num_rows + # table_slice = self.table.slice(self.table_row_index, length) + # partial_result_chunks.append(table_slice) + # self.table_row_index += table_slice.num_rows # Replace current table with the next table if we are at the end of the current table - if self.table_row_index == self.table.num_rows: + if self.table.num_rows == 0: self.table = self._create_next_table() - self.table_row_index = 0 - num_rows -= table_slice.num_rows + # self.table_row_index = 0 + # num_rows -= table_slice.num_rows logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) - return concat_chunked_tables(partial_result_chunks) + return results - def remaining_rows(self) -> "pyarrow.Table": + def remaining_rows(self): """ Get all remaining rows of the cloud fetch Arrow dataframes. Returns: pyarrow.Table """ + result = self._create_empty_table() if not self.table: # Return empty pyarrow table to cause retry of fetch - return self._create_empty_table() - results = self.table.slice(0, 0) - partial_result_chunks = [results] + return result + # results = self.table.slice(0, 0) + # result = self._create_empty_table() + print("remaining_rows call") print(f"self.table.num_rows - {self.table.num_rows}") while self.table: - table_slice = self.table.slice( - self.table_row_index, self.table.num_rows - self.table_row_index - ) - partial_result_chunks.append(table_slice) - self.table_row_index += table_slice.num_rows + # table_slice = self.table.slice( + # self.table_row_index, self.table.num_rows - self.table_row_index + # ) + result.append(self.table) + # self.table_row_index += table_slice.num_rows self.table = self._create_next_table() - self.table_row_index = 0 - print(f"results.num_rows - {results.num_rows}") - return concat_chunked_tables(partial_result_chunks) + # self.table_row_index = 0 + print(f"result.num_rows - {result.num_rows}") + return result - def _create_next_table(self) -> Union["pyarrow.Table", None]: + def _create_next_table(self) -> ArrowStreamTable: logger.debug( "CloudFetchQueue: Trying to get downloaded file for row {}".format( self.start_row_index @@ -328,32 +414,41 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]: ) # None signals no more Arrow tables can be built from the remaining handlers if any remain return None - arrow_table = create_arrow_table_from_arrow_file( - downloaded_file.file_bytes, self.description - ) + + arrow_stream_table = ArrowStreamTable( + list(pyarrow.ipc.open_stream(downloaded_file.file_bytes)), + downloaded_file.row_count, + self.description) + # arrow_table = create_arrow_table_from_arrow_file( + # downloaded_file.file_bytes, self.description + # ) # The server rarely prepares the exact number of rows requested by the client in cloud fetch. # Subsequently, we drop the extraneous rows in the last file if more rows are retrieved than requested - if arrow_table.num_rows > downloaded_file.row_count: - arrow_table = arrow_table.slice(0, downloaded_file.row_count) + # if arrow_table.num_rows > downloaded_file.row_count: + # arrow_table = arrow_table.slice(0, downloaded_file.row_count) # At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows - assert downloaded_file.row_count == arrow_table.num_rows - self.start_row_index += arrow_table.num_rows + # assert downloaded_file.row_count == arrow_table.num_rows + # self.start_row_index += arrow_table.num_rows + self.start_row_index += arrow_stream_table.num_rows logger.debug( "CloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format( - arrow_table.num_rows, self.start_row_index + arrow_stream_table.num_rows, self.start_row_index ) ) print("_create_next_table") - print(f"arrow_table.num_rows - {arrow_table.num_rows}") - return arrow_table + print(f"arrow_stream_table.num_rows - {arrow_stream_table.num_rows}") + return arrow_stream_table - def _create_empty_table(self) -> "pyarrow.Table": + def _create_empty_table(self) -> ArrowStreamTable: # Create a 0-row table with just the schema bytes - return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) + return ArrowStreamTable( + list(pyarrow.ipc.open_stream(self.schema_bytes)), + 0, + self.description) ExecuteResponse = namedtuple( @@ -612,7 +707,6 @@ def create_arrow_table_from_arrow_file( arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) - def convert_arrow_based_file_to_arrow_table(file_bytes: bytes): try: return pyarrow.ipc.open_stream(file_bytes).read_all() @@ -779,12 +873,17 @@ def _create_python_tuple(t_col_value_wrapper): return tuple(result) -def concat_chunked_tables(tables: List[Union["pyarrow.Table", ColumnTable]]) -> Union["pyarrow.Table", ColumnTable]: +def concat_chunked_tables(tables: List[Union["pyarrow.Table", ColumnTable, ArrowStreamTable]]) -> Union["pyarrow.Table", ColumnTable, ArrowStreamTable]: if isinstance(tables[0], ColumnTable): base_table = tables[0] for table in tables[1:]: base_table = merge_columnar(base_table, table) return base_table + elif isinstance(tables[0], ArrowStreamTable): + base_table = tables[0] + for table in tables[1:]: + base_table = base_table.append(table) + return base_table else: return pyarrow.concat_tables(tables) From ddc2940feff32317750dce7ccb279533a1f1e6a9 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Sat, 12 Jul 2025 17:46:29 +0530 Subject: [PATCH 04/12] eff conv --- src/databricks/sql/client.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index dd898232c..5c89a95fa 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1403,14 +1403,13 @@ def _convert_columnar_table(self, table): return result - def _convert_arrow_table(self, table): + def _convert_arrow_table(self, table: "pyarrow.Table"): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] + columns_as_lists = [col.to_pylist() for col in table.itercolumns()] + return [ResultRow(*row) for row in zip(*columns_as_lists)] # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types From ba72da07d6e2210f210c5e4b479233a06fe78d4d Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Sat, 12 Jul 2025 17:56:22 +0530 Subject: [PATCH 05/12] More timings --- src/databricks/sql/client.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 5c89a95fa..719edf65f 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1408,8 +1408,12 @@ def _convert_arrow_table(self, table: "pyarrow.Table"): ResultRow = Row(*column_names) if self.connection.disable_pandas is True: + start_time = time.time() columns_as_lists = [col.to_pylist() for col in table.itercolumns()] - return [ResultRow(*row) for row in zip(*columns_as_lists)] + res = [ResultRow(*row) for row in zip(*columns_as_lists)] + end_time = time.time() + print(f"Time taken to convert arrow table to list: {end_time - start_time} seconds") + return res # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types From 45b85a86ad933d51467297073e11313638e21d12 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Sat, 12 Jul 2025 18:17:31 +0530 Subject: [PATCH 06/12] More opt --- src/databricks/sql/client.py | 72 ++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 719edf65f..6ed2b2cca 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1407,42 +1407,42 @@ def _convert_arrow_table(self, table: "pyarrow.Table"): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) - if self.connection.disable_pandas is True: - start_time = time.time() - columns_as_lists = [col.to_pylist() for col in table.itercolumns()] - res = [ResultRow(*row) for row in zip(*columns_as_lists)] - end_time = time.time() - print(f"Time taken to convert arrow table to list: {end_time - start_time} seconds") - return res - - # Need to use nullable types, as otherwise type can change when there are missing values. - # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { - pyarrow.int8(): pandas.Int8Dtype(), - pyarrow.int16(): pandas.Int16Dtype(), - pyarrow.int32(): pandas.Int32Dtype(), - pyarrow.int64(): pandas.Int64Dtype(), - pyarrow.uint8(): pandas.UInt8Dtype(), - pyarrow.uint16(): pandas.UInt16Dtype(), - pyarrow.uint32(): pandas.UInt32Dtype(), - pyarrow.uint64(): pandas.UInt64Dtype(), - pyarrow.bool_(): pandas.BooleanDtype(), - pyarrow.float32(): pandas.Float32Dtype(), - pyarrow.float64(): pandas.Float64Dtype(), - pyarrow.string(): pandas.StringDtype(), - } - - # Need to rename columns, as the to_pandas function cannot handle duplicate column names - table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) - - res = df.to_numpy(na_value=None, dtype="object") - return [ResultRow(*v) for v in res] + # if self.connection.disable_pandas is True: + start_time = time.time() + columns_as_lists = [col.to_pylist() for col in table.itercolumns()] + res = [ResultRow(*row) for row in zip(*columns_as_lists)] + end_time = time.time() + print(f"Time taken to convert arrow table to list: {end_time - start_time} seconds") + return res + + # # Need to use nullable types, as otherwise type can change when there are missing values. + # # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + # dtype_mapping = { + # pyarrow.int8(): pandas.Int8Dtype(), + # pyarrow.int16(): pandas.Int16Dtype(), + # pyarrow.int32(): pandas.Int32Dtype(), + # pyarrow.int64(): pandas.Int64Dtype(), + # pyarrow.uint8(): pandas.UInt8Dtype(), + # pyarrow.uint16(): pandas.UInt16Dtype(), + # pyarrow.uint32(): pandas.UInt32Dtype(), + # pyarrow.uint64(): pandas.UInt64Dtype(), + # pyarrow.bool_(): pandas.BooleanDtype(), + # pyarrow.float32(): pandas.Float32Dtype(), + # pyarrow.float64(): pandas.Float64Dtype(), + # pyarrow.string(): pandas.StringDtype(), + # } + + # # Need to rename columns, as the to_pandas function cannot handle duplicate column names + # table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + # df = table_renamed.to_pandas( + # types_mapper=dtype_mapping.get, + # date_as_object=True, + # timestamp_as_object=True, + # ) + + # res = df.to_numpy(na_value=None, dtype="object") + # return [ResultRow(*v) for v in res] @property def rownumber(self): From 63661f23735613afa84b16c25dc9dcd4f7465100 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Sat, 12 Jul 2025 18:26:19 +0530 Subject: [PATCH 07/12] prev was better --- src/databricks/sql/client.py | 70 +++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 6ed2b2cca..7791cf8f1 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1407,42 +1407,46 @@ def _convert_arrow_table(self, table: "pyarrow.Table"): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) - # if self.connection.disable_pandas is True: + if self.connection.disable_pandas is True: + start_time = time.time() + columns_as_lists = [col.to_pylist() for col in table.itercolumns()] + res = [ResultRow(*row) for row in zip(*columns_as_lists)] + end_time = time.time() + print(f"Time taken to convert arrow table to list: {end_time - start_time} seconds") + return res + start_time = time.time() - columns_as_lists = [col.to_pylist() for col in table.itercolumns()] - res = [ResultRow(*row) for row in zip(*columns_as_lists)] + # Need to use nullable types, as otherwise type can change when there are missing values. + # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + # Need to rename columns, as the to_pandas function cannot handle duplicate column names + table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + df = table_renamed.to_pandas( + types_mapper=dtype_mapping.get, + date_as_object=True, + timestamp_as_object=True, + ) + + res = df.to_numpy(na_value=None, dtype="object") + tmp_res = [ResultRow(*v) for v in res] end_time = time.time() print(f"Time taken to convert arrow table to list: {end_time - start_time} seconds") - return res - - # # Need to use nullable types, as otherwise type can change when there are missing values. - # # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - # dtype_mapping = { - # pyarrow.int8(): pandas.Int8Dtype(), - # pyarrow.int16(): pandas.Int16Dtype(), - # pyarrow.int32(): pandas.Int32Dtype(), - # pyarrow.int64(): pandas.Int64Dtype(), - # pyarrow.uint8(): pandas.UInt8Dtype(), - # pyarrow.uint16(): pandas.UInt16Dtype(), - # pyarrow.uint32(): pandas.UInt32Dtype(), - # pyarrow.uint64(): pandas.UInt64Dtype(), - # pyarrow.bool_(): pandas.BooleanDtype(), - # pyarrow.float32(): pandas.Float32Dtype(), - # pyarrow.float64(): pandas.Float64Dtype(), - # pyarrow.string(): pandas.StringDtype(), - # } - - # # Need to rename columns, as the to_pandas function cannot handle duplicate column names - # table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - # df = table_renamed.to_pandas( - # types_mapper=dtype_mapping.get, - # date_as_object=True, - # timestamp_as_object=True, - # ) - - # res = df.to_numpy(na_value=None, dtype="object") - # return [ResultRow(*v) for v in res] + return tmp_res @property def rownumber(self): From 06551a0f20b604f1c2e7bb7644d60de19eda45c5 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Mon, 14 Jul 2025 13:42:12 +0530 Subject: [PATCH 08/12] more comments --- src/databricks/sql/client.py | 40 +++++++++++++++++-- .../sql/cloudfetch/download_manager.py | 4 +- src/databricks/sql/cloudfetch/downloader.py | 14 +++---- src/databricks/sql/common/http.py | 2 +- src/databricks/sql/utils.py | 23 +++++++---- 5 files changed, 62 insertions(+), 21 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 7791cf8f1..c5c3070e5 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1402,17 +1402,42 @@ def _convert_columnar_table(self, table): result.append(ResultRow(*curr_row)) return result + + def print_mem(self): + import os + import psutil + + process = psutil.Process(os.getpid()) + mem_info = process.memory_info() + total_mem_mb = mem_info.rss / 1024 / 1024 + cpu_percent = process.cpu_percent(interval=0.1) + print(f"Total memory usage: {total_mem_mb:.2f} MB") + print(f"CPU percent: {cpu_percent:.2f}%") + # total_size_bytes = table.get_total_buffer_size() + # total_size_mb = total_size_bytes / (1024 * 1024) + + # print(f"Total PyArrow table size: {total_size_bytes} bytes ({total_size_mb:.2f} MB)") def _convert_arrow_table(self, table: "pyarrow.Table"): + import sys + from pympler import asizeof + + self.print_mem() + print(f"Memory size table: {table.nbytes / (1024 ** 2):.2f} MB") + # Convert to MB for easier reading column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) if self.connection.disable_pandas is True: start_time = time.time() columns_as_lists = [col.to_pylist() for col in table.itercolumns()] + self.print_mem() + print(f"Memory size columns_as_lists: {sum(sys.getsizeof(col) for col in columns_as_lists) / (1024 ** 2):.2f} MB") res = [ResultRow(*row) for row in zip(*columns_as_lists)] + self.print_mem() end_time = time.time() print(f"Time taken to convert arrow table to list: {end_time - start_time} seconds") + print(f"Memory size res: {sum(sys.getsizeof(row) for row in res) / (1024 ** 2):.2f} MB") return res start_time = time.time() @@ -1436,14 +1461,23 @@ def _convert_arrow_table(self, table: "pyarrow.Table"): # Need to rename columns, as the to_pandas function cannot handle duplicate column names table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + print(f"Memory size table_renamed: {table_renamed.nbytes / (1024 ** 2):.2f} MB") df = table_renamed.to_pandas( types_mapper=dtype_mapping.get, date_as_object=True, timestamp_as_object=True, + self_destruct=True, ) + print(f"Memory size df: {df.memory_usage(deep=True).sum() / (1024 ** 2):.2f} MB") + self.print_mem() + # del table_renamed res = df.to_numpy(na_value=None, dtype="object") + print(f"Memory size res: {res.nbytes / (1024 ** 2):.2f} MB") + self.print_mem() + # del df tmp_res = [ResultRow(*v) for v in res] + self.print_mem() end_time = time.time() print(f"Time taken to convert arrow table to list: {end_time - start_time} seconds") return tmp_res @@ -1471,7 +1505,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": and not self.has_been_closed_server_side and self.has_more_rows ): - print(f"TOTAL DATA ROWS {TOTAL_SIZE}") + # print(f"TOTAL DATA ROWS {TOTAL_SIZE}") self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) results.append(partial_results) @@ -1515,11 +1549,11 @@ def fetchall_arrow(self) -> "pyarrow.Table": self._next_row_index += results.num_rows # partial_result_chunks = [results] - print("Server side has more rows", self.has_more_rows) + # print("Server side has more rows", self.has_more_rows) TOTAL_SIZE = results.num_rows while not self.has_been_closed_server_side and self.has_more_rows: - print(f"TOTAL DATA ROWS {TOTAL_SIZE}") + # print(f"TOTAL DATA ROWS {TOTAL_SIZE}") self._fill_results_buffer() partial_results = self.results.remaining_rows() results.append(partial_results) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index ad1fbbe7c..64401dc9c 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -84,8 +84,8 @@ def _schedule_downloads(self): """ While download queue has a capacity, peek pending links and submit them to thread pool. """ - print("Schedule_downloads") - logger.debug("ResultFileDownloadManager: schedule downloads") + # print("Schedule_downloads") + # logger.debug("ResultFileDownloadManager: schedule downloads") while (len(self._download_tasks) < self._max_download_threads) and ( len(self._pending_links) > 0 ): diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 3bd337a58..fddca7c21 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -99,9 +99,9 @@ def run(self) -> DownloadedFile: verify=self._ssl_options.tls_verify, headers=self.link.httpHeaders ) as response: - print_text = [ + # print_text = [ - ] + # ] response.raise_for_status() @@ -127,12 +127,12 @@ def run(self) -> DownloadedFile: ) ) - print_text.append( - f"Downloaded file startRowOffset - {self.link.startRowOffset} - rowCount - {self.link.rowCount}" - ) + # print_text.append( + # f"Downloaded file startRowOffset - {self.link.startRowOffset} - rowCount - {self.link.rowCount}" + # ) - for text in print_text: - print(text) + # for text in print_text: + # print(text) return DownloadedFile( decompressed_data, diff --git a/src/databricks/sql/common/http.py b/src/databricks/sql/common/http.py index 92e80e7fb..3437f1c9b 100644 --- a/src/databricks/sql/common/http.py +++ b/src/databricks/sql/common/http.py @@ -73,7 +73,7 @@ def execute( start_time = time.time() response = self.session.request(method.value, url, **kwargs) end_time = time.time() - print(f"Downloaded file in {end_time - start_time} seconds") + # print(f"Downloaded file in {end_time - start_time} seconds") yield response except Exception as e: logger.error("Error executing HTTP request in DatabricksHttpClient: %s", e) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index f5ca17ada..617857f19 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -219,7 +219,12 @@ def batch_generator(): yield self.convert_decimals_in_record_batch(batch) return pyarrow.Table.from_batches(batch_generator()) - + + def remove_extraneous_rows(self): + num_rows_in_data = sum(batch.num_rows for batch in self.record_batches) + if num_rows_in_data > self.num_rows: + self.record_batches = self.record_batches[:self.num_rows] + self.num_rows = self.num_rows class ColumnQueue(ResultSetQueue): def __init__(self, column_table: ColumnTable): @@ -319,8 +324,8 @@ def __init__( result_link.startRowOffset, result_link.rowCount ) ) - print("Initial Setup Cloudfetch Queue") - print(f"No of result links - {len(result_links)}") + # print("Initial Setup Cloudfetch Queue") + # print(f"No of result links - {len(result_links)}") self.download_manager = ResultFileDownloadManager( links=result_links or [], max_download_threads=self.max_download_threads, @@ -383,8 +388,8 @@ def remaining_rows(self): # results = self.table.slice(0, 0) # result = self._create_empty_table() - print("remaining_rows call") - print(f"self.table.num_rows - {self.table.num_rows}") + # print("remaining_rows call") + # print(f"self.table.num_rows - {self.table.num_rows}") while self.table: # table_slice = self.table.slice( # self.table_row_index, self.table.num_rows - self.table_row_index @@ -393,7 +398,7 @@ def remaining_rows(self): # self.table_row_index += table_slice.num_rows self.table = self._create_next_table() # self.table_row_index = 0 - print(f"result.num_rows - {result.num_rows}") + # print(f"result.num_rows - {result.num_rows}") return result def _create_next_table(self) -> ArrowStreamTable: @@ -419,6 +424,8 @@ def _create_next_table(self) -> ArrowStreamTable: list(pyarrow.ipc.open_stream(downloaded_file.file_bytes)), downloaded_file.row_count, self.description) + + arrow_stream_table.remove_extraneous_rows() # arrow_table = create_arrow_table_from_arrow_file( # downloaded_file.file_bytes, self.description # ) @@ -439,8 +446,8 @@ def _create_next_table(self) -> ArrowStreamTable: ) ) - print("_create_next_table") - print(f"arrow_stream_table.num_rows - {arrow_stream_table.num_rows}") + # print("_create_next_table") + # print(f"arrow_stream_table.num_rows - {arrow_stream_table.num_rows}") return arrow_stream_table def _create_empty_table(self) -> ArrowStreamTable: From e726c33d1d3510bb41093c7d36182b9947befc8b Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Mon, 14 Jul 2025 14:14:09 +0530 Subject: [PATCH 09/12] remove extra riws --- src/databricks/sql/utils.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 617857f19..fb218c45b 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -155,7 +155,7 @@ def __eq__(self, other): class ArrowStreamTable: - def __init__(self, record_batches, num_rows, column_description): + def __init__(self, record_batches: List["pyarrow.RecordBatch"], num_rows: int, column_description): self.record_batches = record_batches self.num_rows = num_rows self.column_description = column_description @@ -222,9 +222,17 @@ def batch_generator(): def remove_extraneous_rows(self): num_rows_in_data = sum(batch.num_rows for batch in self.record_batches) - if num_rows_in_data > self.num_rows: - self.record_batches = self.record_batches[:self.num_rows] - self.num_rows = self.num_rows + rows_to_delete = num_rows_in_data - self.num_rows + while rows_to_delete > 0 and self.record_batches: + last_batch = self.record_batches[-1] + if last_batch.num_rows <= rows_to_delete: + self.record_batches.pop() + rows_to_delete -= last_batch.num_rows + else: + keep_rows = last_batch.num_rows - rows_to_delete + self.record_batches[-1] = last_batch.slice(0, keep_rows) + rows_to_delete = 0 + class ColumnQueue(ResultSetQueue): def __init__(self, column_table: ColumnTable): From cfc047f5b59baa50824b14430f42ac1b4b446fec Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Mon, 14 Jul 2025 14:45:54 +0530 Subject: [PATCH 10/12] More fix --- src/databricks/sql/utils.py | 56 ++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index fb218c45b..806578aea 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -74,13 +74,14 @@ def build_queue( ResultSetQueue """ if row_set_type == TSparkRowSetType.ARROW_BASED_SET: - arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( + arrow_record_batches, n_valid_rows = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes ) - converted_arrow_table = convert_decimals_in_arrow_table( - arrow_table, description - ) - return ArrowQueue(converted_arrow_table, n_valid_rows) + # converted_arrow_table = convert_decimals_in_arrow_table( + # arrow_table, description + # ) + arrow_stream_table = ArrowStreamTable(arrow_record_batches, n_valid_rows, description) + return ArrowQueue(arrow_stream_table, n_valid_rows, description) elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: column_table, column_names = convert_column_based_set_to_column_table( t_row_set.columns, description @@ -159,7 +160,6 @@ def __init__(self, record_batches: List["pyarrow.RecordBatch"], num_rows: int, c self.record_batches = record_batches self.num_rows = num_rows self.column_description = column_description - self.curr_batch_index = 0 def append(self, other: ArrowStreamTable): if self.column_description != other.column_description: @@ -187,7 +187,9 @@ def next_n_rows(self, req_num_rows: int): req_num_rows = 0 return ArrowStreamTable(consumed_batches, consumed_num_rows, self.column_description) - + + def remaining_rows(self): + return self def convert_decimals_in_record_batch(self,batch: "pyarrow.RecordBatch") -> "pyarrow.RecordBatch": new_columns = [] @@ -258,9 +260,9 @@ def remaining_rows(self): class ArrowQueue(ResultSetQueue): def __init__( self, - arrow_table: "pyarrow.Table", + arrow_stream_table: ArrowStreamTable, n_valid_rows: int, - start_row_index: int = 0, + column_description, ): """ A queue-like wrapper over an Arrow table @@ -269,25 +271,27 @@ def __init__( :param n_valid_rows: The index of the last valid row in the table :param start_row_index: The first row in the table we should start fetching from """ - self.cur_row_index = start_row_index - self.arrow_table = arrow_table + self.arrow_stream_table = arrow_stream_table self.n_valid_rows = n_valid_rows + self.column_description = column_description - def next_n_rows(self, num_rows: int) -> "pyarrow.Table": + def next_n_rows(self, num_rows: int): """Get upto the next n rows of the Arrow dataframe""" - length = min(num_rows, self.n_valid_rows - self.cur_row_index) - # Note that the table.slice API is not the same as Python's slice - # The second argument should be length, not end index - slice = self.arrow_table.slice(self.cur_row_index, length) - self.cur_row_index += slice.num_rows - return slice + return self.arrow_stream_table.next_n_rows(num_rows) + # length = min(num_rows, self.n_valid_rows - self.cur_row_index) + # # Note that the table.slice API is not the same as Python's slice + # # The second argument should be length, not end index + # slice = self.arrow_table.slice(self.cur_row_index, length) + # self.cur_row_index += slice.num_rows + # return slice - def remaining_rows(self) -> "pyarrow.Table": - slice = self.arrow_table.slice( - self.cur_row_index, self.n_valid_rows - self.cur_row_index - ) - self.cur_row_index += slice.num_rows - return slice + def remaining_rows(self): + return self.arrow_stream_table.remaining_rows() + # slice = self.arrow_table.slice( + # self.cur_row_index, self.n_valid_rows - self.cur_row_index + # ) + # self.cur_row_index += slice.num_rows + # return slice class CloudFetchQueue(ResultSetQueue): @@ -740,8 +744,8 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema if lz4_compressed else arrow_batch.batch ) - arrow_table = pyarrow.ipc.open_stream(ba).read_all() - return arrow_table, n_rows + arrow_record_batches = list(pyarrow.ipc.open_stream(ba)) + return arrow_record_batches, n_rows def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": From eb14f9502d296071f1f74ac1c10f78b4a4710fd7 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Mon, 14 Jul 2025 21:26:09 +0530 Subject: [PATCH 11/12] nit --- src/databricks/sql/client.py | 82 ++--------- src/databricks/sql/cloudfetch/downloader.py | 67 --------- src/databricks/sql/utils.py | 153 ++++++++------------ 3 files changed, 70 insertions(+), 232 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c5c3070e5..9a7673d84 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1391,7 +1391,7 @@ def _fill_results_buffer(self): self.results = results self.has_more_rows = has_more_rows - def _convert_columnar_table(self, table): + def _convert_columnar_table(self, table: ColumnTable): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) result = [] @@ -1402,45 +1402,16 @@ def _convert_columnar_table(self, table): result.append(ResultRow(*curr_row)) return result - - def print_mem(self): - import os - import psutil - - process = psutil.Process(os.getpid()) - mem_info = process.memory_info() - total_mem_mb = mem_info.rss / 1024 / 1024 - cpu_percent = process.cpu_percent(interval=0.1) - print(f"Total memory usage: {total_mem_mb:.2f} MB") - print(f"CPU percent: {cpu_percent:.2f}%") - # total_size_bytes = table.get_total_buffer_size() - # total_size_mb = total_size_bytes / (1024 * 1024) - - # print(f"Total PyArrow table size: {total_size_bytes} bytes ({total_size_mb:.2f} MB)") - + def _convert_arrow_table(self, table: "pyarrow.Table"): - import sys - from pympler import asizeof - - self.print_mem() - print(f"Memory size table: {table.nbytes / (1024 ** 2):.2f} MB") - # Convert to MB for easier reading + column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) if self.connection.disable_pandas is True: - start_time = time.time() columns_as_lists = [col.to_pylist() for col in table.itercolumns()] - self.print_mem() - print(f"Memory size columns_as_lists: {sum(sys.getsizeof(col) for col in columns_as_lists) / (1024 ** 2):.2f} MB") - res = [ResultRow(*row) for row in zip(*columns_as_lists)] - self.print_mem() - end_time = time.time() - print(f"Time taken to convert arrow table to list: {end_time - start_time} seconds") - print(f"Memory size res: {sum(sys.getsizeof(row) for row in res) / (1024 ** 2):.2f} MB") - return res + return [ResultRow(*row) for row in zip(*columns_as_lists)] - start_time = time.time() # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html @@ -1461,31 +1432,20 @@ def _convert_arrow_table(self, table: "pyarrow.Table"): # Need to rename columns, as the to_pandas function cannot handle duplicate column names table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - print(f"Memory size table_renamed: {table_renamed.nbytes / (1024 ** 2):.2f} MB") df = table_renamed.to_pandas( types_mapper=dtype_mapping.get, date_as_object=True, timestamp_as_object=True, self_destruct=True, ) - print(f"Memory size df: {df.memory_usage(deep=True).sum() / (1024 ** 2):.2f} MB") - self.print_mem() - # del table_renamed res = df.to_numpy(na_value=None, dtype="object") - print(f"Memory size res: {res.nbytes / (1024 ** 2):.2f} MB") - self.print_mem() - # del df - tmp_res = [ResultRow(*v) for v in res] - self.print_mem() - end_time = time.time() - print(f"Time taken to convert arrow table to list: {end_time - start_time} seconds") - return tmp_res + return [ResultRow(*v) for v in res] @property def rownumber(self): return self._next_row_index - + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -1497,26 +1457,18 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows - # partial_result_chunks = [results] - - TOTAL_SIZE = results.num_rows while ( n_remaining_rows > 0 and not self.has_been_closed_server_side and self.has_more_rows ): - # print(f"TOTAL DATA ROWS {TOTAL_SIZE}") self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) results.append(partial_results) - # partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - TOTAL_SIZE += partial_results.num_rows return results.to_arrow_table() - - def fetchmany_columnar(self, size: int): """ @@ -1537,39 +1489,23 @@ def fetchmany_columnar(self, size: int): ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = merge_columnar(results, partial_results) + results.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows return results - + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows - # partial_result_chunks = [results] - # print("Server side has more rows", self.has_more_rows) - TOTAL_SIZE = results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: - # print(f"TOTAL DATA ROWS {TOTAL_SIZE}") self._fill_results_buffer() partial_results = self.results.remaining_rows() results.append(partial_results) self._next_row_index += partial_results.num_rows - TOTAL_SIZE += partial_results.num_rows - - # results = concat_chunked_tables(partial_result_chunks) - - # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table - # Valid only for metadata commands result set - # if isinstance(results, ColumnTable) and pyarrow: - # data = { - # name: col - # for name, col in zip(results.column_names, results.column_table) - # } - # return pyarrow.Table.from_pydict(data) + return results.to_arrow_table() def fetchall_columnar(self): diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index fddca7c21..10f1f55dd 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -99,10 +99,7 @@ def run(self) -> DownloadedFile: verify=self._ssl_options.tls_verify, headers=self.link.httpHeaders ) as response: - # print_text = [ - # ] - response.raise_for_status() # Save (and decompress if needed) the downloaded file @@ -127,75 +124,11 @@ def run(self) -> DownloadedFile: ) ) - # print_text.append( - # f"Downloaded file startRowOffset - {self.link.startRowOffset} - rowCount - {self.link.rowCount}" - # ) - - # for text in print_text: - # print(text) - return DownloadedFile( decompressed_data, self.link.startRowOffset, self.link.rowCount, ) - # session = requests.Session() - # session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) - # session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) - - # try: - # print_text = [ - - # ] - # start_time = time.time() - # # Get the file via HTTP request - # response = session.get( - # self.link.fileLink, - # timeout=self.settings.download_timeout, - # verify=self._ssl_options.tls_verify, - # headers=self.link.httpHeaders - # # TODO: Pass cert from `self._ssl_options` - # ) - # response.raise_for_status() - # end_time = time.time() - # print_text.append(f"Downloaded file in {end_time - start_time} seconds") - # # Save (and decompress if needed) the downloaded file - # compressed_data = response.content - # decompressed_data = ( - # ResultSetDownloadHandler._decompress_data(compressed_data) - # if self.settings.is_lz4_compressed - # else compressed_data - # ) - - # # The size of the downloaded file should match the size specified from TSparkArrowResultLink - # if len(decompressed_data) != self.link.bytesNum: - # logger.debug( - # "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( - # len(decompressed_data), self.link.bytesNum - # ) - # ) - - # logger.debug( - # "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( - # self.link.startRowOffset, self.link.rowCount - # ) - # ) - - # print_text.append( - # f"Downloaded file startRowOffset - {self.link.startRowOffset} - rowCount - {self.link.rowCount}" - # ) - - # for text in print_text: - # print(text) - - # return DownloadedFile( - # decompressed_data, - # self.link.startRowOffset, - # self.link.rowCount, - # ) - # finally: - # if session: - # session.close() @staticmethod def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int): diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 806578aea..fb26ff029 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -77,11 +77,8 @@ def build_queue( arrow_record_batches, n_valid_rows = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes ) - # converted_arrow_table = convert_decimals_in_arrow_table( - # arrow_table, description - # ) arrow_stream_table = ArrowStreamTable(arrow_record_batches, n_valid_rows, description) - return ArrowQueue(arrow_stream_table, n_valid_rows, description) + return ArrowQueue(arrow_stream_table) elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: column_table, column_names = convert_column_based_set_to_column_table( t_row_set.columns, description @@ -106,11 +103,34 @@ def build_queue( raise AssertionError("Row set type is not valid") -class ColumnTable: +class ResultTable(ABC): + + @abstractmethod + def next_n_rows(self, num_rows: int): + pass + + @abstractmethod + def remaining_rows(self): + pass + + @abstractmethod + def append(self, other: ResultTable): + pass + + @abstractmethod + def to_arrow_table(self) -> "pyarrow.Table": + pass + + @abstractmethod + def remove_extraneous_rows(self): + pass + +class ColumnTable(ResultTable): def __init__(self, column_table, column_names): self.column_table = column_table self.column_names = column_names - + self.curr_row_index = 0 + @property def num_rows(self): if len(self.column_table) == 0: @@ -121,9 +141,13 @@ def num_rows(self): @property def num_columns(self): return len(self.column_names) - - def get_item(self, col_index, row_index): - return self.column_table[col_index][row_index] + + def next_n_rows(self, num_rows: int): + sliced_column_table = [ + column[self.curr_row_index : self.curr_row_index + num_rows] for column in self.column_table + ] + self.curr_row_index += num_rows + return ColumnTable(sliced_column_table, self.column_names) def append(self, other: ColumnTable): if self.column_names != other.column_names: @@ -133,7 +157,17 @@ def append(self, other: ColumnTable): self.column_table[i] + other.column_table[i] for i in range(self.num_columns) ] - return ColumnTable(merged_result, self.column_names) + self.column_table = merged_result + + def remaining_rows(self): + sliced_column_table = [ + column[self.curr_row_index : ] for column in self.column_table + ] + self.curr_row_index = self.num_rows + return ColumnTable(sliced_column_table, self.column_names) + + def get_item(self, col_index, row_index): + return self.column_table[col_index][row_index] def to_arrow_table(self): data = { @@ -142,20 +176,10 @@ def to_arrow_table(self): } return pyarrow.Table.from_pydict(data) - def slice(self, curr_index, length): - sliced_column_table = [ - column[curr_index : curr_index + length] for column in self.column_table - ] - return ColumnTable(sliced_column_table, self.column_names) - - def __eq__(self, other): - return ( - self.column_table == other.column_table - and self.column_names == other.column_names - ) - + def remove_extraneous_rows(self): + pass -class ArrowStreamTable: +class ArrowStreamTable(ResultTable): def __init__(self, record_batches: List["pyarrow.RecordBatch"], num_rows: int, column_description): self.record_batches = record_batches self.num_rows = num_rows @@ -239,30 +263,18 @@ def remove_extraneous_rows(self): class ColumnQueue(ResultSetQueue): def __init__(self, column_table: ColumnTable): self.column_table = column_table - self.cur_row_index = 0 - self.n_valid_rows = column_table.num_rows def next_n_rows(self, num_rows): - length = min(num_rows, self.n_valid_rows - self.cur_row_index) - - slice = self.column_table.slice(self.cur_row_index, length) - self.cur_row_index += slice.num_rows - return slice + return self.column_table.next_n_rows(num_rows) def remaining_rows(self): - slice = self.column_table.slice( - self.cur_row_index, self.n_valid_rows - self.cur_row_index - ) - self.cur_row_index += slice.num_rows - return slice + return self.column_table.remaining_rows() class ArrowQueue(ResultSetQueue): def __init__( self, arrow_stream_table: ArrowStreamTable, - n_valid_rows: int, - column_description, ): """ A queue-like wrapper over an Arrow table @@ -271,27 +283,14 @@ def __init__( :param n_valid_rows: The index of the last valid row in the table :param start_row_index: The first row in the table we should start fetching from """ - self.arrow_stream_table = arrow_stream_table - self.n_valid_rows = n_valid_rows - self.column_description = column_description + self.arrow_stream_table = arrow_stream_table def next_n_rows(self, num_rows: int): """Get upto the next n rows of the Arrow dataframe""" return self.arrow_stream_table.next_n_rows(num_rows) - # length = min(num_rows, self.n_valid_rows - self.cur_row_index) - # # Note that the table.slice API is not the same as Python's slice - # # The second argument should be length, not end index - # slice = self.arrow_table.slice(self.cur_row_index, length) - # self.cur_row_index += slice.num_rows - # return slice def remaining_rows(self): return self.arrow_stream_table.remaining_rows() - # slice = self.arrow_table.slice( - # self.cur_row_index, self.n_valid_rows - self.cur_row_index - # ) - # self.cur_row_index += slice.num_rows - # return slice class CloudFetchQueue(ResultSetQueue): @@ -336,8 +335,7 @@ def __init__( result_link.startRowOffset, result_link.rowCount ) ) - # print("Initial Setup Cloudfetch Queue") - # print(f"No of result links - {len(result_links)}") + self.download_manager = ResultFileDownloadManager( links=result_links or [], max_download_threads=self.max_download_threads, @@ -346,7 +344,6 @@ def __init__( ) self.table = self._create_next_table() - # self.table_row_index = 0 def next_n_rows(self, num_rows: int): """ @@ -361,27 +358,19 @@ def next_n_rows(self, num_rows: int): results = self._create_empty_table() if not self.table: logger.debug("CloudFetchQueue: no more rows available") - # Return empty pyarrow table to cause retry of fetch return results logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows)) - # results = self.table.slice(0, 0) - # partial_result_chunks = [results] while num_rows > 0 and self.table: # Get remaining of num_rows or the rest of the current table, whichever is smaller length = min(num_rows, self.table.num_rows) nxt_result = self.table.next_n_rows(length) results.append(nxt_result) num_rows -= nxt_result.num_rows - # table_slice = self.table.slice(self.table_row_index, length) - # partial_result_chunks.append(table_slice) - # self.table_row_index += table_slice.num_rows # Replace current table with the next table if we are at the end of the current table if self.table.num_rows == 0: self.table = self._create_next_table() - # self.table_row_index = 0 - # num_rows -= table_slice.num_rows logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) return results @@ -397,23 +386,13 @@ def remaining_rows(self): if not self.table: # Return empty pyarrow table to cause retry of fetch return result - # results = self.table.slice(0, 0) - # result = self._create_empty_table() - - # print("remaining_rows call") - # print(f"self.table.num_rows - {self.table.num_rows}") + while self.table: - # table_slice = self.table.slice( - # self.table_row_index, self.table.num_rows - self.table_row_index - # ) result.append(self.table) - # self.table_row_index += table_slice.num_rows self.table = self._create_next_table() - # self.table_row_index = 0 - # print(f"result.num_rows - {result.num_rows}") return result - def _create_next_table(self) -> ArrowStreamTable: + def _create_next_table(self) -> ResultTable: logger.debug( "CloudFetchQueue: Trying to get downloaded file for row {}".format( self.start_row_index @@ -432,37 +411,27 @@ def _create_next_table(self) -> ArrowStreamTable: # None signals no more Arrow tables can be built from the remaining handlers if any remain return None - arrow_stream_table = ArrowStreamTable( + result_table = ArrowStreamTable( list(pyarrow.ipc.open_stream(downloaded_file.file_bytes)), downloaded_file.row_count, self.description) - arrow_stream_table.remove_extraneous_rows() - # arrow_table = create_arrow_table_from_arrow_file( - # downloaded_file.file_bytes, self.description - # ) - + # The server rarely prepares the exact number of rows requested by the client in cloud fetch. # Subsequently, we drop the extraneous rows in the last file if more rows are retrieved than requested - # if arrow_table.num_rows > downloaded_file.row_count: - # arrow_table = arrow_table.slice(0, downloaded_file.row_count) + result_table.remove_extraneous_rows() - # At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows - # assert downloaded_file.row_count == arrow_table.num_rows - # self.start_row_index += arrow_table.num_rows - self.start_row_index += arrow_stream_table.num_rows + self.start_row_index += result_table.num_rows logger.debug( "CloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format( - arrow_stream_table.num_rows, self.start_row_index + result_table.num_rows, self.start_row_index ) ) - - # print("_create_next_table") - # print(f"arrow_stream_table.num_rows - {arrow_stream_table.num_rows}") - return arrow_stream_table - def _create_empty_table(self) -> ArrowStreamTable: + return result_table + + def _create_empty_table(self) -> ResultTable: # Create a 0-row table with just the schema bytes return ArrowStreamTable( list(pyarrow.ipc.open_stream(self.schema_bytes)), From cec90b210f8d64fbc548d523879c2f991ed914a1 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Mon, 14 Jul 2025 21:45:39 +0530 Subject: [PATCH 12/12] More refractor --- src/databricks/sql/client.py | 17 +- .../sql/cloudfetch/download_manager.py | 3 +- src/databricks/sql/cloudfetch/downloader.py | 6 +- src/databricks/sql/common/http.py | 4 +- src/databricks/sql/thrift_backend.py | 20 --- src/databricks/sql/utils.py | 155 ++++++++---------- tests/unit/test_telemetry.py | 108 ++++++------ tests/unit/test_thrift_backend.py | 4 +- 8 files changed, 142 insertions(+), 175 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9a7673d84..0a1b0e24b 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -9,7 +9,6 @@ import requests import json import os -import decimal from uuid import UUID from databricks.sql import __version__ @@ -31,8 +30,6 @@ transform_paramstyle, ColumnTable, ColumnQueue, - concat_chunked_tables, - merge_columnar, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -1402,7 +1399,7 @@ def _convert_columnar_table(self, table: ColumnTable): result.append(ResultRow(*curr_row)) return result - + def _convert_arrow_table(self, table: "pyarrow.Table"): column_names = [c[0] for c in self.description] @@ -1411,7 +1408,7 @@ def _convert_arrow_table(self, table: "pyarrow.Table"): if self.connection.disable_pandas is True: columns_as_lists = [col.to_pylist() for col in table.itercolumns()] return [ResultRow(*row) for row in zip(*columns_as_lists)] - + # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html @@ -1445,7 +1442,7 @@ def _convert_arrow_table(self, table: "pyarrow.Table"): @property def rownumber(self): return self._next_row_index - + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -1494,18 +1491,18 @@ def fetchmany_columnar(self, size: int): self._next_row_index += partial_results.num_rows return results - + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows - + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results.append(partial_results) self._next_row_index += partial_results.num_rows - + return results.to_arrow_table() def fetchall_columnar(self): @@ -1516,7 +1513,7 @@ def fetchall_columnar(self): while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - results = merge_columnar(results, partial_results) + results.append(partial_results) self._next_row_index += partial_results.num_rows return results diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 64401dc9c..a8a163fa8 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -84,8 +84,7 @@ def _schedule_downloads(self): """ While download queue has a capacity, peek pending links and submit them to thread pool. """ - # print("Schedule_downloads") - # logger.debug("ResultFileDownloadManager: schedule downloads") + logger.debug("ResultFileDownloadManager: schedule downloads") while (len(self._download_tasks) < self._max_download_threads) and ( len(self._pending_links) > 0 ): diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 10f1f55dd..a30f78327 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -91,17 +91,17 @@ def run(self) -> DownloadedFile: ResultSetDownloadHandler._validate_link( self.link, self.settings.link_expiry_buffer_secs ) - + with self._http_client.execute( method=HttpMethod.GET, url=self.link.fileLink, timeout=self.settings.download_timeout, verify=self._ssl_options.tls_verify, headers=self.link.httpHeaders + # TODO: Pass cert from `self._ssl_options` ) as response: - response.raise_for_status() - + # Save (and decompress if needed) the downloaded file compressed_data = response.content decompressed_data = ( diff --git a/src/databricks/sql/common/http.py b/src/databricks/sql/common/http.py index 3437f1c9b..c0be9f3bf 100644 --- a/src/databricks/sql/common/http.py +++ b/src/databricks/sql/common/http.py @@ -8,6 +8,7 @@ from typing import Generator import logging import time + logger = logging.getLogger(__name__) @@ -70,10 +71,7 @@ def execute( logger.info("Executing HTTP request: %s with url: %s", method.value, url) response = None try: - start_time = time.time() response = self.session.request(method.value, url, **kwargs) - end_time = time.time() - # print(f"Downloaded file in {end_time - start_time} seconds") yield response except Exception as e: logger.error("Error executing HTTP request in DatabricksHttpClient: %s", e) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 78683ac31..a07f645a7 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -36,9 +36,6 @@ RequestErrorInfo, NoRetryReason, ResultSetQueueFactory, - convert_arrow_based_set_to_arrow_table, - convert_decimals_in_arrow_table, - convert_column_based_set_to_arrow_table, ) from databricks.sql.types import SSLOptions @@ -633,23 +630,6 @@ def _poll_for_status(self, op_handle): ) return self.make_request(self._client.GetOperationStatus, req) - def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, description): - if t_row_set.columns is not None: - ( - arrow_table, - num_rows, - ) = convert_column_based_set_to_arrow_table(t_row_set.columns, description) - elif t_row_set.arrowBatches is not None: - (arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table( - t_row_set.arrowBatches, lz4_compressed, schema_bytes - ) - else: - raise OperationalError( - "Unsupported TRowSet instance {}".format(t_row_set), - session_id_hex=self._session_id_hex, - ) - return convert_decimals_in_arrow_table(arrow_table, description), num_rows - def _get_metadata_resp(self, op_handle): req = ttypes.TGetResultSetMetadataReq(operationHandle=op_handle) return self.make_request(self._client.GetResultSetMetadata, req) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index fb26ff029..26adcd77a 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Mapping from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Union, Sequence +from typing import Any, Dict, List, Optional, Union, Sequence, Tuple import re import lz4.frame @@ -74,10 +74,12 @@ def build_queue( ResultSetQueue """ if row_set_type == TSparkRowSetType.ARROW_BASED_SET: - arrow_record_batches, n_valid_rows = convert_arrow_based_set_to_arrow_table( + arrow_record_batches, n_valid_rows = convert_bytes_to_record_batches( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes ) - arrow_stream_table = ArrowStreamTable(arrow_record_batches, n_valid_rows, description) + arrow_stream_table = ArrowStreamTable( + arrow_record_batches, n_valid_rows, description + ) return ArrowQueue(arrow_stream_table) elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: column_table, column_names = convert_column_based_set_to_column_table( @@ -104,11 +106,10 @@ def build_queue( class ResultTable(ABC): - @abstractmethod def next_n_rows(self, num_rows: int): pass - + @abstractmethod def remaining_rows(self): pass @@ -116,7 +117,7 @@ def remaining_rows(self): @abstractmethod def append(self, other: ResultTable): pass - + @abstractmethod def to_arrow_table(self) -> "pyarrow.Table": pass @@ -125,12 +126,13 @@ def to_arrow_table(self) -> "pyarrow.Table": def remove_extraneous_rows(self): pass + class ColumnTable(ResultTable): def __init__(self, column_table, column_names): self.column_table = column_table self.column_names = column_names self.curr_row_index = 0 - + @property def num_rows(self): if len(self.column_table) == 0: @@ -141,14 +143,15 @@ def num_rows(self): @property def num_columns(self): return len(self.column_names) - + def next_n_rows(self, num_rows: int): sliced_column_table = [ - column[self.curr_row_index : self.curr_row_index + num_rows] for column in self.column_table + column[self.curr_row_index : self.curr_row_index + num_rows] + for column in self.column_table ] self.curr_row_index += num_rows return ColumnTable(sliced_column_table, self.column_names) - + def append(self, other: ColumnTable): if self.column_names != other.column_names: raise ValueError("The columns in the results don't match") @@ -161,37 +164,42 @@ def append(self, other: ColumnTable): def remaining_rows(self): sliced_column_table = [ - column[self.curr_row_index : ] for column in self.column_table + column[self.curr_row_index :] for column in self.column_table ] self.curr_row_index = self.num_rows return ColumnTable(sliced_column_table, self.column_names) - + def get_item(self, col_index, row_index): return self.column_table[col_index][row_index] - + def to_arrow_table(self): - data = { - name: col - for name, col in zip(self.column_names, self.column_table) - } + data = {name: col for name, col in zip(self.column_names, self.column_table)} return pyarrow.Table.from_pydict(data) def remove_extraneous_rows(self): pass + class ArrowStreamTable(ResultTable): - def __init__(self, record_batches: List["pyarrow.RecordBatch"], num_rows: int, column_description): + def __init__( + self, + record_batches: List["pyarrow.RecordBatch"], + num_rows: int, + column_description, + ): self.record_batches = record_batches self.num_rows = num_rows self.column_description = column_description - + def append(self, other: ArrowStreamTable): if self.column_description != other.column_description: - raise ValueError("ArrowStreamTable: Column descriptions do not match for the tables to be appended") - + raise ValueError( + "ArrowStreamTable: Column descriptions do not match for the tables to be appended" + ) + self.record_batches.extend(other.record_batches) self.num_rows += other.num_rows - + def next_n_rows(self, req_num_rows: int): consumed_batches = [] consumed_num_rows = 0 @@ -202,20 +210,24 @@ def next_n_rows(self, req_num_rows: int): req_num_rows -= current.num_rows consumed_num_rows += current.num_rows self.num_rows -= current.num_rows - self.record_batches.pop(0) + self.record_batches.pop(0) else: consumed_batches.append(current.slice(0, req_num_rows)) - self.record_batches[0] = current.slice(req_num_rows) + self.record_batches[0] = current.slice(req_num_rows) self.num_rows -= req_num_rows consumed_num_rows += req_num_rows req_num_rows = 0 - - return ArrowStreamTable(consumed_batches, consumed_num_rows, self.column_description) - + + return ArrowStreamTable( + consumed_batches, consumed_num_rows, self.column_description + ) + def remaining_rows(self): return self - def convert_decimals_in_record_batch(self,batch: "pyarrow.RecordBatch") -> "pyarrow.RecordBatch": + def convert_decimals_in_record_batch( + self, batch: "pyarrow.RecordBatch" + ) -> "pyarrow.RecordBatch": new_columns = [] new_fields = [] @@ -223,7 +235,10 @@ def convert_decimals_in_record_batch(self,batch: "pyarrow.RecordBatch") -> "pyar field = batch.schema.field(i) if self.column_description[i][1] == "decimal": - precision, scale = self.column_description[i][4], self.column_description[i][5] + precision, scale = ( + self.column_description[i][4], + self.column_description[i][5], + ) assert scale is not None and precision is not None dtype = pyarrow.decimal128(precision, scale) @@ -245,7 +260,7 @@ def batch_generator(): yield self.convert_decimals_in_record_batch(batch) return pyarrow.Table.from_batches(batch_generator()) - + def remove_extraneous_rows(self): num_rows_in_data = sum(batch.num_rows for batch in self.record_batches) rows_to_delete = num_rows_in_data - self.num_rows @@ -259,22 +274,22 @@ def remove_extraneous_rows(self): self.record_batches[-1] = last_batch.slice(0, keep_rows) rows_to_delete = 0 - + class ColumnQueue(ResultSetQueue): - def __init__(self, column_table: ColumnTable): - self.column_table = column_table + def __init__(self, table: ColumnTable): + self.table = table def next_n_rows(self, num_rows): - return self.column_table.next_n_rows(num_rows) + return self.table.next_n_rows(num_rows) def remaining_rows(self): - return self.column_table.remaining_rows() + return self.table.remaining_rows() class ArrowQueue(ResultSetQueue): def __init__( self, - arrow_stream_table: ArrowStreamTable, + table: ArrowStreamTable, ): """ A queue-like wrapper over an Arrow table @@ -283,14 +298,14 @@ def __init__( :param n_valid_rows: The index of the last valid row in the table :param start_row_index: The first row in the table we should start fetching from """ - self.arrow_stream_table = arrow_stream_table + self.table = table def next_n_rows(self, num_rows: int): """Get upto the next n rows of the Arrow dataframe""" - return self.arrow_stream_table.next_n_rows(num_rows) + return self.table.next_n_rows(num_rows) def remaining_rows(self): - return self.arrow_stream_table.remaining_rows() + return self.table.remaining_rows() class CloudFetchQueue(ResultSetQueue): @@ -345,7 +360,7 @@ def __init__( self.table = self._create_next_table() - def next_n_rows(self, num_rows: int): + def next_n_rows(self, num_rows: int) -> ResultTable: """ Get up to the next n rows of the cloud fetch Arrow dataframes. @@ -360,7 +375,7 @@ def next_n_rows(self, num_rows: int): logger.debug("CloudFetchQueue: no more rows available") return results logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows)) - + while num_rows > 0 and self.table: # Get remaining of num_rows or the rest of the current table, whichever is smaller length = min(num_rows, self.table.num_rows) @@ -375,7 +390,7 @@ def next_n_rows(self, num_rows: int): logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) return results - def remaining_rows(self): + def remaining_rows(self) -> ResultTable: """ Get all remaining rows of the cloud fetch Arrow dataframes. @@ -386,7 +401,7 @@ def remaining_rows(self): if not self.table: # Return empty pyarrow table to cause retry of fetch return result - + while self.table: result.append(self.table) self.table = self._create_next_table() @@ -410,13 +425,13 @@ def _create_next_table(self) -> ResultTable: ) # None signals no more Arrow tables can be built from the remaining handlers if any remain return None - + result_table = ArrowStreamTable( - list(pyarrow.ipc.open_stream(downloaded_file.file_bytes)), - downloaded_file.row_count, - self.description) - - + list(pyarrow.ipc.open_stream(downloaded_file.file_bytes)), + downloaded_file.row_count, + self.description, + ) + # The server rarely prepares the exact number of rows requested by the client in cloud fetch. # Subsequently, we drop the extraneous rows in the last file if more rows are retrieved than requested result_table.remove_extraneous_rows() @@ -434,9 +449,8 @@ def _create_next_table(self) -> ResultTable: def _create_empty_table(self) -> ResultTable: # Create a 0-row table with just the schema bytes return ArrowStreamTable( - list(pyarrow.ipc.open_stream(self.schema_bytes)), - 0, - self.description) + list(pyarrow.ipc.open_stream(self.schema_bytes)), 0, self.description + ) ExecuteResponse = namedtuple( @@ -695,6 +709,7 @@ def create_arrow_table_from_arrow_file( arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) + def convert_arrow_based_file_to_arrow_table(file_bytes: bytes): try: return pyarrow.ipc.open_stream(file_bytes).read_all() @@ -702,7 +717,9 @@ def convert_arrow_based_file_to_arrow_table(file_bytes: bytes): raise RuntimeError("Failure to convert arrow based file to arrow table", e) -def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema_bytes): +def convert_bytes_to_record_batches( + arrow_batches, lz4_compressed, schema_bytes +) -> Tuple[List["pyarrow.RecordBatch"], int]: ba = bytearray() ba += schema_bytes n_rows = 0 @@ -859,35 +876,3 @@ def _create_python_tuple(t_col_value_wrapper): result[i] = None return tuple(result) - - -def concat_chunked_tables(tables: List[Union["pyarrow.Table", ColumnTable, ArrowStreamTable]]) -> Union["pyarrow.Table", ColumnTable, ArrowStreamTable]: - if isinstance(tables[0], ColumnTable): - base_table = tables[0] - for table in tables[1:]: - base_table = merge_columnar(base_table, table) - return base_table - elif isinstance(tables[0], ArrowStreamTable): - base_table = tables[0] - for table in tables[1:]: - base_table = base_table.append(table) - return base_table - else: - return pyarrow.concat_tables(tables) - -def merge_columnar(result1: ColumnTable, result2: ColumnTable) -> ColumnTable: - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) \ No newline at end of file diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index f57f75562..dc1c7d630 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -8,7 +8,7 @@ NoopTelemetryClient, TelemetryClientFactory, TelemetryHelper, - BaseTelemetryClient + BaseTelemetryClient, ) from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow from databricks.sql.auth.authenticators import ( @@ -24,7 +24,7 @@ def mock_telemetry_client(): session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() - + return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -43,7 +43,7 @@ def test_noop_client_behavior(self): client1 = NoopTelemetryClient() client2 = NoopTelemetryClient() assert client1 is client2 - + # Test that all methods can be called without exceptions client1.export_initial_telemetry_log(MagicMock(), "test-agent") client1.export_failure_log("TestError", "Test message") @@ -58,61 +58,61 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): """Test the complete event batching and flushing flow.""" client = mock_telemetry_client client._batch_size = 3 # Small batch for testing - + # Mock the network call - with patch.object(client, '_send_telemetry') as mock_send: + with patch.object(client, "_send_telemetry") as mock_send: # Add events one by one - should not flush yet client._export_event("event1") client._export_event("event2") mock_send.assert_not_called() assert len(client._events_batch) == 2 - + # Third event should trigger flush client._export_event("event3") mock_send.assert_called_once() assert len(client._events_batch) == 0 # Batch cleared after flush - - @patch('requests.post') + + @patch("requests.post") def test_network_request_flow(self, mock_post, mock_telemetry_client): """Test the complete network request flow with authentication.""" mock_post.return_value.status_code = 200 client = mock_telemetry_client - + # Create mock events mock_events = [MagicMock() for _ in range(2)] for i, event in enumerate(mock_events): event.to_json.return_value = f'{{"event": "{i}"}}' - + # Send telemetry client._send_telemetry(mock_events) - + # Verify request was submitted to executor client._executor.submit.assert_called_once() args, kwargs = client._executor.submit.call_args - + # Verify correct function and URL assert args[0] == requests.post - assert args[1] == 'https://test-host.com/telemetry-ext' - assert kwargs['headers']['Authorization'] == 'Bearer test-token' - + assert args[1] == "https://test-host.com/telemetry-ext" + assert kwargs["headers"]["Authorization"] == "Bearer test-token" + # Verify request body structure - request_data = kwargs['data'] + request_data = kwargs["data"] assert '"uploadTime"' in request_data assert '"protoLogs"' in request_data def test_telemetry_logging_flows(self, mock_telemetry_client): """Test all telemetry logging methods work end-to-end.""" client = mock_telemetry_client - - with patch.object(client, '_export_event') as mock_export: + + with patch.object(client, "_export_event") as mock_export: # Test initial log client.export_initial_telemetry_log(MagicMock(), "test-agent") assert mock_export.call_count == 1 - + # Test failure log client.export_failure_log("TestError", "Error message") assert mock_export.call_count == 2 - + # Test latency log client.export_latency_log(150, "EXECUTE_STATEMENT", "stmt-123") assert mock_export.call_count == 3 @@ -120,14 +120,14 @@ def test_telemetry_logging_flows(self, mock_telemetry_client): def test_error_handling_resilience(self, mock_telemetry_client): """Test that telemetry errors don't break the client.""" client = mock_telemetry_client - + # Test that exceptions in telemetry don't propagate - with patch.object(client, '_export_event', side_effect=Exception("Test error")): + with patch.object(client, "_export_event", side_effect=Exception("Test error")): # These should not raise exceptions client.export_initial_telemetry_log(MagicMock(), "test-agent") client.export_failure_log("TestError", "Error message") client.export_latency_log(100, "EXECUTE_STATEMENT", "stmt-123") - + # Test executor submission failure client._executor.submit.side_effect = Exception("Thread pool error") client._send_telemetry([MagicMock()]) # Should not raise @@ -140,7 +140,7 @@ def test_system_configuration_caching(self): """Test that system configuration is cached and contains expected data.""" config1 = TelemetryHelper.get_driver_system_configuration() config2 = TelemetryHelper.get_driver_system_configuration() - + # Should be cached (same instance) assert config1 is config2 @@ -153,7 +153,7 @@ def test_auth_mechanism_detection(self): (MagicMock(), AuthMech.OTHER), # Unknown provider (None, None), ] - + for provider, expected in test_cases: assert TelemetryHelper.get_auth_mechanism(provider) == expected @@ -163,19 +163,25 @@ def test_auth_flow_detection(self): oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) oauth_with_tokens._access_token = "test-access-token" oauth_with_tokens._refresh_token = "test-refresh-token" - assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH - + assert ( + TelemetryHelper.get_auth_flow(oauth_with_tokens) + == AuthFlow.TOKEN_PASSTHROUGH + ) + # Test OAuth with browser-based auth oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) oauth_with_browser._access_token = None oauth_with_browser._refresh_token = None oauth_with_browser.oauth_manager = MagicMock() - assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION - + assert ( + TelemetryHelper.get_auth_flow(oauth_with_browser) + == AuthFlow.BROWSER_BASED_AUTHENTICATION + ) + # Test non-OAuth provider pat_auth = AccessTokenAuthProvider("test-token") assert TelemetryHelper.get_auth_flow(pat_auth) is None - + # Test None auth provider assert TelemetryHelper.get_auth_flow(None) is None @@ -202,24 +208,24 @@ def test_client_lifecycle_flow(self): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") - + # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, auth_provider=auth_provider, - host_url="test-host.com" + host_url="test-host.com", ) - + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) assert client._session_id_hex == session_id_hex - + # Close client - with patch.object(client, 'close') as mock_close: + with patch.object(client, "close") as mock_close: TelemetryClientFactory.close(session_id_hex) mock_close.assert_called_once() - + # Should get NoopTelemetryClient after close client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) @@ -227,31 +233,33 @@ def test_client_lifecycle_flow(self): def test_disabled_telemetry_flow(self): """Test that disabled telemetry uses NoopTelemetryClient.""" session_id_hex = "test-session" - + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, session_id_hex=session_id_hex, auth_provider=None, - host_url="test-host.com" + host_url="test-host.com", ) - + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) def test_factory_error_handling(self): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" - + # Simulate initialization error - with patch('databricks.sql.telemetry.telemetry_client.TelemetryClient', - side_effect=Exception("Init error")): + with patch( + "databricks.sql.telemetry.telemetry_client.TelemetryClient", + side_effect=Exception("Init error"), + ): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id, auth_provider=AccessTokenAuthProvider("token"), - host_url="test-host.com" + host_url="test-host.com", ) - + # Should fall back to NoopTelemetryClient client = TelemetryClientFactory.get_telemetry_client(session_id) assert isinstance(client, NoopTelemetryClient) @@ -260,25 +268,25 @@ def test_factory_shutdown_flow(self): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" - + # Initialize multiple clients for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session, auth_provider=AccessTokenAuthProvider("token"), - host_url="test-host.com" + host_url="test-host.com", ) - + # Factory should be initialized assert TelemetryClientFactory._initialized is True assert TelemetryClientFactory._executor is not None - + # Close first client - factory should stay initialized TelemetryClientFactory.close(session1) assert TelemetryClientFactory._initialized is True - + # Close second client - factory should shut down TelemetryClientFactory.close(session2) assert TelemetryClientFactory._initialized is False - assert TelemetryClientFactory._executor is None \ No newline at end of file + assert TelemetryClientFactory._executor is None diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 458ea9a82..579fe0a35 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1468,10 +1468,10 @@ def test_convert_arrow_based_set_to_arrow_table( ttypes.TSparkArrowBatch(batch=bytearray("Testing", "utf-8"), rowCount=1) for _ in range(10) ] - utils.convert_arrow_based_set_to_arrow_table(arrow_batches, False, schema) + utils.convert_bytes_to_record_batches(arrow_batches, False, schema) lz4_decompress_mock.assert_not_called() - utils.convert_arrow_based_set_to_arrow_table(arrow_batches, True, schema) + utils.convert_bytes_to_record_batches(arrow_batches, True, schema) lz4_decompress_mock.assert_called() def test_convert_column_based_set_to_arrow_table_without_nulls(self):