diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index 5bc6c6793..3c0e325fe 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -45,7 +45,6 @@ def test_sea_async_query_with_cloud_fetch(): use_sea=True, user_agent_entry="SEA-Test-Client", use_cloud_fetch=True, - enable_query_result_lz4_compression=False, ) logger.info( diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 16ee80a78..76941e2d2 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -43,7 +43,6 @@ def test_sea_sync_query_with_cloud_fetch(): use_sea=True, user_agent_entry="SEA-Test-Client", use_cloud_fetch=True, - enable_query_result_lz4_compression=False, ) logger.info( diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/client.py similarity index 86% rename from src/databricks/sql/backend/sea/backend.py rename to src/databricks/sql/backend/sea/client.py index 42677b903..9f7b552f8 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/client.py @@ -5,7 +5,7 @@ import re from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set -from databricks.sql.backend.sea.models.base import ExternalLink, ResultManifest +from databricks.sql.backend.sea.models.base import ResultManifest from databricks.sql.backend.sea.utils.constants import ( ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, ResultFormat, @@ -18,8 +18,7 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor - -from databricks.sql.backend.sea.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import ( @@ -29,7 +28,7 @@ BackendType, ExecuteResponse, ) -from databricks.sql.exc import DatabaseError, ServerOperationError +from databricks.sql.exc import DatabaseError, ProgrammingError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.types import SSLOptions @@ -45,23 +44,22 @@ GetStatementResponse, CreateSessionResponse, ) -from databricks.sql.backend.sea.models.responses import GetChunksResponse logger = logging.getLogger(__name__) def _filter_session_configuration( - session_configuration: Optional[Dict[str, Any]], -) -> Dict[str, str]: + session_configuration: Optional[Dict[str, str]] +) -> Optional[Dict[str, str]]: if not session_configuration: - return {} + return None filtered_session_configuration = {} ignored_configs: Set[str] = set() for key, value in session_configuration.items(): if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: - filtered_session_configuration[key.lower()] = str(value) + filtered_session_configuration[key.lower()] = value else: ignored_configs.add(key) @@ -90,7 +88,6 @@ class SeaDatabricksClient(DatabricksClient): STATEMENT_PATH = BASE_PATH + "statements" STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" - CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}" # SEA constants POLL_INTERVAL_SECONDS = 0.2 @@ -126,24 +123,18 @@ def __init__( ) self._max_download_threads = kwargs.get("max_download_threads", 10) - self._ssl_options = ssl_options - self._use_arrow_native_complex_types = kwargs.get( - "_use_arrow_native_complex_types", True - ) - - self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True) # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) # Initialize HTTP client - self._http_client = SeaHttpClient( + self.http_client = SeaHttpClient( server_hostname=server_hostname, port=port, http_path=http_path, http_headers=http_headers, auth_provider=auth_provider, - ssl_options=self._ssl_options, + ssl_options=ssl_options, **kwargs, ) @@ -182,7 +173,7 @@ def _extract_warehouse_id(self, http_path: str) -> str: f"Note: SEA only works for warehouses." ) logger.error(error_message) - raise ValueError(error_message) + raise ProgrammingError(error_message) @property def max_download_threads(self) -> int: @@ -191,7 +182,7 @@ def max_download_threads(self) -> int: def open_session( self, - session_configuration: Optional[Dict[str, Any]], + session_configuration: Optional[Dict[str, str]], catalog: Optional[str], schema: Optional[str], ) -> SessionId: @@ -229,7 +220,7 @@ def open_session( schema=schema, ) - response = self._http_client._make_request( + response = self.http_client._make_request( method="POST", path=self.SESSION_PATH, data=request_data.to_dict() ) @@ -254,7 +245,7 @@ def close_session(self, session_id: SessionId) -> None: session_id: The session identifier returned by open_session() Raises: - ValueError: If the session ID is invalid + ProgrammingError: If the session ID is invalid OperationalError: If there's an error closing the session """ @@ -269,7 +260,7 @@ def close_session(self, session_id: SessionId) -> None: session_id=sea_session_id, ) - self._http_client._make_request( + self.http_client._make_request( method="DELETE", path=self.SESSION_PATH_WITH_ID.format(sea_session_id), data=request_data.to_dict(), @@ -333,7 +324,7 @@ def _extract_description_from_manifest( return columns def _results_message_to_execute_response( - self, response: Union[ExecuteStatementResponse, GetStatementResponse] + self, response: GetStatementResponse ) -> ExecuteResponse: """ Convert a SEA response to an ExecuteResponse and extract result data. @@ -351,43 +342,20 @@ def _results_message_to_execute_response( # Check for compression lz4_compressed = ( - response.manifest.result_compression == ResultCompression.LZ4_FRAME.value + response.manifest.result_compression == ResultCompression.LZ4_FRAME ) execute_response = ExecuteResponse( command_id=CommandId.from_sea_statement_id(response.statement_id), status=response.status.state, description=description, - has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=response.manifest.is_volume_operation, - arrow_schema_bytes=None, result_format=response.manifest.format, ) return execute_response - def _response_to_result_set( - self, - response: Union[ExecuteStatementResponse, GetStatementResponse], - cursor: Cursor, - ) -> SeaResultSet: - """ - Convert a SEA response to a SeaResultSet. - """ - - execute_response = self._results_message_to_execute_response(response) - - return SeaResultSet( - connection=cursor.connection, - execute_response=execute_response, - sea_client=self, - result_data=response.result, - manifest=response.manifest, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, - ) - def _check_command_not_in_failed_or_closed_state( self, state: CommandState, command_id: CommandId ) -> None: @@ -408,24 +376,21 @@ def _check_command_not_in_failed_or_closed_state( def _wait_until_command_done( self, response: ExecuteStatementResponse - ) -> Union[ExecuteStatementResponse, GetStatementResponse]: + ) -> CommandState: """ Wait until a command is done. """ - final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response - - state = final_response.status.state - command_id = CommandId.from_sea_statement_id(final_response.statement_id) + state = response.status.state + command_id = CommandId.from_sea_statement_id(response.statement_id) while state in [CommandState.PENDING, CommandState.RUNNING]: time.sleep(self.POLL_INTERVAL_SECONDS) - final_response = self._poll_query(command_id) - state = final_response.status.state + state = self.get_query_state(command_id) self._check_command_not_in_failed_or_closed_state(state, command_id) - return final_response + return state def execute_command( self, @@ -457,7 +422,7 @@ def execute_command( enforce_embedded_schema_correctness: Whether to enforce schema correctness Returns: - SeaResultSet: A SeaResultSet instance for the executed command + ResultSet: A SeaResultSet instance for the executed command """ if session_id.backend_type != BackendType.SEA: @@ -483,11 +448,7 @@ def execute_command( ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY ).value disposition = ( - ( - ResultDisposition.HYBRID - if self.use_hybrid_disposition - else ResultDisposition.EXTERNAL_LINKS - ) + ResultDisposition.EXTERNAL_LINKS if use_cloud_fetch else ResultDisposition.INLINE ).value @@ -508,7 +469,7 @@ def execute_command( result_compression=result_compression, ) - response_data = self._http_client._make_request( + response_data = self.http_client._make_request( method="POST", path=self.STATEMENT_PATH, data=request.to_dict() ) response = ExecuteStatementResponse.from_dict(response_data) @@ -531,11 +492,8 @@ def execute_command( if async_op: return None - final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response - if response.status.state != CommandState.SUCCEEDED: - final_response = self._wait_until_command_done(response) - - return self._response_to_result_set(final_response, cursor) + self._wait_until_command_done(response) + return self.get_execution_result(command_id, cursor) def cancel_command(self, command_id: CommandId) -> None: """ @@ -545,7 +503,7 @@ def cancel_command(self, command_id: CommandId) -> None: command_id: Command identifier to cancel Raises: - ValueError: If the command ID is invalid + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: @@ -556,7 +514,7 @@ def cancel_command(self, command_id: CommandId) -> None: raise ValueError("Not a valid SEA command ID") request = CancelStatementRequest(statement_id=sea_statement_id) - self._http_client._make_request( + self.http_client._make_request( method="POST", path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), data=request.to_dict(), @@ -570,7 +528,7 @@ def close_command(self, command_id: CommandId) -> None: command_id: Command identifier to close Raises: - ValueError: If the command ID is invalid + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: @@ -581,15 +539,24 @@ def close_command(self, command_id: CommandId) -> None: raise ValueError("Not a valid SEA command ID") request = CloseStatementRequest(statement_id=sea_statement_id) - self._http_client._make_request( + self.http_client._make_request( method="DELETE", path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), data=request.to_dict(), ) - def _poll_query(self, command_id: CommandId) -> GetStatementResponse: + def get_query_state(self, command_id: CommandId) -> CommandState: """ - Poll for the current command info. + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: @@ -600,30 +567,14 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse: raise ValueError("Not a valid SEA command ID") request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self._http_client._make_request( + response_data = self.http_client._make_request( method="GET", path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), data=request.to_dict(), ) - response = GetStatementResponse.from_dict(response_data) - - return response - def get_query_state(self, command_id: CommandId) -> CommandState: - """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ProgrammingError: If the command ID is invalid - """ - - response = self._poll_query(command_id) + # Parse the response + response = GetStatementResponse.from_dict(response_data) return response.status.state def get_execution_result( @@ -645,29 +596,37 @@ def get_execution_result( ValueError: If the command ID is invalid """ - response = self._poll_query(command_id) - return self._response_to_result_set(response, cursor) + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") - def get_chunk_links( - self, statement_id: str, chunk_index: int - ) -> List[ExternalLink]: - """ - Get links for chunks starting from the specified index. - Args: - statement_id: The statement ID - chunk_index: The starting chunk index - Returns: - ExternalLink: External link for the chunk - """ + sea_statement_id = command_id.to_sea_statement_id() + if sea_statement_id is None: + raise ValueError("Not a valid SEA command ID") + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self._http_client._make_request( + # Get the statement result + response_data = self.http_client._make_request( method="GET", - path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index), + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) - response = GetChunksResponse.from_dict(response_data) + response = GetStatementResponse.from_dict(response_data) + + # Create and return a SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet - links = response.external_links or [] - return links + execute_response = self._results_message_to_execute_response(response) + + return SeaResultSet( + connection=cursor.connection, + execute_response=execute_response, + result_data=response.result, + manifest=response.manifest, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + ) # == Metadata Operations == diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index 4a2b57327..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -27,7 +27,6 @@ ExecuteStatementResponse, GetStatementResponse, CreateSessionResponse, - GetChunksResponse, ) __all__ = [ @@ -50,5 +49,4 @@ "ExecuteStatementResponse", "GetStatementResponse", "CreateSessionResponse", - "GetChunksResponse", ] diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 5a5580481..302b32d0c 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -4,8 +4,7 @@ These models define the structures used in SEA API responses. """ -import base64 -from typing import Dict, Any, List, Optional +from typing import Dict, Any from dataclasses import dataclass from databricks.sql.backend.types import CommandState @@ -92,11 +91,6 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: ) ) - # Handle attachment field - decode from base64 if present - attachment = result_data.get("attachment") - if attachment is not None: - attachment = base64.b64decode(attachment) - return ResultData( data=result_data.get("data_array"), external_links=external_links, @@ -106,7 +100,7 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: next_chunk_internal_link=result_data.get("next_chunk_internal_link"), row_count=result_data.get("row_count"), row_offset=result_data.get("row_offset"), - attachment=attachment, + attachment=result_data.get("attachment"), ) @@ -160,37 +154,3 @@ class CreateSessionResponse: def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": """Create a CreateSessionResponse from a dictionary.""" return cls(session_id=data.get("session_id", "")) - - -@dataclass -class GetChunksResponse: - """ - Response from getting chunks for a statement. - - The response model can be found in the docs, here: - https://docs.databricks.com/api/workspace/statementexecution/getstatementresultchunkn - """ - - data: Optional[List[List[Any]]] = None - external_links: Optional[List[ExternalLink]] = None - byte_count: Optional[int] = None - chunk_index: Optional[int] = None - next_chunk_index: Optional[int] = None - next_chunk_internal_link: Optional[str] = None - row_count: Optional[int] = None - row_offset: Optional[int] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse": - """Create a GetChunksResponse from a dictionary.""" - result = _parse_result({"result": data}) - return cls( - data=result.data, - external_links=result.external_links, - byte_count=result.byte_count, - chunk_index=result.chunk_index, - next_chunk_index=result.next_chunk_index, - next_chunk_internal_link=result.next_chunk_internal_link, - row_count=result.row_count, - row_offset=result.row_offset, - ) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 85e4236bc..3a1f6ef51 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -1,59 +1,31 @@ from __future__ import annotations from abc import ABC -from typing import List, Optional, Tuple, Union, TYPE_CHECKING +from typing import List, Optional, Tuple -from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager - -from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler - -try: - import pyarrow -except ImportError: - pyarrow = None - -import dateutil - -if TYPE_CHECKING: - from databricks.sql.backend.sea.backend import SeaDatabricksClient - from databricks.sql.backend.sea.models.base import ( - ExternalLink, - ResultData, - ResultManifest, - ) +from databricks.sql.backend.sea.client import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.constants import ResultFormat -from databricks.sql.exc import ProgrammingError, ServerOperationError -from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink -from databricks.sql.types import SSLOptions -from databricks.sql.utils import ( - ArrowQueue, - CloudFetchQueue, - ResultSetQueue, - create_arrow_table_from_arrow_file, -) - -import logging - -logger = logging.getLogger(__name__) +from databricks.sql.exc import ProgrammingError +from databricks.sql.utils import ResultSetQueue class SeaResultSetQueueFactory(ABC): @staticmethod def build_queue( - result_data: ResultData, + sea_result_data: ResultData, manifest: ResultManifest, statement_id: str, - ssl_options: SSLOptions, - description: List[Tuple], - max_download_threads: int, - sea_client: SeaDatabricksClient, - lz4_compressed: bool, + description: List[Tuple] = [], + max_download_threads: Optional[int] = None, + sea_client: Optional[SeaDatabricksClient] = None, + lz4_compressed: bool = False, ) -> ResultSetQueue: """ Factory method to build a result set queue for SEA backend. Args: - result_data (ResultData): Result data from SEA response + sea_result_data (ResultData): Result data from SEA response manifest (ResultManifest): Manifest from SEA response statement_id (str): Statement ID for the query description (List[List[Any]]): Column descriptions @@ -67,30 +39,11 @@ def build_queue( if manifest.format == ResultFormat.JSON_ARRAY.value: # INLINE disposition with JSON_ARRAY format - return JsonQueue(result_data.data) + return JsonQueue(sea_result_data.data) elif manifest.format == ResultFormat.ARROW_STREAM.value: - if result_data.attachment is not None: - arrow_file = ( - ResultSetDownloadHandler._decompress_data(result_data.attachment) - if lz4_compressed - else result_data.attachment - ) - arrow_table = create_arrow_table_from_arrow_file( - arrow_file, description - ) - logger.debug(f"Created arrow table with {arrow_table.num_rows} rows") - return ArrowQueue(arrow_table, manifest.total_row_count) - # EXTERNAL_LINKS disposition - return SeaCloudFetchQueue( - result_data=result_data, - max_download_threads=max_download_threads, - ssl_options=ssl_options, - sea_client=sea_client, - statement_id=statement_id, - total_chunk_count=manifest.total_chunk_count, - lz4_compressed=lz4_compressed, - description=description, + raise NotImplementedError( + "EXTERNAL_LINKS disposition is not implemented for SEA backend" ) raise ProgrammingError("Invalid result format") @@ -119,119 +72,3 @@ def remaining_rows(self) -> List[List[str]]: def close(self): return - - -class SeaCloudFetchQueue(CloudFetchQueue): - """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend.""" - - def __init__( - self, - result_data: ResultData, - max_download_threads: int, - ssl_options: SSLOptions, - sea_client: SeaDatabricksClient, - statement_id: str, - total_chunk_count: int, - lz4_compressed: bool = False, - description: List[Tuple] = [], - ): - """ - Initialize the SEA CloudFetchQueue. - - Args: - initial_links: Initial list of external links to download - schema_bytes: Arrow schema bytes - max_download_threads: Maximum number of download threads - ssl_options: SSL options for downloads - sea_client: SEA client for fetching additional links - statement_id: Statement ID for the query - total_chunk_count: Total number of chunks in the result set - lz4_compressed: Whether the data is LZ4 compressed - description: Column descriptions - """ - - super().__init__( - max_download_threads=max_download_threads, - ssl_options=ssl_options, - schema_bytes=None, - lz4_compressed=lz4_compressed, - description=description, - ) - - self._sea_client = sea_client - self._statement_id = statement_id - self._total_chunk_count = total_chunk_count - - logger.debug( - "SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}".format( - statement_id, total_chunk_count - ) - ) - - initial_links = result_data.external_links or [] - self._chunk_index_to_link = {link.chunk_index: link for link in initial_links} - - # Track the current chunk we're processing - self._current_chunk_index = 0 - first_link = self._chunk_index_to_link.get(self._current_chunk_index, None) - if not first_link: - # possibly an empty response - return None - - # Track the current chunk we're processing - self._current_chunk_index = 0 - # Initialize table and position - self.table = self._create_table_from_link(first_link) - - def _convert_to_thrift_link(self, link: ExternalLink) -> TSparkArrowResultLink: - """Convert SEA external links to Thrift format for compatibility with existing download manager.""" - # Parse the ISO format expiration time - expiry_time = int(dateutil.parser.parse(link.expiration).timestamp()) - return TSparkArrowResultLink( - fileLink=link.external_link, - expiryTime=expiry_time, - rowCount=link.row_count, - bytesNum=link.byte_count, - startRowOffset=link.row_offset, - httpHeaders=link.http_headers or {}, - ) - - def _get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]: - if chunk_index >= self._total_chunk_count: - return None - - if chunk_index not in self._chunk_index_to_link: - links = self._sea_client.get_chunk_links(self._statement_id, chunk_index) - self._chunk_index_to_link.update({l.chunk_index: l for l in links}) - - link = self._chunk_index_to_link.get(chunk_index, None) - if not link: - raise ServerOperationError( - f"Error fetching link for chunk {chunk_index}", - { - "operation-id": self._statement_id, - "diagnostic-info": None, - }, - ) - return link - - def _create_table_from_link( - self, link: ExternalLink - ) -> Union["pyarrow.Table", None]: - """Create a table from a link.""" - - thrift_link = self._convert_to_thrift_link(link) - self.download_manager.add_link(thrift_link) - - row_offset = link.row_offset - arrow_table = self._create_table_at_offset(row_offset) - - return arrow_table - - def _create_next_table(self) -> Union["pyarrow.Table", None]: - """Create next table by retrieving the logical next downloaded file.""" - self._current_chunk_index += 1 - next_chunk_link = self._get_chunk_link(self._current_chunk_index) - if not next_chunk_link: - return None - return self._create_table_from_link(next_chunk_link) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index a6a0a298b..6c7d20636 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -4,6 +4,7 @@ import logging +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter @@ -14,10 +15,10 @@ if TYPE_CHECKING: from databricks.sql.client import Connection - from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.exc import CursorAlreadyClosedError, ProgrammingError, RequestError from databricks.sql.types import Row from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory -from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.backend.types import CommandState, ExecuteResponse from databricks.sql.result_set import ResultSet logger = logging.getLogger(__name__) @@ -30,7 +31,6 @@ def __init__( self, connection: Connection, execute_response: ExecuteResponse, - sea_client: SeaDatabricksClient, result_data: ResultData, manifest: ResultManifest, buffer_size_bytes: int = 104857600, @@ -42,7 +42,6 @@ def __init__( Args: connection: The parent connection execute_response: Response from the execute command - sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch result_data: Result data from SEA response @@ -55,33 +54,36 @@ def __init__( if statement_id is None: raise ValueError("Command ID is not a SEA statement ID") - results_queue = SeaResultSetQueueFactory.build_queue( - result_data, - self.manifest, - statement_id, - ssl_options=connection.session.ssl_options, - description=execute_response.description, - max_download_threads=sea_client.max_download_threads, - sea_client=sea_client, - lz4_compressed=execute_response.lz4_compressed, - ) - # Call parent constructor with common attributes super().__init__( connection=connection, - backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, command_id=execute_response.command_id, status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + # Assert that the backend is of the correct type + assert isinstance( + self.backend, SeaDatabricksClient + ), "Backend must be a SeaDatabricksClient" + + results_queue = SeaResultSetQueueFactory.build_queue( + result_data, + self.manifest, + statement_id, + description=execute_response.description, + max_download_threads=self.backend.max_download_threads, + sea_client=self.backend, + lz4_compressed=execute_response.lz4_compressed, + ) + + # Set the results queue + self.results = results_queue + def _convert_json_types(self, row: List[str]) -> List[Any]: """ Convert string values in the row to appropriate Python types based on column metadata. @@ -160,6 +162,9 @@ def fetchmany_json(self, size: int) -> List[List[str]]: if size < 0: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) self._next_row_index += len(results) @@ -173,6 +178,9 @@ def fetchall_json(self) -> List[List[str]]: Columnar table containing all remaining rows """ + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += len(results) @@ -196,10 +204,10 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": if size < 0: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") - results = self.results.next_n_rows(size) - if isinstance(self.results, JsonQueue): - results = self._convert_json_to_arrow_table(results) + if not isinstance(self.results, JsonQueue): + raise NotImplementedError("fetchmany_arrow only supported for JSON data") + results = self._convert_json_to_arrow_table(self.results.next_n_rows(size)) self._next_row_index += results.num_rows return results @@ -209,10 +217,10 @@ def fetchall_arrow(self) -> "pyarrow.Table": Fetch all remaining rows as an Arrow table. """ - results = self.results.remaining_rows() - if isinstance(self.results, JsonQueue): - results = self._convert_json_to_arrow_table(results) + if not isinstance(self.results, JsonQueue): + raise NotImplementedError("fetchall_arrow only supported for JSON data") + results = self._convert_json_to_arrow_table(self.results.remaining_rows()) self._next_row_index += results.num_rows return results @@ -229,7 +237,7 @@ def fetchone(self) -> Optional[Row]: if isinstance(self.results, JsonQueue): res = self._create_json_table(self.fetchmany_json(1)) else: - res = self._convert_arrow_table(self.fetchmany_arrow(1)) + raise NotImplementedError("fetchone only supported for JSON data") return res[0] if res else None @@ -250,7 +258,7 @@ def fetchmany(self, size: int) -> List[Row]: if isinstance(self.results, JsonQueue): return self._create_json_table(self.fetchmany_json(size)) else: - return self._convert_arrow_table(self.fetchmany_arrow(size)) + raise NotImplementedError("fetchmany only supported for JSON data") def fetchall(self) -> List[Row]: """ @@ -263,4 +271,22 @@ def fetchall(self) -> List[Row]: if isinstance(self.results, JsonQueue): return self._create_json_table(self.fetchall_json()) else: - return self._convert_arrow_table(self.fetchall_arrow()) + raise NotImplementedError("fetchall only supported for JSON data") + + def close(self) -> None: + """ + Close the result set. + + If the connection has not been closed, and the result set has not already + been closed on the server for some other reason, issue a request to the server to close it. + """ + try: + if self.results is not None: + self.results.close() + if self.status != CommandState.CLOSED and self.connection.open: + self.backend.close_command(self.command_id) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + finally: + self.status = CommandState.CLOSED diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 46ce8c98a..402da0de5 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -28,7 +28,7 @@ class ResultFormat(Enum): class ResultDisposition(Enum): """Enum for result disposition values.""" - HYBRID = "INLINE_OR_EXTERNAL_LINKS" + # TODO: add support for hybrid disposition EXTERNAL_LINKS = "EXTERNAL_LINKS" INLINE = "INLINE" diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index ef6c91d7d..9e7a85c56 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -12,14 +12,13 @@ Optional, Any, Callable, - cast, TYPE_CHECKING, ) if TYPE_CHECKING: from databricks.sql.backend.sea.result_set import SeaResultSet -from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse, CommandId, CommandState logger = logging.getLogger(__name__) @@ -45,6 +44,9 @@ def _filter_sea_result_set( """ # Get all remaining rows + if result_set.results is None: + raise RuntimeError("Results queue is not initialized") + all_rows = result_set.results.remaining_rows() # Filter rows @@ -58,9 +60,7 @@ def _filter_sea_result_set( command_id=command_id, status=result_set.status, description=result_set.description, - has_been_closed_server_side=result_set.has_been_closed_server_side, lz4_compressed=result_set.lz4_compressed, - arrow_schema_bytes=result_set._arrow_schema_bytes, is_staging_operation=False, ) @@ -69,7 +69,7 @@ def _filter_sea_result_set( result_data = ResultData(data=filtered_rows, external_links=None) - from databricks.sql.backend.sea.backend import SeaDatabricksClient + from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.result_set import SeaResultSet # Create a new SeaResultSet with the filtered data @@ -79,7 +79,6 @@ def _filter_sea_result_set( filtered_result_set = SeaResultSet( connection=result_set.connection, execute_response=execute_response, - sea_client=cast(SeaDatabricksClient, result_set.backend), result_data=result_data, manifest=manifest, buffer_size_bytes=result_set.buffer_size_bytes, diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 50a256f48..48a7a1ddb 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -165,7 +165,6 @@ def __init__( self._use_arrow_native_complex_types = kwargs.get( "_use_arrow_native_complex_types", True ) - self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True) self._use_arrow_native_timestamps = kwargs.get( "_use_arrow_native_timestamps", True @@ -822,14 +821,17 @@ def _results_message_to_execute_response(self, resp, operation_state): command_id=command_id, status=status, description=description, - has_been_closed_server_side=has_been_closed_server_side, lz4_compressed=lz4_compressed, is_staging_operation=t_result_set_metadata_resp.isStagingOperation, - arrow_schema_bytes=schema_bytes, result_format=t_result_set_metadata_resp.resultFormat, ) - return execute_response, is_direct_results + return ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) def get_execution_result( self, command_id: CommandId, cursor: "Cursor" @@ -882,17 +884,14 @@ def get_execution_result( command_id=command_id, status=status, description=description, - has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - arrow_schema_bytes=schema_bytes, result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -900,6 +899,8 @@ def get_execution_result( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + arrow_schema_bytes=schema_bytes, + has_been_closed_server_side=False, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -1018,9 +1019,12 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1029,7 +1033,6 @@ def execute_command( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, @@ -1037,6 +1040,8 @@ def execute_command( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_catalogs( @@ -1058,9 +1063,12 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1069,7 +1077,6 @@ def get_catalogs( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1077,6 +1084,8 @@ def get_catalogs( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_schemas( @@ -1104,9 +1113,12 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1115,7 +1127,6 @@ def get_schemas( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1123,6 +1134,8 @@ def get_schemas( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_tables( @@ -1154,9 +1167,12 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1165,7 +1181,6 @@ def get_tables( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1173,6 +1188,8 @@ def get_tables( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_columns( @@ -1204,9 +1221,12 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1215,7 +1235,6 @@ def get_columns( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1223,6 +1242,8 @@ def get_columns( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def _handle_execute_response(self, resp, cursor): diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index f6428a187..b188b7ba1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -419,8 +419,6 @@ class ExecuteResponse: command_id: CommandId status: CommandState description: List[Tuple] - has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False - arrow_schema_bytes: Optional[bytes] = None result_format: Optional[Any] = None diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index dfa732c2d..75e89d92a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -99,10 +99,6 @@ def __init__( Connect to a Databricks SQL endpoint or a Databricks cluster. Parameters: - :param use_sea: `bool`, optional (default is False) - Use the SEA backend instead of the Thrift backend. - :param use_hybrid_disposition: `bool`, optional (default is False) - Use the hybrid disposition instead of the inline disposition. :param server_hostname: Databricks instance host name. :param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 12dd0a01f..7e96cd323 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -101,24 +101,6 @@ def _schedule_downloads(self): task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) - def add_link(self, link: TSparkArrowResultLink): - """ - Add more links to the download manager. - - Args: - link: Link to add - """ - - if link.rowCount <= 0: - return - - logger.debug( - "ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format( - link.startRowOffset, link.rowCount - ) - ) - self._pending_links.append(link) - def _shutdown_manager(self): # Clear download handlers and shutdown the thread pool self._pending_links = [] diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index dc279cf91..51128da8c 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -20,6 +20,7 @@ from databricks.sql.utils import ( ColumnTable, ColumnQueue, + ResultSetQueue, ) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse @@ -36,50 +37,41 @@ class ResultSet(ABC): def __init__( self, connection: "Connection", - backend: "DatabricksClient", arraysize: int, buffer_size_bytes: int, command_id: CommandId, status: CommandState, - has_been_closed_server_side: bool = False, is_direct_results: bool = False, - results_queue=None, description: List[Tuple] = [], is_staging_operation: bool = False, lz4_compressed: bool = False, - arrow_schema_bytes: Optional[bytes] = None, ): """ A ResultSet manages the results of a single command. Parameters: :param connection: The parent connection - :param backend: The backend client :param arraysize: The max number of rows to fetch at a time (PEP-249) :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch :param command_id: The command ID :param status: The command status - :param has_been_closed_server_side: Whether the command has been closed on the server :param is_direct_results: Whether the command has more rows - :param results_queue: The results queue :param description: column description of the results :param is_staging_operation: Whether the command is a staging operation """ - self.connection = connection - self.backend = backend - self.arraysize = arraysize - self.buffer_size_bytes = buffer_size_bytes - self._next_row_index = 0 - self.description = description - self.command_id = command_id - self.status = status - self.has_been_closed_server_side = has_been_closed_server_side - self.is_direct_results = is_direct_results - self.results = results_queue - self._is_staging_operation = is_staging_operation - self.lz4_compressed = lz4_compressed - self._arrow_schema_bytes = arrow_schema_bytes + self.connection: "Connection" = connection + self.backend: DatabricksClient = connection.session.backend + self.arraysize: int = arraysize + self.buffer_size_bytes: int = buffer_size_bytes + self._next_row_index: int = 0 + self.description: List[Tuple] = description + self.command_id: CommandId = command_id + self.status: CommandState = status + self.is_direct_results: bool = is_direct_results + self.results: Optional[ResultSetQueue] = None + self._is_staging_operation: bool = is_staging_operation + self.lz4_compressed: bool = lz4_compressed def __iter__(self): while True: @@ -161,27 +153,12 @@ def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass + @abstractmethod def close(self) -> None: """ Close the result set. - - If the connection has not been closed, and the result set has not already - been closed on the server for some other reason, issue a request to the server to close it. """ - try: - self.results.close() - if ( - self.status != CommandState.CLOSED - and not self.has_been_closed_server_side - and self.connection.open - ): - self.backend.close_command(self.command_id) - except RequestError as e: - if isinstance(e.args[1], CursorAlreadyClosedError): - logger.info("Operation was canceled by a prior request") - finally: - self.has_been_closed_server_side = True - self.status = CommandState.CLOSED + pass class ThriftResultSet(ResultSet): @@ -191,7 +168,6 @@ def __init__( self, connection: "Connection", execute_response: "ExecuteResponse", - thrift_client: "ThriftDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, @@ -199,6 +175,8 @@ def __init__( max_download_threads: int = 10, ssl_options=None, is_direct_results: bool = True, + has_been_closed_server_side: bool = False, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -206,7 +184,6 @@ def __init__( Parameters: :param connection: The parent connection :param execute_response: Response from the execute command - :param thrift_client: The ThriftDatabricksClient instance for direct access :param buffer_size_bytes: Buffer size for fetching results :param arraysize: Default number of rows to fetch :param use_cloud_fetch: Whether to use cloud fetch for retrieving results @@ -214,11 +191,15 @@ def __init__( :param max_download_threads: Maximum number of download threads for cloud fetch :param ssl_options: SSL options for cloud fetch :param is_direct_results: Whether there are more rows to fetch + :param has_been_closed_server_side: Whether the command has been closed on the server + :param arrow_schema_bytes: The schema of the result set """ # Initialize ThriftResultSet-specific attributes self._use_cloud_fetch = use_cloud_fetch self.is_direct_results = is_direct_results + self.has_been_closed_server_side = has_been_closed_server_side + self._arrow_schema_bytes = arrow_schema_bytes # Build the results queue if t_row_set is provided results_queue = None @@ -229,7 +210,7 @@ def __init__( results_queue = ThriftResultSetQueueFactory.build_queue( row_set_type=execute_response.result_format, t_row_set=t_row_set, - arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + arrow_schema_bytes=self._arrow_schema_bytes or b"", max_download_threads=max_download_threads, lz4_compressed=execute_response.lz4_compressed, description=execute_response.description, @@ -239,20 +220,26 @@ def __init__( # Call parent constructor with common attributes super().__init__( connection=connection, - backend=thrift_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, command_id=execute_response.command_id, status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, is_direct_results=is_direct_results, - results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + # Assert that the backend is of the correct type + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + assert isinstance( + self.backend, ThriftDatabricksClient + ), "Backend must be a ThriftDatabricksClient" + + # Set the results queue + self.results = results_queue + # Initialize results queue if not provided if not self.results: self._fill_results_buffer() @@ -308,6 +295,10 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows @@ -333,6 +324,9 @@ def fetchmany_columnar(self, size: int): if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows @@ -352,6 +346,9 @@ def fetchmany_columnar(self, size: int): def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += results.num_rows @@ -378,6 +375,9 @@ def fetchall_arrow(self) -> "pyarrow.Table": def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += results.num_rows @@ -394,6 +394,9 @@ def fetchone(self) -> Optional[Row]: Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. """ + if self.results is None: + raise RuntimeError("Results queue is not initialized") + if isinstance(self.results, ColumnQueue): res = self._convert_columnar_table(self.fetchmany_columnar(1)) else: @@ -440,3 +443,26 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] + + def close(self) -> None: + """ + Close the result set. + + If the connection has not been closed, and the result set has not already + been closed on the server for some other reason, issue a request to the server to close it. + """ + try: + if self.results: + self.results.close() + if ( + self.status != CommandState.CLOSED + and not self.has_been_closed_server_side + and self.connection.open + ): + self.backend.close_command(self.command_id) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + finally: + self.has_been_closed_server_side = True + self.status = CommandState.CLOSED diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index b956657ee..6f3f7387a 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -8,7 +8,7 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType @@ -64,7 +64,7 @@ def __init__( base_headers = [("User-Agent", self.useragent_header)] all_headers = (http_headers or []) + base_headers - self.ssl_options = SSLOptions( + self._ssl_options = SSLOptions( # Double negation is generally a bad thing, but we have to keep backward compatibility tls_verify=not kwargs.get( "_tls_no_verify", False @@ -113,7 +113,7 @@ def _create_backend( "http_path": http_path, "http_headers": all_headers, "auth_provider": auth_provider, - "ssl_options": self.ssl_options, + "ssl_options": self._ssl_options, "_use_arrow_native_complex_types": _use_arrow_native_complex_types, **kwargs, } diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 79a376d12..fa0bb1e69 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -1,5 +1,4 @@ from __future__ import annotations -from typing import Dict, List, Optional, Union from dateutil import parser import datetime @@ -9,17 +8,21 @@ from collections.abc import Mapping from decimal import Decimal from enum import Enum -from typing import Dict, List, Optional, Tuple, Union, Sequence +from typing import Any, Dict, List, Optional, Tuple, Union, Sequence import re import lz4.frame +from databricks.sql.backend.sea.client import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest + try: import pyarrow except ImportError: pyarrow = None from databricks.sql import OperationalError +from databricks.sql.exc import ProgrammingError from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager from databricks.sql.thrift_api.TCLIService.ttypes import ( TRowSet, @@ -27,6 +30,7 @@ TSparkRowSetType, ) from databricks.sql.types import SSLOptions +from databricks.sql.backend.types import CommandId from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter @@ -64,7 +68,7 @@ def build_queue( description: List[Tuple] = [], ) -> ResultSetQueue: """ - Factory method to build a result set queue for Thrift backend. + Factory method to build a result set queue. Args: row_set_type (enum): Row set type (Arrow, Column, or URL). @@ -98,7 +102,7 @@ def build_queue( return ColumnQueue(ColumnTable(converted_column_table, column_names)) elif row_set_type == TSparkRowSetType.URL_BASED_SET: - return ThriftCloudFetchQueue( + return CloudFetchQueue( schema_bytes=arrow_schema_bytes, start_row_offset=t_row_set.startRowOffset, result_links=t_row_set.resultLinks, @@ -207,55 +211,70 @@ def close(self): return -class CloudFetchQueue(ResultSetQueue, ABC): - """Base class for cloud fetch queues that handle EXTERNAL_LINKS disposition with ARROW format.""" - +class CloudFetchQueue(ResultSetQueue): def __init__( self, + schema_bytes, max_download_threads: int, ssl_options: SSLOptions, - schema_bytes: Optional[bytes] = None, + start_row_offset: int = 0, + result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, description: List[Tuple] = [], ): """ - Initialize the base CloudFetchQueue. + A queue-like wrapper over CloudFetch arrow batches. - Args: - max_download_threads: Maximum number of download threads - ssl_options: SSL options for downloads - schema_bytes: Arrow schema bytes - lz4_compressed: Whether the data is LZ4 compressed - description: Column descriptions + Attributes: + schema_bytes (bytes): Table schema in bytes. + max_download_threads (int): Maximum number of downloader thread pool threads. + start_row_offset (int): The offset of the first row of the cloud fetch links. + result_links (List[TSparkArrowResultLink]): Links containing the downloadable URL and metadata. + lz4_compressed (bool): Whether the files are lz4 compressed. + description (List[List[Any]]): Hive table schema description. """ self.schema_bytes = schema_bytes self.max_download_threads = max_download_threads + self.start_row_index = start_row_offset + self.result_links = result_links self.lz4_compressed = lz4_compressed self.description = description self._ssl_options = ssl_options - # Table state - self.table = None - self.table_row_index = 0 - - # Initialize download manager + logger.debug( + "Initialize CloudFetch loader, row set start offset: {}, file list:".format( + start_row_offset + ) + ) + if result_links is not None: + for result_link in result_links: + logger.debug( + "- start row offset: {}, row count: {}".format( + result_link.startRowOffset, result_link.rowCount + ) + ) self.download_manager = ResultFileDownloadManager( - links=[], - max_download_threads=max_download_threads, - lz4_compressed=lz4_compressed, - ssl_options=ssl_options, + links=result_links or [], + max_download_threads=self.max_download_threads, + lz4_compressed=self.lz4_compressed, + ssl_options=self._ssl_options, ) + self.table = self._create_next_table() + self.table_row_index = 0 + def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """ Get up to the next n rows of the cloud fetch Arrow dataframes. Args: num_rows (int): Number of rows to retrieve. + Returns: pyarrow.Table """ + if not self.table: logger.debug("CloudFetchQueue: no more rows available") # Return empty pyarrow table to cause retry of fetch @@ -300,14 +319,21 @@ def remaining_rows(self) -> "pyarrow.Table": self.table_row_index = 0 return results - def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]: - """Create next table at the given row offset""" - + def _create_next_table(self) -> Union["pyarrow.Table", None]: + logger.debug( + "CloudFetchQueue: Trying to get downloaded file for row {}".format( + self.start_row_index + ) + ) # Create next table by retrieving the logical next downloaded file, or return None to signal end of queue - downloaded_file = self.download_manager.get_next_downloaded_file(offset) + downloaded_file = self.download_manager.get_next_downloaded_file( + self.start_row_index + ) if not downloaded_file: logger.debug( - "CloudFetchQueue: Cannot find downloaded file for row {}".format(offset) + "CloudFetchQueue: Cannot find downloaded file for row {}".format( + self.start_row_index + ) ) # None signals no more Arrow tables can be built from the remaining handlers if any remain return None @@ -322,94 +348,24 @@ def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]: # At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows assert downloaded_file.row_count == arrow_table.num_rows + self.start_row_index += arrow_table.num_rows - return arrow_table + logger.debug( + "CloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format( + arrow_table.num_rows, self.start_row_index + ) + ) - @abstractmethod - def _create_next_table(self) -> Union["pyarrow.Table", None]: - """Create next table by retrieving the logical next downloaded file.""" - pass + return arrow_table def _create_empty_table(self) -> "pyarrow.Table": - """Create a 0-row table with just the schema bytes.""" - if not self.schema_bytes: - return pyarrow.Table.from_pydict({}) + # Create a 0-row table with just the schema bytes return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) def close(self): self.download_manager._shutdown_manager() -class ThriftCloudFetchQueue(CloudFetchQueue): - """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for Thrift backend.""" - - def __init__( - self, - schema_bytes, - max_download_threads: int, - ssl_options: SSLOptions, - start_row_offset: int = 0, - result_links: Optional[List[TSparkArrowResultLink]] = None, - lz4_compressed: bool = True, - description: List[Tuple] = [], - ): - """ - Initialize the Thrift CloudFetchQueue. - - Args: - schema_bytes: Table schema in bytes - max_download_threads: Maximum number of downloader thread pool threads - ssl_options: SSL options for downloads - start_row_offset: The offset of the first row of the cloud fetch links - result_links: Links containing the downloadable URL and metadata - lz4_compressed: Whether the files are lz4 compressed - description: Hive table schema description - """ - super().__init__( - max_download_threads=max_download_threads, - ssl_options=ssl_options, - schema_bytes=schema_bytes, - lz4_compressed=lz4_compressed, - description=description, - ) - - self.start_row_index = start_row_offset - self.result_links = result_links or [] - - logger.debug( - "Initialize CloudFetch loader, row set start offset: {}, file list:".format( - start_row_offset - ) - ) - if self.result_links: - for result_link in self.result_links: - logger.debug( - "- start row offset: {}, row count: {}".format( - result_link.startRowOffset, result_link.rowCount - ) - ) - self.download_manager.add_link(result_link) - - # Initialize table and position - self.table = self._create_next_table() - - def _create_next_table(self) -> Union["pyarrow.Table", None]: - logger.debug( - "ThriftCloudFetchQueue: Trying to get downloaded file for row {}".format( - self.start_row_index - ) - ) - arrow_table = self._create_table_at_offset(self.start_row_index) - if arrow_table: - self.start_row_index += arrow_table.num_rows - logger.debug( - "ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format( - arrow_table.num_rows, self.start_row_index - ) - ) - return arrow_table - - def _bound(min_x, max_x, x): """Bound x by [min_x, max_x] @@ -712,6 +668,7 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": def convert_to_assigned_datatypes_in_column_table(column_table, description): + converted_column_table = [] for i, col in enumerate(column_table): if description[i][1] == "decimal": diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index aeeb67974..1181ef154 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -2,8 +2,6 @@ import math import time -import pytest - log = logging.getLogger(__name__) @@ -44,14 +42,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size): + "assuming 10K fetch size." ) - @pytest.mark.parametrize( - "extra_params", - [ - {}, - {"use_sea": True}, - ], - ) - def test_query_with_large_wide_result_set(self, extra_params): + def test_query_with_large_wide_result_set(self): resultSize = 300 * 1000 * 1000 # 300 MB width = 8192 # B rows = resultSize // width @@ -61,7 +52,7 @@ def test_query_with_large_wide_result_set(self, extra_params): fetchmany_size = 10 * 1024 * 1024 // width # This is used by PyHive tests to determine the buffer size self.arraysize = 1000 - with self.cursor(extra_params) as cursor: + with self.cursor() as cursor: for lz4_compression in [False, True]: cursor.connection.lz4_compression = lz4_compression uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) @@ -77,14 +68,7 @@ def test_query_with_large_wide_result_set(self, extra_params): assert row[0] == row_id # Verify no rows are dropped in the middle. assert len(row[1]) == 36 - @pytest.mark.parametrize( - "extra_params", - [ - {}, - {"use_sea": True}, - ], - ) - def test_query_with_large_narrow_result_set(self, extra_params): + def test_query_with_large_narrow_result_set(self): resultSize = 300 * 1000 * 1000 # 300 MB width = 8 # sizeof(long) rows = resultSize / width @@ -93,19 +77,12 @@ def test_query_with_large_narrow_result_set(self, extra_params): fetchmany_size = 10 * 1024 * 1024 // width # This is used by PyHive tests to determine the buffer size self.arraysize = 10000000 - with self.cursor(extra_params) as cursor: + with self.cursor() as cursor: cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows)) for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): assert row[0] == row_id - @pytest.mark.parametrize( - "extra_params", - [ - {}, - {"use_sea": True}, - ], - ) - def test_long_running_query(self, extra_params): + def test_long_running_query(self): """Incrementally increase query size until it takes at least 3 minutes, and asserts that the query completes successfully. """ @@ -115,7 +92,7 @@ def test_long_running_query(self, extra_params): duration = -1 scale0 = 10000 scale_factor = 1 - with self.cursor(extra_params) as cursor: + with self.cursor() as cursor: while duration < min_duration: assert scale_factor < 1024, "Detected infinite loop" start = time.time() diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 3fa87b1af..3ceb8c773 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -182,19 +182,10 @@ def test_cloud_fetch(self): class TestPySQLAsyncQueriesSuite(PySQLPytestTestCase): - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - }, - ], - ) - def test_execute_async__long_running(self, extra_params): + def test_execute_async__long_running(self): long_running_query = "SELECT COUNT(*) FROM RANGE(10000 * 16) x JOIN RANGE(10000) y ON FROM_UNIXTIME(x.id * y.id, 'yyyy-MM-dd') LIKE '%not%a%date%'" - with self.cursor(extra_params) as cursor: + with self.cursor() as cursor: cursor.execute_async(long_running_query) ## Polling after every POLLING_INTERVAL seconds @@ -237,16 +228,7 @@ def test_execute_async__small_result(self, extra_params): assert result[0].asDict() == {"1": 1} - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - }, - ], - ) - def test_execute_async__large_result(self, extra_params): + def test_execute_async__large_result(self): x_dimension = 1000 y_dimension = 1000 large_result_query = f""" @@ -260,7 +242,7 @@ def test_execute_async__large_result(self, extra_params): RANGE({y_dimension}) y """ - with self.cursor(extra_params) as cursor: + with self.cursor() as cursor: cursor.execute_async(large_result_query) ## Fake sleep for 5 secs @@ -368,9 +350,6 @@ def test_incorrect_query_throws_exception(self): "use_cloud_fetch": False, "enable_query_result_lz4_compression": False, }, - { - "use_sea": True, - }, ], ) def test_create_table_will_return_empty_result_set(self, extra_params): @@ -581,9 +560,6 @@ def test_get_catalogs(self): "use_cloud_fetch": False, "enable_query_result_lz4_compression": False, }, - { - "use_sea": True, - }, ], ) def test_get_arrow(self, extra_params): @@ -657,9 +633,6 @@ def execute_really_long_query(): "use_cloud_fetch": False, "enable_query_result_lz4_compression": False, }, - { - "use_sea": True, - }, ], ) def test_can_execute_command_after_failure(self, extra_params): @@ -682,9 +655,6 @@ def test_can_execute_command_after_failure(self, extra_params): "use_cloud_fetch": False, "enable_query_result_lz4_compression": False, }, - { - "use_sea": True, - }, ], ) def test_can_execute_command_after_success(self, extra_params): @@ -709,9 +679,6 @@ def generate_multi_row_query(self): "use_cloud_fetch": False, "enable_query_result_lz4_compression": False, }, - { - "use_sea": True, - }, ], ) def test_fetchone(self, extra_params): @@ -756,9 +723,6 @@ def test_fetchall(self, extra_params): "use_cloud_fetch": False, "enable_query_result_lz4_compression": False, }, - { - "use_sea": True, - }, ], ) def test_fetchmany_when_stride_fits(self, extra_params): @@ -779,9 +743,6 @@ def test_fetchmany_when_stride_fits(self, extra_params): "use_cloud_fetch": False, "enable_query_result_lz4_compression": False, }, - { - "use_sea": True, - }, ], ) def test_fetchmany_in_excess(self, extra_params): @@ -802,9 +763,6 @@ def test_fetchmany_in_excess(self, extra_params): "use_cloud_fetch": False, "enable_query_result_lz4_compression": False, }, - { - "use_sea": True, - }, ], ) def test_iterator_api(self, extra_params): @@ -890,9 +848,6 @@ def test_timestamps_arrow(self): "use_cloud_fetch": False, "enable_query_result_lz4_compression": False, }, - { - "use_sea": True, - }, ], ) def test_multi_timestamps_arrow(self, extra_params): diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 3b5072cfe..51430e9e0 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -108,7 +108,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): mock_execute_response.status = ( CommandState.SUCCEEDED if not closed else CommandState.CLOSED ) - mock_execute_response.has_been_closed_server_side = closed mock_execute_response.is_staging_operation = False mock_execute_response.description = [] @@ -127,7 +126,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): real_result_set = ThriftResultSet( connection=connection, execute_response=mock_execute_response, - thrift_client=mock_backend, + has_been_closed_server_side=closed, ) # Mock execute_command to return our real result set @@ -187,22 +186,26 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() - mock_backend = Mock() + mock_backend = Mock(spec=ThriftDatabricksClient) mock_results = Mock() mock_backend.fetch_results.return_value = (Mock(), False) + # Ensure connection appears closed + type(mock_connection).open = PropertyMock(return_value=False) + # Ensure isinstance check passes if needed + mock_backend.__class__ = ThriftDatabricksClient + + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.backend = mock_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) + result_set = ThriftResultSet( connection=mock_connection, execute_response=Mock(), - thrift_client=mock_backend, ) result_set.results = mock_results - # Setup session mock on the mock_connection - mock_session = Mock() - mock_session.open = False - type(mock_connection).session = PropertyMock(return_value=mock_session) - result_set.close() self.assertFalse(mock_backend.close_command.called) @@ -213,16 +216,18 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response = Mock() mock_results_response.has_been_closed_server_side = False mock_connection = Mock() - mock_thrift_backend = Mock() + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_results = Mock() # Setup session mock on the mock_connection mock_session = Mock() mock_session.open = True + mock_session.backend = mock_thrift_backend type(mock_connection).session = PropertyMock(return_value=mock_session) mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( - mock_connection, mock_results_response, mock_thrift_backend + mock_connection, + mock_results_response, ) result_set.results = mock_results @@ -267,10 +272,20 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - mock_backend = Mock() + mock_connection = Mock() + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.fetch_results.return_value = (Mock(), False) + # Ensure isinstance check passes + mock_backend.__class__ = ThriftDatabricksClient + + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.backend = mock_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) - result_set = ThriftResultSet(Mock(), Mock(), mock_backend) + result_set = ThriftResultSet(mock_connection, Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -565,10 +580,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( - self, - mock_client_class, - mock_handle_staging_operation, - mock_execute_response, + self, mock_client_class, mock_handle_staging_operation, mock_execute_response ): # If server sets ExecuteResponse.is_staging_operation True then _handle_staging_operation should be called diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 275d055c9..7dec4e680 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -52,13 +52,13 @@ def get_schema_bytes(): return sink.getvalue().to_pybytes() @patch( - "databricks.sql.utils.ThriftCloudFetchQueue._create_next_table", + "databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=[None, None], ) def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=result_links, max_download_threads=10, @@ -72,7 +72,7 @@ def test_initializer_adds_links(self, mock_create_next_table): def test_initializer_no_links_to_add(self): schema_bytes = MagicMock() result_links = [] - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=result_links, max_download_threads=10, @@ -88,7 +88,7 @@ def test_initializer_no_links_to_add(self): return_value=None, ) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( MagicMock(), result_links=[], max_download_threads=10, @@ -108,7 +108,7 @@ def test_initializer_create_next_table_success( ): mock_create_arrow_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=[], description=description, @@ -129,11 +129,11 @@ def test_initializer_create_next_table_success( assert table.num_rows == 4 assert queue.start_row_index == 8 - @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_0_rows(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=[], description=description, @@ -147,14 +147,13 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): result = queue.next_n_rows(0) assert result.num_rows == 0 assert queue.table_row_index == 0 - # Instead of comparing tables directly, just check the row count - # This avoids issues with empty table schema differences + assert result == self.make_arrow_table()[0:0] - @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_partial_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=[], description=description, @@ -170,11 +169,11 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): assert queue.table_row_index == 3 assert result == self.make_arrow_table()[:3] - @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_more_than_one_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=[], description=description, @@ -195,11 +194,11 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): )[:7] ) - @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=[], description=description, @@ -214,14 +213,11 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): assert result.num_rows == 4 assert result == self.make_arrow_table() - @patch( - "databricks.sql.utils.ThriftCloudFetchQueue._create_next_table", - return_value=None, - ) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=[], description=description, @@ -234,11 +230,11 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): mock_create_next_table.assert_called() assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() - @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=[], description=description, @@ -253,11 +249,11 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) assert result.num_rows == 0 assert result == self.make_arrow_table()[0:0] - @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=[], description=description, @@ -272,11 +268,11 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl assert result.num_rows == 2 assert result == self.make_arrow_table()[2:] - @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=[], description=description, @@ -291,7 +287,7 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): assert result.num_rows == 4 assert result == self.make_arrow_table() - @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_multiple_tables_fully_returned( self, mock_create_next_table ): @@ -301,7 +297,7 @@ def test_remaining_rows_multiple_tables_fully_returned( None, ] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=[], description=description, @@ -322,14 +318,11 @@ def test_remaining_rows_multiple_tables_fully_returned( )[3:] ) - @patch( - "databricks.sql.utils.ThriftCloudFetchQueue._create_next_table", - return_value=None, - ) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() - queue = utils.ThriftCloudFetchQueue( + queue = utils.CloudFetchQueue( schema_bytes, result_links=[], description=description, diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index a649941e1..8643404ba 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -1,6 +1,6 @@ import unittest import pytest -from unittest.mock import Mock +from unittest.mock import Mock, PropertyMock try: import pyarrow as pa @@ -38,12 +38,19 @@ def make_arrow_queue(batch): @staticmethod def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more - schema, arrow_table = FetchTests.make_arrow_table(initial_results) - arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) + arrow_queue = FetchTests.make_arrow_queue(initial_results) + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient - # Create a mock backend that will return the queue when _fill_results_buffer is called mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + # Ensure isinstance check passes + mock_thrift_backend.__class__ = ThriftDatabricksClient + + # Setup mock connection with session.backend + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = mock_thrift_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) num_cols = len(initial_results[0]) if initial_results else 0 description = [ @@ -52,17 +59,16 @@ def make_dummy_result_set_from_initial_results(initial_results): ] rs = ThriftResultSet( - connection=Mock(), + connection=mock_connection, execute_response=ExecuteResponse( command_id=None, status=None, - has_been_closed_server_side=True, description=description, lz4_compressed=True, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, t_row_set=None, + has_been_closed_server_side=True, ) return rs @@ -86,8 +92,19 @@ def fetch_results( return results, batch_index < len(batch_list) + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results = fetch_results + # Ensure isinstance check passes + mock_thrift_backend.__class__ = ThriftDatabricksClient + + # Setup mock connection with session.backend + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = mock_thrift_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) + num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 description = [ @@ -96,16 +113,15 @@ def fetch_results( ] rs = ThriftResultSet( - connection=Mock(), + connection=mock_connection, execute_response=ExecuteResponse( command_id=None, status=None, - has_been_closed_server_side=False, description=description, lz4_compressed=True, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, + has_been_closed_server_side=False, ) return rs diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index ac9648a0e..e4a9e5cdd 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -39,7 +39,8 @@ def make_dummy_result_set_from_initial_results(arrow_table): is_direct_results=False, description=Mock(), command_id=None, - arrow_schema_bytes=arrow_table.schema, + arrow_queue=arrow_queue, + arrow_schema=arrow_table.schema, ), ) rs.description = [ diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 5f920e246..e3dda1818 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -8,7 +8,7 @@ import pytest from unittest.mock import patch, MagicMock, Mock -from databricks.sql.backend.sea.backend import ( +from databricks.sql.backend.sea.client import ( SeaDatabricksClient, _filter_session_configuration, ) @@ -33,7 +33,7 @@ class TestSeaBackend: def mock_http_client(self): """Create a mock HTTP client.""" with patch( - "databricks.sql.backend.sea.backend.SeaHttpClient" + "databricks.sql.backend.sea.client.SeaHttpClient" ) as mock_client_class: mock_client = mock_client_class.return_value yield mock_client @@ -70,12 +70,20 @@ def sea_command_id(self): return CommandId.from_sea_statement_id("test-statement-123") @pytest.fixture - def mock_cursor(self): + def mock_cursor(self, sea_client): """Create a mock cursor.""" cursor = Mock() cursor.active_command_id = None cursor.buffer_size_bytes = 1000 cursor.arraysize = 100 + + # Set up a mock connection with session.backend pointing to the sea_client + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = sea_client + mock_connection.session = mock_session + cursor.connection = mock_connection + return cursor @pytest.fixture @@ -132,7 +140,7 @@ def test_initialization(self, mock_http_client): assert client3.max_download_threads == 5 # Test with invalid HTTP path - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: SeaDatabricksClient( server_hostname="test-server.databricks.com", port=443, @@ -227,7 +235,7 @@ def test_command_execution_sync( mock_http_client._make_request.return_value = execute_response with patch.object( - sea_client, "_response_to_result_set", return_value="mock_result_set" + sea_client, "get_execution_result", return_value="mock_result_set" ) as mock_get_result: result = sea_client.execute_command( operation="SELECT 1", @@ -242,6 +250,9 @@ def test_command_execution_sync( enforce_embedded_schema_correctness=False, ) assert result == "mock_result_set" + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" # Test with invalid session ID with pytest.raises(ValueError) as excinfo: @@ -329,7 +340,7 @@ def test_command_execution_advanced( mock_http_client._make_request.side_effect = [initial_response, poll_response] with patch.object( - sea_client, "_response_to_result_set", return_value="mock_result_set" + sea_client, "get_execution_result", return_value="mock_result_set" ) as mock_get_result: with patch("time.sleep"): result = sea_client.execute_command( @@ -357,7 +368,7 @@ def test_command_execution_advanced( dbsql_param = IntegerParameter(name="param1", value=1) param = dbsql_param.as_tspark_param(named=True) - with patch.object(sea_client, "_response_to_result_set"): + with patch.object(sea_client, "get_execution_result"): sea_client.execute_command( operation="SELECT * FROM table WHERE col = :param1", session_id=sea_session_id, @@ -621,71 +632,6 @@ def test_utility_methods(self, sea_client): assert description[1][1] == "INT" # type_code assert description[1][6] is False # null_ok - def test_filter_session_configuration(self): - """Test that _filter_session_configuration converts all values to strings.""" - session_config = { - "ANSI_MODE": True, - "statement_timeout": 3600, - "TIMEZONE": "UTC", - "enable_photon": False, - "MAX_FILE_PARTITION_BYTES": 128.5, - "unsupported_param": "value", - "ANOTHER_UNSUPPORTED": 42, - } - - result = _filter_session_configuration(session_config) - - # Verify result is not None - assert result is not None - - # Verify all returned values are strings - for key, value in result.items(): - assert isinstance( - value, str - ), f"Value for key '{key}' is not a string: {type(value)}" - - # Verify specific conversions - expected_result = { - "ansi_mode": "True", # boolean True -> "True", key lowercased - "statement_timeout": "3600", # int -> "3600", key lowercased - "timezone": "UTC", # string -> "UTC", key lowercased - "enable_photon": "False", # boolean False -> "False", key lowercased - "max_file_partition_bytes": "128.5", # float -> "128.5", key lowercased - } - - assert result == expected_result - - # Test with None input - assert _filter_session_configuration(None) == {} - - # Test with only unsupported parameters - unsupported_config = { - "unsupported_param1": "value1", - "unsupported_param2": 123, - } - result = _filter_session_configuration(unsupported_config) - assert result == {} - - # Test case insensitivity for keys - case_insensitive_config = { - "ansi_mode": "false", # lowercase key - "STATEMENT_TIMEOUT": 7200, # uppercase key - "TiMeZoNe": "America/New_York", # mixed case key - } - result = _filter_session_configuration(case_insensitive_config) - expected_case_result = { - "ansi_mode": "false", - "statement_timeout": "7200", - "timezone": "America/New_York", - } - assert result == expected_case_result - - # Verify all values are strings in case insensitive test - for key, value in result.items(): - assert isinstance( - value, str - ), f"Value for key '{key}' is not a string: {type(value)}" - def test_results_message_to_execute_response_is_staging_operation(self, sea_client): """Test that is_staging_operation is correctly set from manifest.is_volume_operation.""" # Test when is_volume_operation is True @@ -955,67 +901,3 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor): cursor=mock_cursor, ) assert "Catalog name is required for get_columns" in str(excinfo.value) - - def test_get_chunk_links(self, sea_client, mock_http_client, sea_command_id): - """Test get_chunk_links method when links are available.""" - # Setup mock response - mock_response = { - "external_links": [ - { - "external_link": "https://example.com/data/chunk0", - "expiration": "2025-07-03T05:51:18.118009", - "row_count": 100, - "byte_count": 1024, - "row_offset": 0, - "chunk_index": 0, - "next_chunk_index": 1, - "http_headers": {"Authorization": "Bearer token123"}, - } - ] - } - mock_http_client._make_request.return_value = mock_response - - # Call the method - results = sea_client.get_chunk_links("test-statement-123", 0) - - # Verify the HTTP client was called correctly - mock_http_client._make_request.assert_called_once_with( - method="GET", - path=sea_client.CHUNK_PATH_WITH_ID_AND_INDEX.format( - "test-statement-123", 0 - ), - ) - - # Verify the results - assert isinstance(results, list) - assert len(results) == 1 - result = results[0] - assert result.external_link == "https://example.com/data/chunk0" - assert result.expiration == "2025-07-03T05:51:18.118009" - assert result.row_count == 100 - assert result.byte_count == 1024 - assert result.row_offset == 0 - assert result.chunk_index == 0 - assert result.next_chunk_index == 1 - assert result.http_headers == {"Authorization": "Bearer token123"} - - def test_get_chunk_links_empty(self, sea_client, mock_http_client): - """Test get_chunk_links when no links are returned (empty list).""" - # Setup mock response with no matching chunk - mock_response = {"external_links": []} - mock_http_client._make_request.return_value = mock_response - - # Call the method - results = sea_client.get_chunk_links("test-statement-123", 0) - - # Verify the HTTP client was called correctly - mock_http_client._make_request.assert_called_once_with( - method="GET", - path=sea_client.CHUNK_PATH_WITH_ID_AND_INDEX.format( - "test-statement-123", 0 - ), - ) - - # Verify the results are empty - assert isinstance(results, list) - assert results == [] diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 4e5af0658..93d3dc4d7 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -1,28 +1,15 @@ """ -Tests for SEA-related queue classes. +Tests for SEA-related queue classes in utils.py. -This module contains tests for the JsonQueue, SeaResultSetQueueFactory, and SeaCloudFetchQueue classes. -It also tests the Hybrid disposition which can create either ArrowQueue or SeaCloudFetchQueue based on -whether attachment is set. +This module contains tests for the JsonQueue and SeaResultSetQueueFactory classes. """ import pytest -from unittest.mock import Mock, patch +from unittest.mock import Mock, MagicMock, patch -from databricks.sql.backend.sea.queue import ( - JsonQueue, - SeaResultSetQueueFactory, - SeaCloudFetchQueue, -) -from databricks.sql.backend.sea.models.base import ( - ResultData, - ResultManifest, - ExternalLink, -) +from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.constants import ResultFormat -from databricks.sql.exc import ProgrammingError, ServerOperationError -from databricks.sql.types import SSLOptions -from databricks.sql.utils import ArrowQueue class TestJsonQueue: @@ -46,13 +33,6 @@ def test_init(self, sample_data): assert queue.cur_row_index == 0 assert queue.num_rows == len(sample_data) - def test_init_with_none(self): - """Test initialization with None data.""" - queue = JsonQueue(None) - assert queue.data_array == [] - assert queue.cur_row_index == 0 - assert queue.num_rows == 0 - def test_next_n_rows_partial(self, sample_data): """Test fetching a subset of rows.""" queue = JsonQueue(sample_data) @@ -74,189 +54,41 @@ def test_next_n_rows_more_than_available(self, sample_data): assert result == sample_data assert queue.cur_row_index == len(sample_data) - def test_next_n_rows_zero(self, sample_data): - """Test fetching zero rows.""" - queue = JsonQueue(sample_data) - result = queue.next_n_rows(0) - assert result == [] - assert queue.cur_row_index == 0 - - def test_remaining_rows(self, sample_data): - """Test fetching all remaining rows.""" + def test_next_n_rows_after_partial(self, sample_data): + """Test fetching rows after a partial fetch.""" queue = JsonQueue(sample_data) - - # Fetch some rows first - queue.next_n_rows(2) - - # Now fetch remaining - result = queue.remaining_rows() - assert result == sample_data[2:] - assert queue.cur_row_index == len(sample_data) + queue.next_n_rows(2) # Fetch first 2 rows + result = queue.next_n_rows(2) # Fetch next 2 rows + assert result == sample_data[2:4] + assert queue.cur_row_index == 4 def test_remaining_rows_all(self, sample_data): - """Test fetching all remaining rows from the start.""" + """Test fetching all remaining rows at once.""" queue = JsonQueue(sample_data) result = queue.remaining_rows() assert result == sample_data assert queue.cur_row_index == len(sample_data) - def test_remaining_rows_empty(self, sample_data): - """Test fetching remaining rows when none are left.""" + def test_remaining_rows_after_partial(self, sample_data): + """Test fetching remaining rows after a partial fetch.""" queue = JsonQueue(sample_data) - - # Fetch all rows first - queue.next_n_rows(len(sample_data)) - - # Now fetch remaining (should be empty) - result = queue.remaining_rows() - assert result == [] + queue.next_n_rows(2) # Fetch first 2 rows + result = queue.remaining_rows() # Fetch remaining rows + assert result == sample_data[2:] assert queue.cur_row_index == len(sample_data) + def test_empty_data(self): + """Test with empty data array.""" + queue = JsonQueue([]) + assert queue.next_n_rows(10) == [] + assert queue.remaining_rows() == [] + assert queue.cur_row_index == 0 + assert queue.num_rows == 0 + class TestSeaResultSetQueueFactory: """Test suite for the SeaResultSetQueueFactory class.""" - @pytest.fixture - def json_manifest(self): - """Create a JSON manifest for testing.""" - return ResultManifest( - format=ResultFormat.JSON_ARRAY.value, - schema={}, - total_row_count=5, - total_byte_count=1000, - total_chunk_count=1, - ) - - @pytest.fixture - def arrow_manifest(self): - """Create an Arrow manifest for testing.""" - return ResultManifest( - format=ResultFormat.ARROW_STREAM.value, - schema={}, - total_row_count=5, - total_byte_count=1000, - total_chunk_count=1, - ) - - @pytest.fixture - def invalid_manifest(self): - """Create an invalid manifest for testing.""" - return ResultManifest( - format="INVALID_FORMAT", - schema={}, - total_row_count=5, - total_byte_count=1000, - total_chunk_count=1, - ) - - @pytest.fixture - def sample_data(self): - """Create sample result data.""" - return [ - ["value1", "1", "true"], - ["value2", "2", "false"], - ] - - @pytest.fixture - def ssl_options(self): - """Create SSL options for testing.""" - return SSLOptions(tls_verify=True) - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - client = Mock() - client.max_download_threads = 10 - return client - - @pytest.fixture - def description(self): - """Create column descriptions.""" - return [ - ("col1", "string", None, None, None, None, None), - ("col2", "int", None, None, None, None, None), - ("col3", "boolean", None, None, None, None, None), - ] - - def test_build_queue_json_array(self, json_manifest, sample_data): - """Test building a JSON array queue.""" - result_data = ResultData(data=sample_data) - - queue = SeaResultSetQueueFactory.build_queue( - result_data=result_data, - manifest=json_manifest, - statement_id="test-statement", - ssl_options=SSLOptions(), - description=[], - max_download_threads=10, - sea_client=Mock(), - lz4_compressed=False, - ) - - assert isinstance(queue, JsonQueue) - assert queue.data_array == sample_data - - def test_build_queue_arrow_stream( - self, arrow_manifest, ssl_options, mock_sea_client, description - ): - """Test building an Arrow stream queue.""" - external_links = [ - ExternalLink( - external_link="https://example.com/data/chunk0", - expiration="2025-07-03T05:51:18.118009", - row_count=100, - byte_count=1024, - row_offset=0, - chunk_index=0, - next_chunk_index=1, - http_headers={"Authorization": "Bearer token123"}, - ) - ] - result_data = ResultData(data=None, external_links=external_links) - - with patch( - "databricks.sql.backend.sea.queue.ResultFileDownloadManager" - ), patch.object( - SeaCloudFetchQueue, "_create_table_from_link", return_value=None - ): - queue = SeaResultSetQueueFactory.build_queue( - result_data=result_data, - manifest=arrow_manifest, - statement_id="test-statement", - ssl_options=ssl_options, - description=description, - max_download_threads=10, - sea_client=mock_sea_client, - lz4_compressed=False, - ) - - assert isinstance(queue, SeaCloudFetchQueue) - - def test_build_queue_invalid_format(self, invalid_manifest): - """Test building a queue with invalid format.""" - result_data = ResultData(data=[]) - - with pytest.raises(ProgrammingError, match="Invalid result format"): - SeaResultSetQueueFactory.build_queue( - result_data=result_data, - manifest=invalid_manifest, - statement_id="test-statement", - ssl_options=SSLOptions(), - description=[], - max_download_threads=10, - sea_client=Mock(), - lz4_compressed=False, - ) - - -class TestSeaCloudFetchQueue: - """Test suite for the SeaCloudFetchQueue class.""" - - @pytest.fixture - def ssl_options(self): - """Create SSL options for testing.""" - return SSLOptions(tls_verify=True) - @pytest.fixture def mock_sea_client(self): """Create a mock SEA client.""" @@ -265,317 +97,86 @@ def mock_sea_client(self): return client @pytest.fixture - def description(self): - """Create column descriptions.""" + def mock_description(self): + """Create a mock column description.""" return [ ("col1", "string", None, None, None, None, None), ("col2", "int", None, None, None, None, None), ("col3", "boolean", None, None, None, None, None), ] - @pytest.fixture - def sample_external_link(self): - """Create a sample external link.""" - return ExternalLink( - external_link="https://example.com/data/chunk0", - expiration="2025-07-03T05:51:18.118009", - row_count=100, - byte_count=1024, - row_offset=0, - chunk_index=0, - next_chunk_index=1, - http_headers={"Authorization": "Bearer token123"}, - ) - - @pytest.fixture - def sample_external_link_no_headers(self): - """Create a sample external link without headers.""" - return ExternalLink( - external_link="https://example.com/data/chunk0", - expiration="2025-07-03T05:51:18.118009", - row_count=100, - byte_count=1024, - row_offset=0, - chunk_index=0, - next_chunk_index=1, - http_headers=None, - ) - - def test_convert_to_thrift_link(self, sample_external_link): - """Test conversion of ExternalLink to TSparkArrowResultLink.""" - queue = Mock(spec=SeaCloudFetchQueue) - - # Call the method directly - result = SeaCloudFetchQueue._convert_to_thrift_link(queue, sample_external_link) - - # Verify the conversion - assert result.fileLink == sample_external_link.external_link - assert result.rowCount == sample_external_link.row_count - assert result.bytesNum == sample_external_link.byte_count - assert result.startRowOffset == sample_external_link.row_offset - assert result.httpHeaders == sample_external_link.http_headers - - def test_convert_to_thrift_link_no_headers(self, sample_external_link_no_headers): - """Test conversion of ExternalLink with no headers to TSparkArrowResultLink.""" - queue = Mock(spec=SeaCloudFetchQueue) - - # Call the method directly - result = SeaCloudFetchQueue._convert_to_thrift_link( - queue, sample_external_link_no_headers - ) - - # Verify the conversion - assert result.fileLink == sample_external_link_no_headers.external_link - assert result.rowCount == sample_external_link_no_headers.row_count - assert result.bytesNum == sample_external_link_no_headers.byte_count - assert result.startRowOffset == sample_external_link_no_headers.row_offset - assert result.httpHeaders == {} - - @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") - @patch("databricks.sql.backend.sea.queue.logger") - def test_init_with_valid_initial_link( - self, - mock_logger, - mock_download_manager_class, - mock_sea_client, - ssl_options, - description, - sample_external_link, - ): - """Test initialization with valid initial link.""" - # Create a queue with valid initial link - with patch.object( - SeaCloudFetchQueue, "_create_table_from_link", return_value=None - ): - queue = SeaCloudFetchQueue( - result_data=ResultData(external_links=[sample_external_link]), - max_download_threads=5, - ssl_options=ssl_options, - sea_client=mock_sea_client, - statement_id="test-statement-123", - total_chunk_count=1, - lz4_compressed=False, - description=description, - ) - - # Verify debug message was logged - mock_logger.debug.assert_called_with( - "SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}".format( - "test-statement-123", 1 - ) - ) - - # Verify attributes - assert queue._statement_id == "test-statement-123" - assert queue._current_chunk_index == 0 - - @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") - @patch("databricks.sql.backend.sea.queue.logger") - def test_init_no_initial_links( - self, - mock_logger, - mock_download_manager_class, - mock_sea_client, - ssl_options, - description, - ): - """Test initialization with no initial links.""" - # Create a queue with empty initial links - queue = SeaCloudFetchQueue( - result_data=ResultData(external_links=[]), - max_download_threads=5, - ssl_options=ssl_options, - sea_client=mock_sea_client, - statement_id="test-statement-123", - total_chunk_count=0, - lz4_compressed=False, - description=description, - ) - assert queue.table is None - - @patch("databricks.sql.backend.sea.queue.logger") - def test_create_next_table_success(self, mock_logger): - """Test _create_next_table with successful table creation.""" - # Create a queue instance without initializing - queue = Mock(spec=SeaCloudFetchQueue) - queue._current_chunk_index = 0 - queue.download_manager = Mock() - - # Mock the dependencies - mock_table = Mock() - mock_chunk_link = Mock() - queue._get_chunk_link = Mock(return_value=mock_chunk_link) - queue._create_table_from_link = Mock(return_value=mock_table) - - # Call the method directly - result = SeaCloudFetchQueue._create_next_table(queue) - - # Verify the chunk index was incremented - assert queue._current_chunk_index == 1 - - # Verify the chunk link was retrieved - queue._get_chunk_link.assert_called_once_with(1) - - # Verify the table was created from the link - queue._create_table_from_link.assert_called_once_with(mock_chunk_link) - - # Verify the result is the table - assert result == mock_table - - -class TestHybridDisposition: - """Test suite for the Hybrid disposition handling in SeaResultSetQueueFactory.""" - - @pytest.fixture - def arrow_manifest(self): - """Create an Arrow manifest for testing.""" + def _create_empty_manifest(self, format: ResultFormat): return ResultManifest( - format=ResultFormat.ARROW_STREAM.value, + format=format.value, schema={}, - total_row_count=5, - total_byte_count=1000, - total_chunk_count=1, + total_row_count=-1, + total_byte_count=-1, + total_chunk_count=-1, ) - @pytest.fixture - def description(self): - """Create column descriptions.""" - return [ - ("col1", "string", None, None, None, None, None), - ("col2", "int", None, None, None, None, None), - ("col3", "boolean", None, None, None, None, None), + def test_build_queue_with_inline_data(self, mock_sea_client, mock_description): + """Test building a queue with inline JSON data.""" + # Create sample data for inline JSON result + data = [ + ["value1", "1", "true"], + ["value2", "2", "false"], ] - @pytest.fixture - def ssl_options(self): - """Create SSL options for testing.""" - return SSLOptions(tls_verify=True) - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - client = Mock() - client.max_download_threads = 10 - return client - - @patch("databricks.sql.backend.sea.queue.create_arrow_table_from_arrow_file") - def test_hybrid_disposition_with_attachment( - self, - mock_create_table, - arrow_manifest, - description, - ssl_options, - mock_sea_client, - ): - """Test that ArrowQueue is created when attachment is present.""" - # Create mock arrow table - mock_arrow_table = Mock() - mock_arrow_table.num_rows = 5 - mock_create_table.return_value = mock_arrow_table + # Create a ResultData object with inline data + result_data = ResultData(data=data, external_links=None, row_count=len(data)) - # Create result data with attachment - attachment_data = b"mock_arrow_data" - result_data = ResultData(attachment=attachment_data) + # Create a manifest (not used for inline data) + manifest = self._create_empty_manifest(ResultFormat.JSON_ARRAY) - # Build queue + # Build the queue queue = SeaResultSetQueueFactory.build_queue( - result_data=result_data, - manifest=arrow_manifest, - statement_id="test-statement", - ssl_options=ssl_options, - description=description, - max_download_threads=10, + result_data, + manifest, + "test-statement-123", + description=mock_description, sea_client=mock_sea_client, - lz4_compressed=False, ) - # Verify ArrowQueue was created - assert isinstance(queue, ArrowQueue) - mock_create_table.assert_called_once_with(attachment_data, description) - - @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") - @patch.object(SeaCloudFetchQueue, "_create_table_from_link", return_value=None) - def test_hybrid_disposition_with_external_links( - self, - mock_create_table, - mock_download_manager, - arrow_manifest, - description, - ssl_options, - mock_sea_client, - ): - """Test that SeaCloudFetchQueue is created when attachment is None but external links are present.""" - # Create external links - external_links = [ - ExternalLink( - external_link="https://example.com/data/chunk0", - expiration="2025-07-03T05:51:18.118009", - row_count=100, - byte_count=1024, - row_offset=0, - chunk_index=0, - next_chunk_index=1, - http_headers={"Authorization": "Bearer token123"}, - ) - ] + # Verify the queue is a JsonQueue with the correct data + assert isinstance(queue, JsonQueue) + assert queue.data_array == data + assert queue.num_rows == len(data) - # Create result data with external links but no attachment - result_data = ResultData(external_links=external_links, attachment=None) + def test_build_queue_with_empty_data(self, mock_sea_client, mock_description): + """Test building a queue with empty data.""" + # Create a ResultData object with no data + result_data = ResultData(data=[], external_links=None, row_count=0) - # Build queue + # Build the queue queue = SeaResultSetQueueFactory.build_queue( - result_data=result_data, - manifest=arrow_manifest, - statement_id="test-statement", - ssl_options=ssl_options, - description=description, - max_download_threads=10, + result_data, + self._create_empty_manifest(ResultFormat.JSON_ARRAY), + "test-statement-123", + description=mock_description, sea_client=mock_sea_client, - lz4_compressed=False, ) - # Verify SeaCloudFetchQueue was created - assert isinstance(queue, SeaCloudFetchQueue) - mock_create_table.assert_called_once() - - @patch("databricks.sql.backend.sea.queue.ResultSetDownloadHandler._decompress_data") - @patch("databricks.sql.backend.sea.queue.create_arrow_table_from_arrow_file") - def test_hybrid_disposition_with_compressed_attachment( - self, - mock_create_table, - mock_decompress, - arrow_manifest, - description, - ssl_options, - mock_sea_client, - ): - """Test that ArrowQueue is created with decompressed data when attachment is present and lz4_compressed is True.""" - # Create mock arrow table - mock_arrow_table = Mock() - mock_arrow_table.num_rows = 5 - mock_create_table.return_value = mock_arrow_table - - # Setup decompression mock - compressed_data = b"compressed_data" - decompressed_data = b"decompressed_data" - mock_decompress.return_value = decompressed_data - - # Create result data with attachment - result_data = ResultData(attachment=compressed_data) + # Verify the queue is a JsonQueue with empty data + assert isinstance(queue, JsonQueue) + assert queue.data_array == [] + assert queue.num_rows == 0 - # Build queue with lz4_compressed=True - queue = SeaResultSetQueueFactory.build_queue( - result_data=result_data, - manifest=arrow_manifest, - statement_id="test-statement", - ssl_options=ssl_options, - description=description, - max_download_threads=10, - sea_client=mock_sea_client, - lz4_compressed=True, + def test_build_queue_with_external_links(self, mock_sea_client, mock_description): + """Test building a queue with external links raises NotImplementedError.""" + # Create a ResultData object with external links + result_data = ResultData( + data=None, external_links=["link1", "link2"], row_count=10 ) - # Verify ArrowQueue was created with decompressed data - assert isinstance(queue, ArrowQueue) - mock_decompress.assert_called_once_with(compressed_data) - mock_create_table.assert_called_once_with(decompressed_data, description) + # Verify that NotImplementedError is raised + with pytest.raises( + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", + ): + SeaResultSetQueueFactory.build_queue( + result_data, + self._create_empty_manifest(ResultFormat.ARROW_STREAM), + "test-statement-123", + description=mock_description, + sea_client=mock_sea_client, + ) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index dbf81ba7c..25ac23133 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -6,12 +6,7 @@ """ import pytest -from unittest.mock import Mock, patch - -try: - import pyarrow -except ImportError: - pyarrow = None +from unittest.mock import Mock from databricks.sql.backend.sea.result_set import SeaResultSet, Row from databricks.sql.backend.sea.queue import JsonQueue @@ -28,16 +23,20 @@ def mock_connection(self): """Create a mock connection.""" connection = Mock() connection.open = True - connection.session = Mock() - connection.session.ssl_options = Mock() - return connection - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - client = Mock() - client.max_download_threads = 10 - return client + # Mock the session.backend to return a SeaDatabricksClient + mock_session = Mock() + from databricks.sql.backend.sea.client import SeaDatabricksClient + + mock_backend = Mock(spec=SeaDatabricksClient) + mock_backend.max_download_threads = 10 + mock_backend.close_command = Mock() + # Ensure isinstance check passes + mock_backend.__class__ = SeaDatabricksClient + mock_session.backend = mock_backend + connection.session = mock_session + + return connection @pytest.fixture def execute_response(self): @@ -80,9 +79,7 @@ def _create_empty_manifest(self, format: ResultFormat): ) @pytest.fixture - def result_set_with_data( - self, mock_connection, mock_sea_client, execute_response, sample_data - ): + def result_set_with_data(self, mock_connection, execute_response, sample_data): """Create a SeaResultSet with sample data.""" # Create ResultData with inline data result_data = ResultData( @@ -90,224 +87,88 @@ def result_set_with_data( ) # Initialize SeaResultSet with result data - with patch( - "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue", - return_value=JsonQueue(sample_data), - ): - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - result_data=result_data, - manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), - buffer_size_bytes=1000, - arraysize=100, - ) - - return result_set - - @pytest.fixture - def mock_arrow_queue(self): - """Create a mock Arrow queue.""" - queue = Mock() - if pyarrow is not None: - queue.next_n_rows.return_value = Mock(spec=pyarrow.Table) - queue.next_n_rows.return_value.num_rows = 0 - queue.remaining_rows.return_value = Mock(spec=pyarrow.Table) - queue.remaining_rows.return_value.num_rows = 0 - return queue - - @pytest.fixture - def mock_json_queue(self): - """Create a mock JSON queue.""" - queue = Mock(spec=JsonQueue) - queue.next_n_rows.return_value = [] - queue.remaining_rows.return_value = [] - return queue - - @pytest.fixture - def result_set_with_arrow_queue( - self, mock_connection, mock_sea_client, execute_response, mock_arrow_queue - ): - """Create a SeaResultSet with an Arrow queue.""" - # Create ResultData with external links - result_data = ResultData(data=None, external_links=[], row_count=0) - - # Initialize SeaResultSet with result data - with patch( - "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue", - return_value=mock_arrow_queue, - ): - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - result_data=result_data, - manifest=ResultManifest( - format=ResultFormat.ARROW_STREAM.value, - schema={}, - total_row_count=0, - total_byte_count=0, - total_chunk_count=0, - ), - buffer_size_bytes=1000, - arraysize=100, - ) + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + result_data=result_data, + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = JsonQueue(sample_data) return result_set @pytest.fixture - def result_set_with_json_queue( - self, mock_connection, mock_sea_client, execute_response, mock_json_queue - ): - """Create a SeaResultSet with a JSON queue.""" - # Create ResultData with inline data - result_data = ResultData(data=[], external_links=None, row_count=0) - - # Initialize SeaResultSet with result data - with patch( - "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue", - return_value=mock_json_queue, - ): - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - result_data=result_data, - manifest=ResultManifest( - format=ResultFormat.JSON_ARRAY.value, - schema={}, - total_row_count=0, - total_byte_count=0, - total_chunk_count=0, - ), - buffer_size_bytes=1000, - arraysize=100, - ) - - return result_set + def json_queue(self, sample_data): + """Create a JsonQueue with sample data.""" + return JsonQueue(sample_data) - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): + def test_init_with_execute_response(self, mock_connection, execute_response): """Test initializing SeaResultSet with an execute response.""" - with patch( - "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue" - ): - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - result_data=ResultData(data=[]), - manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), - buffer_size_bytes=1000, - arraysize=100, - ) + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) # Verify basic properties assert result_set.command_id == execute_response.command_id assert result_set.status == CommandState.SUCCEEDED assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client assert result_set.buffer_size_bytes == 1000 assert result_set.arraysize == 100 assert result_set.description == execute_response.description - def test_init_with_invalid_command_id( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with invalid command ID.""" - # Mock the command ID to return None - mock_command_id = Mock() - mock_command_id.to_sea_statement_id.return_value = None - execute_response.command_id = mock_command_id - - with pytest.raises(ValueError, match="Command ID is not a SEA statement ID"): - SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - result_data=ResultData(data=[]), - manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), - buffer_size_bytes=1000, - arraysize=100, - ) - - def test_close(self, mock_connection, mock_sea_client, execute_response): + def test_close(self, mock_connection, execute_response): """Test closing a result set.""" - with patch( - "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue" - ): - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - result_data=ResultData(data=[]), - manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), - buffer_size_bytes=1000, - arraysize=100, - ) + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) # Close the result set result_set.close() # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set that has already been closed server-side.""" - with patch( - "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue" - ): - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - result_data=ResultData(data=[]), - manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True + mock_connection.session.backend.close_command.assert_called_once_with( + result_set.command_id + ) assert result_set.status == CommandState.CLOSED - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, execute_response - ): + def test_close_when_connection_closed(self, mock_connection, execute_response): """Test closing a result set when the connection is closed.""" mock_connection.open = False - with patch( - "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue" - ): - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - result_data=ResultData(data=[]), - manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), - buffer_size_bytes=1000, - arraysize=100, - ) + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) # Close the result set result_set.close() # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True + mock_connection.session.backend.close_command.assert_not_called() assert result_set.status == CommandState.CLOSED + def test_init_with_result_data(self, result_set_with_data, sample_data): + """Test initializing SeaResultSet with result data.""" + # Verify the results queue was created correctly + assert isinstance(result_set_with_data.results, JsonQueue) + assert result_set_with_data.results.data_array == sample_data + assert result_set_with_data.results.num_rows == len(sample_data) + def test_convert_json_types(self, result_set_with_data, sample_data): """Test the _convert_json_types method.""" # Call _convert_json_types @@ -318,27 +179,6 @@ def test_convert_json_types(self, result_set_with_data, sample_data): assert converted_row[1] == 1 # "1" converted to int assert converted_row[2] is True # "true" converted to boolean - @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") - def test_convert_json_to_arrow_table(self, result_set_with_data, sample_data): - """Test the _convert_json_to_arrow_table method.""" - # Call _convert_json_to_arrow_table - result_table = result_set_with_data._convert_json_to_arrow_table(sample_data) - - # Verify the result - assert isinstance(result_table, pyarrow.Table) - assert result_table.num_rows == len(sample_data) - assert result_table.num_columns == 3 - - @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") - def test_convert_json_to_arrow_table_empty(self, result_set_with_data): - """Test the _convert_json_to_arrow_table method with empty data.""" - # Call _convert_json_to_arrow_table with empty data - result_table = result_set_with_data._convert_json_to_arrow_table([]) - - # Verify the result - assert isinstance(result_table, pyarrow.Table) - assert result_table.num_rows == 0 - def test_create_json_table(self, result_set_with_data, sample_data): """Test the _create_json_table method.""" # Call _create_json_table @@ -368,13 +208,6 @@ def test_fetchmany_json(self, result_set_with_data): assert len(result) == 1 # Only one row left assert result_set_with_data._next_row_index == 5 - def test_fetchmany_json_negative_size(self, result_set_with_data): - """Test the fetchmany_json method with negative size.""" - with pytest.raises( - ValueError, match="size argument for fetchmany is -1 but must be >= 0" - ): - result_set_with_data.fetchmany_json(-1) - def test_fetchall_json(self, result_set_with_data, sample_data): """Test the fetchall_json method.""" # Test fetching all rows @@ -387,32 +220,6 @@ def test_fetchall_json(self, result_set_with_data, sample_data): assert result == [] assert result_set_with_data._next_row_index == len(sample_data) - @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") - def test_fetchmany_arrow(self, result_set_with_data, sample_data): - """Test the fetchmany_arrow method.""" - # Test with JSON queue (should convert to Arrow) - result = result_set_with_data.fetchmany_arrow(2) - assert isinstance(result, pyarrow.Table) - assert result.num_rows == 2 - assert result_set_with_data._next_row_index == 2 - - @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") - def test_fetchmany_arrow_negative_size(self, result_set_with_data): - """Test the fetchmany_arrow method with negative size.""" - with pytest.raises( - ValueError, match="size argument for fetchmany is -1 but must be >= 0" - ): - result_set_with_data.fetchmany_arrow(-1) - - @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") - def test_fetchall_arrow(self, result_set_with_data, sample_data): - """Test the fetchall_arrow method.""" - # Test with JSON queue (should convert to Arrow) - result = result_set_with_data.fetchall_arrow() - assert isinstance(result, pyarrow.Table) - assert result.num_rows == len(sample_data) - assert result_set_with_data._next_row_index == len(sample_data) - def test_fetchone(self, result_set_with_data): """Test the fetchone method.""" # Test fetching one row at a time @@ -482,133 +289,59 @@ def test_iteration(self, result_set_with_data, sample_data): assert rows[0].col2 == 1 assert rows[0].col3 is True - def test_is_staging_operation( - self, mock_connection, mock_sea_client, execute_response + def test_fetchmany_arrow_not_implemented( + self, mock_connection, execute_response, sample_data ): - """Test the is_staging_operation property.""" - # Set is_staging_operation to True - execute_response.is_staging_operation = True + """Test that fetchmany_arrow raises NotImplementedError for non-JSON data.""" - with patch( - "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue" + # Test that NotImplementedError is raised + with pytest.raises( + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", ): - # Create a result set + # Create a result set without JSON data result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, - result_data=ResultData(data=[]), - manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + result_data=ResultData(data=None, external_links=[]), + manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), buffer_size_bytes=1000, arraysize=100, ) - # Test the property - assert result_set.is_staging_operation is True - - # Edge case tests - @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") - def test_fetchone_empty_arrow_queue(self, result_set_with_arrow_queue): - """Test fetchone with an empty Arrow queue.""" - # Setup _convert_arrow_table to return empty list - result_set_with_arrow_queue._convert_arrow_table = Mock(return_value=[]) - - # Call fetchone - result = result_set_with_arrow_queue.fetchone() - - # Verify result is None - assert result is None - - # Verify _convert_arrow_table was called - result_set_with_arrow_queue._convert_arrow_table.assert_called_once() - - def test_fetchone_empty_json_queue(self, result_set_with_json_queue): - """Test fetchone with an empty JSON queue.""" - # Setup _create_json_table to return empty list - result_set_with_json_queue._create_json_table = Mock(return_value=[]) - - # Call fetchone - result = result_set_with_json_queue.fetchone() - - # Verify result is None - assert result is None - - # Verify _create_json_table was called - result_set_with_json_queue._create_json_table.assert_called_once() - - @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") - def test_fetchmany_empty_arrow_queue(self, result_set_with_arrow_queue): - """Test fetchmany with an empty Arrow queue.""" - # Setup _convert_arrow_table to return empty list - result_set_with_arrow_queue._convert_arrow_table = Mock(return_value=[]) - - # Call fetchmany - result = result_set_with_arrow_queue.fetchmany(10) - - # Verify result is an empty list - assert result == [] - - # Verify _convert_arrow_table was called - result_set_with_arrow_queue._convert_arrow_table.assert_called_once() - - @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") - def test_fetchall_empty_arrow_queue(self, result_set_with_arrow_queue): - """Test fetchall with an empty Arrow queue.""" - # Setup _convert_arrow_table to return empty list - result_set_with_arrow_queue._convert_arrow_table = Mock(return_value=[]) - - # Call fetchall - result = result_set_with_arrow_queue.fetchall() - - # Verify result is an empty list - assert result == [] - - # Verify _convert_arrow_table was called - result_set_with_arrow_queue._convert_arrow_table.assert_called_once() - - @patch("databricks.sql.backend.sea.utils.conversion.SqlTypeConverter.convert_value") - def test_convert_json_types_with_errors( - self, mock_convert_value, result_set_with_data - ): - """Test error handling in _convert_json_types.""" - # Mock the conversion to fail for the second and third values - mock_convert_value.side_effect = [ - "value1", # First value converts normally - Exception("Invalid int"), # Second value fails - Exception("Invalid boolean"), # Third value fails - ] - - # Data with invalid values - data_row = ["value1", "not_an_int", "not_a_boolean"] - - # Should not raise an exception but log warnings - result = result_set_with_data._convert_json_types(data_row) - - # The first value should be converted normally - assert result[0] == "value1" - - # The invalid values should remain as strings - assert result[1] == "not_an_int" - assert result[2] == "not_a_boolean" - - @patch("databricks.sql.backend.sea.result_set.logger") - @patch("databricks.sql.backend.sea.utils.conversion.SqlTypeConverter.convert_value") - def test_convert_json_types_with_logging( - self, mock_convert_value, mock_logger, result_set_with_data + def test_fetchall_arrow_not_implemented( + self, mock_connection, execute_response, sample_data ): - """Test that errors in _convert_json_types are logged.""" - # Mock the conversion to fail for the second and third values - mock_convert_value.side_effect = [ - "value1", # First value converts normally - Exception("Invalid int"), # Second value fails - Exception("Invalid boolean"), # Third value fails - ] + """Test that fetchall_arrow raises NotImplementedError for non-JSON data.""" + # Test that NotImplementedError is raised + with pytest.raises( + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", + ): + # Create a result set without JSON data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + result_data=ResultData(data=None, external_links=[]), + manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), + buffer_size_bytes=1000, + arraysize=100, + ) - # Data with invalid values - data_row = ["value1", "not_an_int", "not_a_boolean"] + def test_is_staging_operation(self, mock_connection, execute_response): + """Test the is_staging_operation property.""" + # Set is_staging_operation to True + execute_response.is_staging_operation = True - # Call the method - result_set_with_data._convert_json_types(data_row) + # Create a result set + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) - # Verify warnings were logged - assert mock_logger.warning.call_count == 2 + # Test the property + assert result_set.is_staging_operation is True diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 37569f755..f66b356ca 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -649,7 +649,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( + execute_response, _, _, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -889,10 +889,9 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - ( - execute_response, - _, - ) = thrift_backend._handle_execute_response(execute_resp, Mock()) + (execute_response, _, _, _) = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -965,11 +964,14 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( - t_execute_resp, Mock() - ) + ( + execute_response, + _, + _, + arrow_schema_bytes, + ) = thrift_backend._handle_execute_response(t_execute_resp, Mock()) - self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -997,7 +999,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + _, _, _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -1046,6 +1048,8 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ( execute_response, has_more_rows_result, + _, + _, ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual(is_direct_results, has_more_rows_result) @@ -1179,7 +1183,12 @@ def test_execute_statement_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1215,7 +1224,12 @@ def test_get_catalogs_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1248,7 +1262,12 @@ def test_get_schemas_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1290,7 +1309,12 @@ def test_get_tables_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1336,7 +1360,12 @@ def test_get_columns_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -2254,6 +2283,8 @@ def test_execute_command_sets_complex_type_fields_correctly( mock_handle_execute_response.return_value = ( mock_execute_response, mock_arrow_schema, + Mock(), + Mock(), ) # Iterate through each possible combination of native types (True, False and unset)