diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index df6d6a801..e9764ce76 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union, TYPE_CHECKING from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager +from databricks.sql.telemetry.models.enums import StatementType try: import pyarrow @@ -134,9 +135,13 @@ def __init__( super().__init__( max_download_threads=max_download_threads, ssl_options=ssl_options, + statement_id=statement_id, schema_bytes=None, lz4_compressed=lz4_compressed, description=description, + # TODO: fix these arguments when telemetry is implemented in SEA + session_id_hex=None, + chunk_id=0, ) self._sea_client = sea_client diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 50a256f48..84679cb33 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -6,9 +6,10 @@ import time import threading from typing import List, Optional, Union, Any, TYPE_CHECKING +from uuid import UUID from databricks.sql.result_set import ThriftResultSet - +from databricks.sql.telemetry.models.event import StatementType if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -900,6 +901,7 @@ def get_execution_result( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -1037,6 +1039,7 @@ def execute_command( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, ) def get_catalogs( @@ -1077,6 +1080,7 @@ def get_catalogs( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, ) def get_schemas( @@ -1123,6 +1127,7 @@ def get_schemas( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, ) def get_tables( @@ -1173,6 +1178,7 @@ def get_tables( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, ) def get_columns( @@ -1223,6 +1229,7 @@ def get_columns( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, ) def _handle_execute_response(self, resp, cursor): @@ -1257,6 +1264,7 @@ def fetch_results( lz4_compressed: bool, arrow_schema_bytes, description, + chunk_id: int, use_cloud_fetch=True, ): thrift_handle = command_id.to_thrift_handle() @@ -1294,9 +1302,16 @@ def fetch_results( lz4_compressed=lz4_compressed, description=description, ssl_options=self._ssl_options, + session_id_hex=self._session_id_hex, + statement_id=command_id.to_hex_guid(), + chunk_id=chunk_id, ) - return queue, resp.hasMoreRows + return ( + queue, + resp.hasMoreRows, + len(resp.results.resultLinks) if resp.results.resultLinks else 0, + ) def cancel_command(self, command_id: CommandId) -> None: thrift_handle = command_id.to_thrift_handle() diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index f6428a187..a4ec307d4 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -4,6 +4,7 @@ import logging from databricks.sql.backend.utils.guid_utils import guid_to_hex_id +from databricks.sql.telemetry.models.enums import StatementType from databricks.sql.thrift_api.TCLIService import ttypes logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 75e89d92a..c279f2c1f 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -280,7 +280,9 @@ def read(self) -> Optional[OAuthToken]: driver_connection_params = DriverConnectionParameters( http_path=http_path, - mode=DatabricksClientType.THRIFT, + mode=DatabricksClientType.SEA + if self.session.use_sea + else DatabricksClientType.THRIFT, host_info=HostDetails(host_url=server_hostname, port=self.session.port), auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider), auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider), diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 12dd0a01f..32b698bed 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,7 +1,7 @@ import logging from concurrent.futures import ThreadPoolExecutor, Future -from typing import List, Union +from typing import List, Union, Tuple, Optional from databricks.sql.cloudfetch.downloader import ( ResultSetDownloadHandler, @@ -9,7 +9,7 @@ DownloadedFile, ) from databricks.sql.types import SSLOptions - +from databricks.sql.telemetry.models.event import StatementType from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) @@ -22,17 +22,22 @@ def __init__( max_download_threads: int, lz4_compressed: bool, ssl_options: SSLOptions, + session_id_hex: Optional[str], + statement_id: str, + chunk_id: int, ): - self._pending_links: List[TSparkArrowResultLink] = [] - for link in links: + self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = [] + self.chunk_id = chunk_id + for i, link in enumerate(links, start=chunk_id): if link.rowCount <= 0: continue logger.debug( - "ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format( - link.startRowOffset, link.rowCount + "ResultFileDownloadManager: adding file link, chunk id {}, start offset {}, row count: {}".format( + i, link.startRowOffset, link.rowCount ) ) - self._pending_links.append(link) + self._pending_links.append((i, link)) + self.chunk_id += len(links) self._download_tasks: List[Future[DownloadedFile]] = [] self._max_download_threads: int = max_download_threads @@ -40,6 +45,8 @@ def __init__( self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) self._ssl_options = ssl_options + self.session_id_hex = session_id_hex + self.statement_id = statement_id def get_next_downloaded_file( self, next_row_offset: int @@ -89,14 +96,19 @@ def _schedule_downloads(self): while (len(self._download_tasks) < self._max_download_threads) and ( len(self._pending_links) > 0 ): - link = self._pending_links.pop(0) + chunk_id, link = self._pending_links.pop(0) logger.debug( - "- start: {}, row count: {}".format(link.startRowOffset, link.rowCount) + "- chunk: {}, start: {}, row count: {}".format( + chunk_id, link.startRowOffset, link.rowCount + ) ) handler = ResultSetDownloadHandler( settings=self._downloadable_result_settings, link=link, ssl_options=self._ssl_options, + chunk_id=chunk_id, + session_id_hex=self.session_id_hex, + statement_id=self.statement_id, ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) @@ -117,7 +129,8 @@ def add_link(self, link: TSparkArrowResultLink): link.startRowOffset, link.rowCount ) ) - self._pending_links.append(link) + self._pending_links.append((self.chunk_id, link)) + self.chunk_id += 1 def _shutdown_manager(self): # Clear download handlers and shutdown the thread pool diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 228e07d6c..e19a69046 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -1,5 +1,6 @@ import logging from dataclasses import dataclass +from typing import Optional import requests from requests.adapters import HTTPAdapter, Retry @@ -9,6 +10,8 @@ from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.exc import Error from databricks.sql.types import SSLOptions +from databricks.sql.telemetry.latency_logger import log_latency +from databricks.sql.telemetry.models.event import StatementType logger = logging.getLogger(__name__) @@ -66,11 +69,18 @@ def __init__( settings: DownloadableResultSettings, link: TSparkArrowResultLink, ssl_options: SSLOptions, + chunk_id: int, + session_id_hex: Optional[str], + statement_id: str, ): self.settings = settings self.link = link self._ssl_options = ssl_options + self.chunk_id = chunk_id + self.session_id_hex = session_id_hex + self.statement_id = statement_id + @log_latency(StatementType.QUERY) def run(self) -> DownloadedFile: """ Download the file described in the cloud fetch link. @@ -80,8 +90,8 @@ def run(self) -> DownloadedFile: """ logger.debug( - "ResultSetDownloadHandler: starting file download, offset {}, row count {}".format( - self.link.startRowOffset, self.link.rowCount + "ResultSetDownloadHandler: starting file download, chunk id {}, offset {}, row count {}".format( + self.chunk_id, self.link.startRowOffset, self.link.rowCount ) ) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index dc279cf91..cd2f980e8 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -22,6 +22,7 @@ ColumnQueue, ) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse +from databricks.sql.telemetry.models.event import StatementType logger = logging.getLogger(__name__) @@ -192,6 +193,7 @@ def __init__( connection: "Connection", execute_response: "ExecuteResponse", thrift_client: "ThriftDatabricksClient", + session_id_hex: Optional[str], buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, @@ -215,6 +217,7 @@ def __init__( :param ssl_options: SSL options for cloud fetch :param is_direct_results: Whether there are more rows to fetch """ + self.num_downloaded_chunks = 0 # Initialize ThriftResultSet-specific attributes self._use_cloud_fetch = use_cloud_fetch @@ -234,7 +237,12 @@ def __init__( lz4_compressed=execute_response.lz4_compressed, description=execute_response.description, ssl_options=ssl_options, + session_id_hex=session_id_hex, + statement_id=execute_response.command_id.to_hex_guid(), + chunk_id=self.num_downloaded_chunks, ) + if t_row_set and t_row_set.resultLinks: + self.num_downloaded_chunks += len(t_row_set.resultLinks) # Call parent constructor with common attributes super().__init__( @@ -258,7 +266,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, is_direct_results = self.backend.fetch_results( + results, is_direct_results, result_links_count = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -267,9 +275,11 @@ def _fill_results_buffer(self): arrow_schema_bytes=self._arrow_schema_bytes, description=self.description, use_cloud_fetch=self._use_cloud_fetch, + chunk_id=self.num_downloaded_chunks, ) self.results = results self.is_direct_results = is_direct_results + self.num_downloaded_chunks += result_links_count def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index b956657ee..b0908ac25 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -97,10 +97,10 @@ def _create_backend( kwargs: dict, ) -> DatabricksClient: """Create and return the appropriate backend client.""" - use_sea = kwargs.get("use_sea", False) + self.use_sea = kwargs.get("use_sea", False) databricks_client_class: Type[DatabricksClient] - if use_sea: + if self.use_sea: logger.debug("Creating SEA backend client") databricks_client_class = SeaDatabricksClient else: diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 0b0c564da..12cacd851 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -7,8 +7,6 @@ SqlExecutionEvent, ) from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType -from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue -from uuid import UUID logger = logging.getLogger(__name__) @@ -36,12 +34,15 @@ def get_statement_id(self): def get_is_compressed(self): pass - def get_execution_result(self): + def get_execution_result_format(self): pass def get_retry_count(self): pass + def get_chunk_id(self): + pass + class CursorExtractor(TelemetryExtractor): """ @@ -60,10 +61,12 @@ def get_session_id_hex(self) -> Optional[str]: def get_is_compressed(self) -> bool: return self.connection.lz4_compression - def get_execution_result(self) -> ExecutionResultFormat: + def get_execution_result_format(self) -> ExecutionResultFormat: if self.active_result_set is None: return ExecutionResultFormat.FORMAT_UNSPECIFIED + from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue + if isinstance(self.active_result_set.results, ColumnQueue): return ExecutionResultFormat.COLUMNAR_INLINE elif isinstance(self.active_result_set.results, CloudFetchQueue): @@ -73,49 +76,37 @@ def get_execution_result(self) -> ExecutionResultFormat: return ExecutionResultFormat.FORMAT_UNSPECIFIED def get_retry_count(self) -> int: - if ( - hasattr(self.thrift_backend, "retry_policy") - and self.thrift_backend.retry_policy - ): - return len(self.thrift_backend.retry_policy.history) + if hasattr(self.backend, "retry_policy") and self.backend.retry_policy: + return len(self.backend.retry_policy.history) return 0 + def get_chunk_id(self): + return None -class ResultSetExtractor(TelemetryExtractor): - """ - Telemetry extractor specialized for ResultSet objects. - Extracts telemetry information from database result set objects, including - operation IDs, session information, compression settings, and result formats. +class ResultSetDownloadHandlerExtractor(TelemetryExtractor): + """ + Telemetry extractor specialized for ResultSetDownloadHandler objects. """ - - def get_statement_id(self) -> Optional[str]: - if self.command_id: - return str(UUID(bytes=self.command_id.operationId.guid)) - return None def get_session_id_hex(self) -> Optional[str]: - return self.connection.get_session_id_hex() + return self._obj.session_id_hex + + def get_statement_id(self) -> Optional[str]: + return self._obj.statement_id def get_is_compressed(self) -> bool: - return self.lz4_compressed + return self._obj.settings.is_lz4_compressed - def get_execution_result(self) -> ExecutionResultFormat: - if isinstance(self.results, ColumnQueue): - return ExecutionResultFormat.COLUMNAR_INLINE - elif isinstance(self.results, CloudFetchQueue): - return ExecutionResultFormat.EXTERNAL_LINKS - elif isinstance(self.results, ArrowQueue): - return ExecutionResultFormat.INLINE_ARROW - return ExecutionResultFormat.FORMAT_UNSPECIFIED + def get_execution_result_format(self) -> ExecutionResultFormat: + return ExecutionResultFormat.EXTERNAL_LINKS - def get_retry_count(self) -> int: - if ( - hasattr(self.thrift_backend, "retry_policy") - and self.thrift_backend.retry_policy - ): - return len(self.thrift_backend.retry_policy.history) - return 0 + def get_retry_count(self) -> Optional[int]: + # standard requests and urllib3 libraries don't expose retry count + return None + + def get_chunk_id(self) -> Optional[int]: + return self._obj.chunk_id def get_extractor(obj): @@ -126,19 +117,19 @@ def get_extractor(obj): that can extract telemetry information from that object type. Args: - obj: The object to create an extractor for. Can be a Cursor, ResultSet, - or any other object. + obj: The object to create an extractor for. Can be a Cursor, + ResultSetDownloadHandler, or any other object. Returns: TelemetryExtractor: A specialized extractor instance: - CursorExtractor for Cursor objects - - ResultSetExtractor for ResultSet objects + - ResultSetDownloadHandlerExtractor for ResultSetDownloadHandler objects - None for all other objects """ if obj.__class__.__name__ == "Cursor": return CursorExtractor(obj) - elif obj.__class__.__name__ == "ResultSet": - return ResultSetExtractor(obj) + elif obj.__class__.__name__ == "ResultSetDownloadHandler": + return ResultSetDownloadHandlerExtractor(obj) else: logger.debug("No extractor found for %s", obj.__class__.__name__) return None @@ -162,7 +153,7 @@ def log_latency(statement_type: StatementType = StatementType.NONE): statement_type (StatementType): The type of SQL statement being executed. Usage: - @log_latency(StatementType.SQL) + @log_latency(StatementType.QUERY) def execute(self, query): # Method implementation pass @@ -204,8 +195,11 @@ def _safe_call(func_to_call): sql_exec_event = SqlExecutionEvent( statement_type=statement_type, is_compressed=_safe_call(extractor.get_is_compressed), - execution_result=_safe_call(extractor.get_execution_result), + execution_result=_safe_call( + extractor.get_execution_result_format + ), retry_count=_safe_call(extractor.get_retry_count), + chunk_id=_safe_call(extractor.get_chunk_id), ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index f5496deec..83f72cd3b 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -122,12 +122,14 @@ class SqlExecutionEvent(JsonSerializableMixin): is_compressed (bool): Whether the result is compressed execution_result (ExecutionResultFormat): Format of the execution result retry_count (int): Number of retry attempts made + chunk_id (int): ID of the chunk if applicable """ statement_type: StatementType is_compressed: bool execution_result: ExecutionResultFormat - retry_count: int + retry_count: Optional[int] + chunk_id: Optional[int] @dataclass diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 79a376d12..f2f9fcb95 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -27,7 +27,8 @@ TSparkRowSetType, ) from databricks.sql.types import SSLOptions - +from databricks.sql.backend.types import CommandId +from databricks.sql.telemetry.models.event import StatementType from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter import logging @@ -60,6 +61,9 @@ def build_queue( arrow_schema_bytes: bytes, max_download_threads: int, ssl_options: SSLOptions, + session_id_hex: Optional[str], + statement_id: str, + chunk_id: int, lz4_compressed: bool = True, description: List[Tuple] = [], ) -> ResultSetQueue: @@ -106,6 +110,9 @@ def build_queue( description=description, max_download_threads=max_download_threads, ssl_options=ssl_options, + session_id_hex=session_id_hex, + statement_id=statement_id, + chunk_id=chunk_id, ) else: raise AssertionError("Row set type is not valid") @@ -214,6 +221,9 @@ def __init__( self, max_download_threads: int, ssl_options: SSLOptions, + session_id_hex: Optional[str], + statement_id: str, + chunk_id: int, schema_bytes: Optional[bytes] = None, lz4_compressed: bool = True, description: List[Tuple] = [], @@ -234,6 +244,9 @@ def __init__( self.lz4_compressed = lz4_compressed self.description = description self._ssl_options = ssl_options + self.session_id_hex = session_id_hex + self.statement_id = statement_id + self.chunk_id = chunk_id # Table state self.table = None @@ -245,6 +258,9 @@ def __init__( max_download_threads=max_download_threads, lz4_compressed=lz4_compressed, ssl_options=ssl_options, + session_id_hex=session_id_hex, + statement_id=statement_id, + chunk_id=chunk_id, ) def next_n_rows(self, num_rows: int) -> "pyarrow.Table": @@ -348,6 +364,9 @@ def __init__( schema_bytes, max_download_threads: int, ssl_options: SSLOptions, + session_id_hex: Optional[str], + statement_id: str, + chunk_id: int, start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, @@ -371,10 +390,16 @@ def __init__( schema_bytes=schema_bytes, lz4_compressed=lz4_compressed, description=description, + session_id_hex=session_id_hex, + statement_id=statement_id, + chunk_id=chunk_id, ) self.start_row_index = start_row_offset self.result_links = result_links or [] + self.session_id_hex = session_id_hex + self.statement_id = statement_id + self.chunk_id = chunk_id logger.debug( "Initialize CloudFetch loader, row set start offset: {}, file list:".format( diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 3b5072cfe..f118d2833 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -115,7 +115,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None - mock_backend.fetch_results.return_value = (Mock(), False) + mock_backend.fetch_results.return_value = (Mock(), False, 0) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -128,6 +128,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): connection=connection, execute_response=mock_execute_response, thrift_client=mock_backend, + session_id_hex=Mock(), ) # Mock execute_command to return our real result set @@ -189,12 +190,13 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() mock_results = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) + mock_backend.fetch_results.return_value = (Mock(), False, 0) result_set = ThriftResultSet( connection=mock_connection, execute_response=Mock(), thrift_client=mock_backend, + session_id_hex=Mock(), ) result_set.results = mock_results @@ -220,9 +222,9 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - mock_thrift_backend.fetch_results.return_value = (Mock(), False) + mock_thrift_backend.fetch_results.return_value = (Mock(), False, 0) result_set = ThriftResultSet( - mock_connection, mock_results_response, mock_thrift_backend + mock_connection, mock_results_response, mock_thrift_backend, session_id_hex=Mock() ) result_set.results = mock_results @@ -268,9 +270,9 @@ def test_closed_cursor_doesnt_allow_operations(self): def test_negative_fetch_throws_exception(self): mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) + mock_backend.fetch_results.return_value = (Mock(), False, 0) - result_set = ThriftResultSet(Mock(), Mock(), mock_backend) + result_set = ThriftResultSet(Mock(), Mock(), mock_backend, session_id_hex=Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 275d055c9..f50c1b82d 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -4,7 +4,7 @@ pyarrow = None import unittest import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, Mock from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink import databricks.sql.utils as utils @@ -63,6 +63,9 @@ def test_initializer_adds_links(self, mock_create_next_table): result_links=result_links, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert len(queue.download_manager._pending_links) == 10 @@ -77,6 +80,9 @@ def test_initializer_no_links_to_add(self): result_links=result_links, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert len(queue.download_manager._pending_links) == 0 @@ -93,6 +99,9 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): result_links=[], max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue._create_next_table() is None @@ -114,6 +123,9 @@ def test_initializer_create_next_table_success( description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) expected_result = self.make_arrow_table() @@ -139,6 +151,9 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -160,6 +175,9 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -180,6 +198,9 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -205,6 +226,9 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -227,6 +251,9 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table is None @@ -244,6 +271,9 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -263,6 +293,9 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -282,6 +315,9 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -307,6 +343,9 @@ def test_remaining_rows_multiple_tables_fully_returned( description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -335,6 +374,9 @@ def test_remaining_rows_empty_table(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table is None diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 64edbdebe..6eb17a05a 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock import databricks.sql.cloudfetch.download_manager as download_manager from databricks.sql.types import SSLOptions @@ -19,6 +19,9 @@ def create_download_manager( max_download_threads, lz4_compressed, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) def create_result_link( diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 2a3b715b5..9879e17c7 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -27,7 +27,7 @@ def test_run_link_expired(self, mock_time): # Already expired result_link.expiryTime = 999 d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) with self.assertRaises(Error) as context: @@ -43,7 +43,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): # Within the expiry buffer time result_link.expiryTime = 1004 d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) with self.assertRaises(Error) as context: @@ -63,7 +63,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): result_link = Mock(expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() @@ -82,7 +82,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session): result_link = Mock(bytesNum=100, expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) file = d.run() @@ -105,7 +105,7 @@ def test_run_compressed_successful(self, mock_time, mock_session): result_link = Mock(bytesNum=100, expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) file = d.run() @@ -121,7 +121,7 @@ def test_download_connection_error(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) with self.assertRaises(ConnectionError): d.run() @@ -136,7 +136,7 @@ def test_download_timeout(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) with self.assertRaises(TimeoutError): d.run() diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index a649941e1..9bb29de8f 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -43,7 +43,7 @@ def make_dummy_result_set_from_initial_results(initial_results): # Create a mock backend that will return the queue when _fill_results_buffer is called mock_thrift_backend = Mock(spec=ThriftDatabricksClient) - mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + mock_thrift_backend.fetch_results.return_value = (arrow_queue, False, 0) num_cols = len(initial_results[0]) if initial_results else 0 description = [ @@ -54,7 +54,7 @@ def make_dummy_result_set_from_initial_results(initial_results): rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( - command_id=None, + command_id=Mock(), status=None, has_been_closed_server_side=True, description=description, @@ -63,6 +63,7 @@ def make_dummy_result_set_from_initial_results(initial_results): ), thrift_client=mock_thrift_backend, t_row_set=None, + session_id_hex=Mock(), ) return rs @@ -79,12 +80,13 @@ def fetch_results( arrow_schema_bytes, description, use_cloud_fetch=True, + chunk_id=0, ): nonlocal batch_index results = FetchTests.make_arrow_queue(batch_list[batch_index]) batch_index += 1 - return results, batch_index < len(batch_list) + return results, batch_index < len(batch_list), 0 mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results = fetch_results @@ -98,7 +100,7 @@ def fetch_results( rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( - command_id=None, + command_id=Mock(), status=None, has_been_closed_server_side=False, description=description, @@ -106,6 +108,7 @@ def fetch_results( is_staging_operation=False, ), thrift_client=mock_thrift_backend, + session_id_hex=Mock(), ) return rs diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 37569f755..452eb4d3e 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -731,7 +731,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock(), Mock()) self.assertEqual(display_message, str(cm.exception)) self.assertIn(diagnostic_info, str(cm.exception.message_with_context())) @@ -772,7 +772,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock(), Mock()) self.assertEqual(display_message, str(cm.exception)) self.assertIn(diagnostic_info, str(cm.exception.message_with_context())) @@ -1097,7 +1097,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(execute_resp, Mock()) - _, has_more_rows_resp = thrift_backend.fetch_results( + _, has_more_rows_resp, _ = thrift_backend.fetch_results( command_id=Mock(), max_rows=1, max_bytes=1, @@ -1105,6 +1105,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( lz4_compressed=False, arrow_schema_bytes=Mock(), description=Mock(), + chunk_id=0, ) self.assertEqual(is_direct_results, has_more_rows_resp) @@ -1150,7 +1151,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - arrow_queue, has_more_results = thrift_backend.fetch_results( + arrow_queue, has_more_results, _ = thrift_backend.fetch_results( command_id=Mock(), max_rows=1, max_bytes=1, @@ -1158,6 +1159,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): lz4_compressed=False, arrow_schema_bytes=schema, description=MagicMock(), + chunk_id=0, ) self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) @@ -1183,7 +1185,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( cursor_mock = Mock() result = thrift_backend.execute_command( - "foo", Mock(), 100, 200, Mock(), cursor_mock + "foo", Mock(), 100, 200, Mock(), cursor_mock, Mock() ) # Verify the result is a ResultSet self.assertEqual(result, mock_result_set.return_value) @@ -1448,7 +1450,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( thrift_backend = self._make_fake_thrift_backend() with self.assertRaises(OperationalError) as cm: - thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock(), Mock()) self.assertIn( "Expected results to be in Arrow or column based format", str(cm.exception) ) @@ -2277,7 +2279,7 @@ def test_execute_command_sets_complex_type_fields_correctly( ssl_options=SSLOptions(), **complex_arg_types, ) - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock(), Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ 0 ][0]