diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1f409bb07..0a1b0e24b 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -9,7 +9,6 @@ import requests import json import os -import decimal from uuid import UUID from databricks.sql import __version__ @@ -1389,7 +1388,7 @@ def _fill_results_buffer(self): self.results = results self.has_more_rows = has_more_rows - def _convert_columnar_table(self, table): + def _convert_columnar_table(self, table: ColumnTable): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) result = [] @@ -1401,14 +1400,14 @@ def _convert_columnar_table(self, table): return result - def _convert_arrow_table(self, table): + def _convert_arrow_table(self, table: "pyarrow.Table"): + column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] + columns_as_lists = [col.to_pylist() for col in table.itercolumns()] + return [ResultRow(*row) for row in zip(*columns_as_lists)] # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types @@ -1434,6 +1433,7 @@ def _convert_arrow_table(self, table): types_mapper=dtype_mapping.get, date_as_object=True, timestamp_as_object=True, + self_destruct=True, ) res = df.to_numpy(na_value=None, dtype="object") @@ -1454,7 +1454,6 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows - while ( n_remaining_rows > 0 and not self.has_been_closed_server_side @@ -1462,28 +1461,11 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables([results, partial_results]) + results.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return results - - def merge_columnar(self, result1, result2): - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) + return results.to_arrow_table() def fetchmany_columnar(self, size: int): """ @@ -1504,7 +1486,7 @@ def fetchmany_columnar(self, size: int): ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = self.merge_columnar(results, partial_results) + results.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows @@ -1518,23 +1500,10 @@ def fetchall_arrow(self) -> "pyarrow.Table": while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): - results = self.merge_columnar(results, partial_results) - else: - results = pyarrow.concat_tables([results, partial_results]) + results.append(partial_results) self._next_row_index += partial_results.num_rows - # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table - # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: - data = { - name: col - for name, col in zip(results.column_names, results.column_table) - } - return pyarrow.Table.from_pydict(data) - return results + return results.to_arrow_table() def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" @@ -1544,7 +1513,7 @@ def fetchall_columnar(self): while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - results = self.merge_columnar(results, partial_results) + results.append(partial_results) self._next_row_index += partial_results.num_rows return results diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 7e96cd323..a8a163fa8 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -78,7 +78,6 @@ def get_next_downloaded_file( next_row_offset, file.start_row_offset, file.row_count ) ) - return file def _schedule_downloads(self): diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 228e07d6c..a30f78327 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -9,6 +9,7 @@ from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.exc import Error from databricks.sql.types import SSLOptions +from databricks.sql.common.http import DatabricksHttpClient, HttpMethod logger = logging.getLogger(__name__) @@ -70,6 +71,7 @@ def __init__( self.settings = settings self.link = link self._ssl_options = ssl_options + self._http_client = DatabricksHttpClient.get_instance() def run(self) -> DownloadedFile: """ @@ -90,19 +92,14 @@ def run(self) -> DownloadedFile: self.link, self.settings.link_expiry_buffer_secs ) - session = requests.Session() - session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) - session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) - - try: - # Get the file via HTTP request - response = session.get( - self.link.fileLink, - timeout=self.settings.download_timeout, - verify=self._ssl_options.tls_verify, - headers=self.link.httpHeaders - # TODO: Pass cert from `self._ssl_options` - ) + with self._http_client.execute( + method=HttpMethod.GET, + url=self.link.fileLink, + timeout=self.settings.download_timeout, + verify=self._ssl_options.tls_verify, + headers=self.link.httpHeaders + # TODO: Pass cert from `self._ssl_options` + ) as response: response.raise_for_status() # Save (and decompress if needed) the downloaded file @@ -132,9 +129,6 @@ def run(self) -> DownloadedFile: self.link.startRowOffset, self.link.rowCount, ) - finally: - if session: - session.close() @staticmethod def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int): diff --git a/src/databricks/sql/common/http.py b/src/databricks/sql/common/http.py index ec4e3341a..c0be9f3bf 100644 --- a/src/databricks/sql/common/http.py +++ b/src/databricks/sql/common/http.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from typing import Generator import logging +import time logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 78683ac31..a07f645a7 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -36,9 +36,6 @@ RequestErrorInfo, NoRetryReason, ResultSetQueueFactory, - convert_arrow_based_set_to_arrow_table, - convert_decimals_in_arrow_table, - convert_column_based_set_to_arrow_table, ) from databricks.sql.types import SSLOptions @@ -633,23 +630,6 @@ def _poll_for_status(self, op_handle): ) return self.make_request(self._client.GetOperationStatus, req) - def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, description): - if t_row_set.columns is not None: - ( - arrow_table, - num_rows, - ) = convert_column_based_set_to_arrow_table(t_row_set.columns, description) - elif t_row_set.arrowBatches is not None: - (arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table( - t_row_set.arrowBatches, lz4_compressed, schema_bytes - ) - else: - raise OperationalError( - "Unsupported TRowSet instance {}".format(t_row_set), - session_id_hex=self._session_id_hex, - ) - return convert_decimals_in_arrow_table(arrow_table, description), num_rows - def _get_metadata_resp(self, op_handle): req = ttypes.TGetResultSetMetadataReq(operationHandle=op_handle) return self.make_request(self._client.GetResultSetMetadata, req) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 0ce2fa169..26adcd77a 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Mapping from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Union, Sequence +from typing import Any, Dict, List, Optional, Union, Sequence, Tuple import re import lz4.frame @@ -74,13 +74,13 @@ def build_queue( ResultSetQueue """ if row_set_type == TSparkRowSetType.ARROW_BASED_SET: - arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( + arrow_record_batches, n_valid_rows = convert_bytes_to_record_batches( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes ) - converted_arrow_table = convert_decimals_in_arrow_table( - arrow_table, description + arrow_stream_table = ArrowStreamTable( + arrow_record_batches, n_valid_rows, description ) - return ArrowQueue(converted_arrow_table, n_valid_rows) + return ArrowQueue(arrow_stream_table) elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: column_table, column_names = convert_column_based_set_to_column_table( t_row_set.columns, description @@ -105,10 +105,33 @@ def build_queue( raise AssertionError("Row set type is not valid") -class ColumnTable: +class ResultTable(ABC): + @abstractmethod + def next_n_rows(self, num_rows: int): + pass + + @abstractmethod + def remaining_rows(self): + pass + + @abstractmethod + def append(self, other: ResultTable): + pass + + @abstractmethod + def to_arrow_table(self) -> "pyarrow.Table": + pass + + @abstractmethod + def remove_extraneous_rows(self): + pass + + +class ColumnTable(ResultTable): def __init__(self, column_table, column_names): self.column_table = column_table self.column_names = column_names + self.curr_row_index = 0 @property def num_rows(self): @@ -121,49 +144,152 @@ def num_rows(self): def num_columns(self): return len(self.column_names) - def get_item(self, col_index, row_index): - return self.column_table[col_index][row_index] + def next_n_rows(self, num_rows: int): + sliced_column_table = [ + column[self.curr_row_index : self.curr_row_index + num_rows] + for column in self.column_table + ] + self.curr_row_index += num_rows + return ColumnTable(sliced_column_table, self.column_names) - def slice(self, curr_index, length): + def append(self, other: ColumnTable): + if self.column_names != other.column_names: + raise ValueError("The columns in the results don't match") + + merged_result = [ + self.column_table[i] + other.column_table[i] + for i in range(self.num_columns) + ] + self.column_table = merged_result + + def remaining_rows(self): sliced_column_table = [ - column[curr_index : curr_index + length] for column in self.column_table + column[self.curr_row_index :] for column in self.column_table ] + self.curr_row_index = self.num_rows return ColumnTable(sliced_column_table, self.column_names) - def __eq__(self, other): - return ( - self.column_table == other.column_table - and self.column_names == other.column_names + def get_item(self, col_index, row_index): + return self.column_table[col_index][row_index] + + def to_arrow_table(self): + data = {name: col for name, col in zip(self.column_names, self.column_table)} + return pyarrow.Table.from_pydict(data) + + def remove_extraneous_rows(self): + pass + + +class ArrowStreamTable(ResultTable): + def __init__( + self, + record_batches: List["pyarrow.RecordBatch"], + num_rows: int, + column_description, + ): + self.record_batches = record_batches + self.num_rows = num_rows + self.column_description = column_description + + def append(self, other: ArrowStreamTable): + if self.column_description != other.column_description: + raise ValueError( + "ArrowStreamTable: Column descriptions do not match for the tables to be appended" + ) + + self.record_batches.extend(other.record_batches) + self.num_rows += other.num_rows + + def next_n_rows(self, req_num_rows: int): + consumed_batches = [] + consumed_num_rows = 0 + while req_num_rows > 0 and self.record_batches: + current = self.record_batches[0] + if current.num_rows <= req_num_rows: + consumed_batches.append(current) + req_num_rows -= current.num_rows + consumed_num_rows += current.num_rows + self.num_rows -= current.num_rows + self.record_batches.pop(0) + else: + consumed_batches.append(current.slice(0, req_num_rows)) + self.record_batches[0] = current.slice(req_num_rows) + self.num_rows -= req_num_rows + consumed_num_rows += req_num_rows + req_num_rows = 0 + + return ArrowStreamTable( + consumed_batches, consumed_num_rows, self.column_description ) + def remaining_rows(self): + return self + + def convert_decimals_in_record_batch( + self, batch: "pyarrow.RecordBatch" + ) -> "pyarrow.RecordBatch": + new_columns = [] + new_fields = [] + + for i, col in enumerate(batch.columns): + field = batch.schema.field(i) + + if self.column_description[i][1] == "decimal": + precision, scale = ( + self.column_description[i][4], + self.column_description[i][5], + ) + assert scale is not None and precision is not None + dtype = pyarrow.decimal128(precision, scale) + + new_col = col.cast(dtype) + new_field = field.with_type(dtype) + + new_columns.append(new_col) + new_fields.append(new_field) + else: + new_columns.append(col) + new_fields.append(field) + + new_schema = pyarrow.schema(new_fields) + return pyarrow.RecordBatch.from_arrays(new_columns, schema=new_schema) + + def to_arrow_table(self) -> "pyarrow.Table": + def batch_generator(): + for batch in self.record_batches: + yield self.convert_decimals_in_record_batch(batch) + + return pyarrow.Table.from_batches(batch_generator()) + + def remove_extraneous_rows(self): + num_rows_in_data = sum(batch.num_rows for batch in self.record_batches) + rows_to_delete = num_rows_in_data - self.num_rows + while rows_to_delete > 0 and self.record_batches: + last_batch = self.record_batches[-1] + if last_batch.num_rows <= rows_to_delete: + self.record_batches.pop() + rows_to_delete -= last_batch.num_rows + else: + keep_rows = last_batch.num_rows - rows_to_delete + self.record_batches[-1] = last_batch.slice(0, keep_rows) + rows_to_delete = 0 + class ColumnQueue(ResultSetQueue): - def __init__(self, column_table: ColumnTable): - self.column_table = column_table - self.cur_row_index = 0 - self.n_valid_rows = column_table.num_rows + def __init__(self, table: ColumnTable): + self.table = table def next_n_rows(self, num_rows): - length = min(num_rows, self.n_valid_rows - self.cur_row_index) - - slice = self.column_table.slice(self.cur_row_index, length) - self.cur_row_index += slice.num_rows - return slice + return self.table.next_n_rows(num_rows) def remaining_rows(self): - slice = self.column_table.slice( - self.cur_row_index, self.n_valid_rows - self.cur_row_index - ) - self.cur_row_index += slice.num_rows - return slice + return self.table.remaining_rows() class ArrowQueue(ResultSetQueue): def __init__( self, - arrow_table: "pyarrow.Table", - n_valid_rows: int, - start_row_index: int = 0, + table: ArrowStreamTable, ): """ A queue-like wrapper over an Arrow table @@ -172,25 +298,14 @@ def __init__( :param n_valid_rows: The index of the last valid row in the table :param start_row_index: The first row in the table we should start fetching from """ - self.cur_row_index = start_row_index - self.arrow_table = arrow_table - self.n_valid_rows = n_valid_rows + self.table = table - def next_n_rows(self, num_rows: int) -> "pyarrow.Table": + def next_n_rows(self, num_rows: int): """Get upto the next n rows of the Arrow dataframe""" - length = min(num_rows, self.n_valid_rows - self.cur_row_index) - # Note that the table.slice API is not the same as Python's slice - # The second argument should be length, not end index - slice = self.arrow_table.slice(self.cur_row_index, length) - self.cur_row_index += slice.num_rows - return slice - - def remaining_rows(self) -> "pyarrow.Table": - slice = self.arrow_table.slice( - self.cur_row_index, self.n_valid_rows - self.cur_row_index - ) - self.cur_row_index += slice.num_rows - return slice + return self.table.next_n_rows(num_rows) + + def remaining_rows(self): + return self.table.remaining_rows() class CloudFetchQueue(ResultSetQueue): @@ -235,6 +350,7 @@ def __init__( result_link.startRowOffset, result_link.rowCount ) ) + self.download_manager = ResultFileDownloadManager( links=result_links or [], max_download_threads=self.max_download_threads, @@ -243,9 +359,8 @@ def __init__( ) self.table = self._create_next_table() - self.table_row_index = 0 - def next_n_rows(self, num_rows: int) -> "pyarrow.Table": + def next_n_rows(self, num_rows: int) -> ResultTable: """ Get up to the next n rows of the cloud fetch Arrow dataframes. @@ -255,50 +370,44 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": Returns: pyarrow.Table """ + results = self._create_empty_table() if not self.table: logger.debug("CloudFetchQueue: no more rows available") - # Return empty pyarrow table to cause retry of fetch - return self._create_empty_table() + return results logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows)) - results = self.table.slice(0, 0) + while num_rows > 0 and self.table: # Get remaining of num_rows or the rest of the current table, whichever is smaller - length = min(num_rows, self.table.num_rows - self.table_row_index) - table_slice = self.table.slice(self.table_row_index, length) - results = pyarrow.concat_tables([results, table_slice]) - self.table_row_index += table_slice.num_rows + length = min(num_rows, self.table.num_rows) + nxt_result = self.table.next_n_rows(length) + results.append(nxt_result) + num_rows -= nxt_result.num_rows # Replace current table with the next table if we are at the end of the current table - if self.table_row_index == self.table.num_rows: + if self.table.num_rows == 0: self.table = self._create_next_table() - self.table_row_index = 0 - num_rows -= table_slice.num_rows logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) return results - def remaining_rows(self) -> "pyarrow.Table": + def remaining_rows(self) -> ResultTable: """ Get all remaining rows of the cloud fetch Arrow dataframes. Returns: pyarrow.Table """ + result = self._create_empty_table() if not self.table: # Return empty pyarrow table to cause retry of fetch - return self._create_empty_table() - results = self.table.slice(0, 0) + return result + while self.table: - table_slice = self.table.slice( - self.table_row_index, self.table.num_rows - self.table_row_index - ) - results = pyarrow.concat_tables([results, table_slice]) - self.table_row_index += table_slice.num_rows + result.append(self.table) self.table = self._create_next_table() - self.table_row_index = 0 - return results + return result - def _create_next_table(self) -> Union["pyarrow.Table", None]: + def _create_next_table(self) -> ResultTable: logger.debug( "CloudFetchQueue: Trying to get downloaded file for row {}".format( self.start_row_index @@ -316,30 +425,32 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]: ) # None signals no more Arrow tables can be built from the remaining handlers if any remain return None - arrow_table = create_arrow_table_from_arrow_file( - downloaded_file.file_bytes, self.description + + result_table = ArrowStreamTable( + list(pyarrow.ipc.open_stream(downloaded_file.file_bytes)), + downloaded_file.row_count, + self.description, ) # The server rarely prepares the exact number of rows requested by the client in cloud fetch. # Subsequently, we drop the extraneous rows in the last file if more rows are retrieved than requested - if arrow_table.num_rows > downloaded_file.row_count: - arrow_table = arrow_table.slice(0, downloaded_file.row_count) + result_table.remove_extraneous_rows() - # At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows - assert downloaded_file.row_count == arrow_table.num_rows - self.start_row_index += arrow_table.num_rows + self.start_row_index += result_table.num_rows logger.debug( "CloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format( - arrow_table.num_rows, self.start_row_index + result_table.num_rows, self.start_row_index ) ) - return arrow_table + return result_table - def _create_empty_table(self) -> "pyarrow.Table": + def _create_empty_table(self) -> ResultTable: # Create a 0-row table with just the schema bytes - return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) + return ArrowStreamTable( + list(pyarrow.ipc.open_stream(self.schema_bytes)), 0, self.description + ) ExecuteResponse = namedtuple( @@ -606,7 +717,9 @@ def convert_arrow_based_file_to_arrow_table(file_bytes: bytes): raise RuntimeError("Failure to convert arrow based file to arrow table", e) -def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema_bytes): +def convert_bytes_to_record_batches( + arrow_batches, lz4_compressed, schema_bytes +) -> Tuple[List["pyarrow.RecordBatch"], int]: ba = bytearray() ba += schema_bytes n_rows = 0 @@ -617,8 +730,8 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema if lz4_compressed else arrow_batch.batch ) - arrow_table = pyarrow.ipc.open_stream(ba).read_all() - return arrow_table, n_rows + arrow_record_batches = list(pyarrow.ipc.open_stream(ba)) + return arrow_record_batches, n_rows def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index f57f75562..dc1c7d630 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -8,7 +8,7 @@ NoopTelemetryClient, TelemetryClientFactory, TelemetryHelper, - BaseTelemetryClient + BaseTelemetryClient, ) from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow from databricks.sql.auth.authenticators import ( @@ -24,7 +24,7 @@ def mock_telemetry_client(): session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() - + return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -43,7 +43,7 @@ def test_noop_client_behavior(self): client1 = NoopTelemetryClient() client2 = NoopTelemetryClient() assert client1 is client2 - + # Test that all methods can be called without exceptions client1.export_initial_telemetry_log(MagicMock(), "test-agent") client1.export_failure_log("TestError", "Test message") @@ -58,61 +58,61 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): """Test the complete event batching and flushing flow.""" client = mock_telemetry_client client._batch_size = 3 # Small batch for testing - + # Mock the network call - with patch.object(client, '_send_telemetry') as mock_send: + with patch.object(client, "_send_telemetry") as mock_send: # Add events one by one - should not flush yet client._export_event("event1") client._export_event("event2") mock_send.assert_not_called() assert len(client._events_batch) == 2 - + # Third event should trigger flush client._export_event("event3") mock_send.assert_called_once() assert len(client._events_batch) == 0 # Batch cleared after flush - - @patch('requests.post') + + @patch("requests.post") def test_network_request_flow(self, mock_post, mock_telemetry_client): """Test the complete network request flow with authentication.""" mock_post.return_value.status_code = 200 client = mock_telemetry_client - + # Create mock events mock_events = [MagicMock() for _ in range(2)] for i, event in enumerate(mock_events): event.to_json.return_value = f'{{"event": "{i}"}}' - + # Send telemetry client._send_telemetry(mock_events) - + # Verify request was submitted to executor client._executor.submit.assert_called_once() args, kwargs = client._executor.submit.call_args - + # Verify correct function and URL assert args[0] == requests.post - assert args[1] == 'https://test-host.com/telemetry-ext' - assert kwargs['headers']['Authorization'] == 'Bearer test-token' - + assert args[1] == "https://test-host.com/telemetry-ext" + assert kwargs["headers"]["Authorization"] == "Bearer test-token" + # Verify request body structure - request_data = kwargs['data'] + request_data = kwargs["data"] assert '"uploadTime"' in request_data assert '"protoLogs"' in request_data def test_telemetry_logging_flows(self, mock_telemetry_client): """Test all telemetry logging methods work end-to-end.""" client = mock_telemetry_client - - with patch.object(client, '_export_event') as mock_export: + + with patch.object(client, "_export_event") as mock_export: # Test initial log client.export_initial_telemetry_log(MagicMock(), "test-agent") assert mock_export.call_count == 1 - + # Test failure log client.export_failure_log("TestError", "Error message") assert mock_export.call_count == 2 - + # Test latency log client.export_latency_log(150, "EXECUTE_STATEMENT", "stmt-123") assert mock_export.call_count == 3 @@ -120,14 +120,14 @@ def test_telemetry_logging_flows(self, mock_telemetry_client): def test_error_handling_resilience(self, mock_telemetry_client): """Test that telemetry errors don't break the client.""" client = mock_telemetry_client - + # Test that exceptions in telemetry don't propagate - with patch.object(client, '_export_event', side_effect=Exception("Test error")): + with patch.object(client, "_export_event", side_effect=Exception("Test error")): # These should not raise exceptions client.export_initial_telemetry_log(MagicMock(), "test-agent") client.export_failure_log("TestError", "Error message") client.export_latency_log(100, "EXECUTE_STATEMENT", "stmt-123") - + # Test executor submission failure client._executor.submit.side_effect = Exception("Thread pool error") client._send_telemetry([MagicMock()]) # Should not raise @@ -140,7 +140,7 @@ def test_system_configuration_caching(self): """Test that system configuration is cached and contains expected data.""" config1 = TelemetryHelper.get_driver_system_configuration() config2 = TelemetryHelper.get_driver_system_configuration() - + # Should be cached (same instance) assert config1 is config2 @@ -153,7 +153,7 @@ def test_auth_mechanism_detection(self): (MagicMock(), AuthMech.OTHER), # Unknown provider (None, None), ] - + for provider, expected in test_cases: assert TelemetryHelper.get_auth_mechanism(provider) == expected @@ -163,19 +163,25 @@ def test_auth_flow_detection(self): oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) oauth_with_tokens._access_token = "test-access-token" oauth_with_tokens._refresh_token = "test-refresh-token" - assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH - + assert ( + TelemetryHelper.get_auth_flow(oauth_with_tokens) + == AuthFlow.TOKEN_PASSTHROUGH + ) + # Test OAuth with browser-based auth oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) oauth_with_browser._access_token = None oauth_with_browser._refresh_token = None oauth_with_browser.oauth_manager = MagicMock() - assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION - + assert ( + TelemetryHelper.get_auth_flow(oauth_with_browser) + == AuthFlow.BROWSER_BASED_AUTHENTICATION + ) + # Test non-OAuth provider pat_auth = AccessTokenAuthProvider("test-token") assert TelemetryHelper.get_auth_flow(pat_auth) is None - + # Test None auth provider assert TelemetryHelper.get_auth_flow(None) is None @@ -202,24 +208,24 @@ def test_client_lifecycle_flow(self): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") - + # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, auth_provider=auth_provider, - host_url="test-host.com" + host_url="test-host.com", ) - + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) assert client._session_id_hex == session_id_hex - + # Close client - with patch.object(client, 'close') as mock_close: + with patch.object(client, "close") as mock_close: TelemetryClientFactory.close(session_id_hex) mock_close.assert_called_once() - + # Should get NoopTelemetryClient after close client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) @@ -227,31 +233,33 @@ def test_client_lifecycle_flow(self): def test_disabled_telemetry_flow(self): """Test that disabled telemetry uses NoopTelemetryClient.""" session_id_hex = "test-session" - + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, session_id_hex=session_id_hex, auth_provider=None, - host_url="test-host.com" + host_url="test-host.com", ) - + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) def test_factory_error_handling(self): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" - + # Simulate initialization error - with patch('databricks.sql.telemetry.telemetry_client.TelemetryClient', - side_effect=Exception("Init error")): + with patch( + "databricks.sql.telemetry.telemetry_client.TelemetryClient", + side_effect=Exception("Init error"), + ): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id, auth_provider=AccessTokenAuthProvider("token"), - host_url="test-host.com" + host_url="test-host.com", ) - + # Should fall back to NoopTelemetryClient client = TelemetryClientFactory.get_telemetry_client(session_id) assert isinstance(client, NoopTelemetryClient) @@ -260,25 +268,25 @@ def test_factory_shutdown_flow(self): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" - + # Initialize multiple clients for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session, auth_provider=AccessTokenAuthProvider("token"), - host_url="test-host.com" + host_url="test-host.com", ) - + # Factory should be initialized assert TelemetryClientFactory._initialized is True assert TelemetryClientFactory._executor is not None - + # Close first client - factory should stay initialized TelemetryClientFactory.close(session1) assert TelemetryClientFactory._initialized is True - + # Close second client - factory should shut down TelemetryClientFactory.close(session2) assert TelemetryClientFactory._initialized is False - assert TelemetryClientFactory._executor is None \ No newline at end of file + assert TelemetryClientFactory._executor is None diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 458ea9a82..579fe0a35 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1468,10 +1468,10 @@ def test_convert_arrow_based_set_to_arrow_table( ttypes.TSparkArrowBatch(batch=bytearray("Testing", "utf-8"), rowCount=1) for _ in range(10) ] - utils.convert_arrow_based_set_to_arrow_table(arrow_batches, False, schema) + utils.convert_bytes_to_record_batches(arrow_batches, False, schema) lz4_decompress_mock.assert_not_called() - utils.convert_arrow_based_set_to_arrow_table(arrow_batches, True, schema) + utils.convert_bytes_to_record_batches(arrow_batches, True, schema) lz4_decompress_mock.assert_called() def test_convert_column_based_set_to_arrow_table_without_nulls(self):