diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index 3c0e325f..5bc6c679 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -45,6 +45,7 @@ def test_sea_async_query_with_cloud_fetch(): use_sea=True, user_agent_entry="SEA-Test-Client", use_cloud_fetch=True, + enable_query_result_lz4_compression=False, ) logger.info( diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 76941e2d..16ee80a7 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -43,6 +43,7 @@ def test_sea_sync_query_with_cloud_fetch(): use_sea=True, user_agent_entry="SEA-Test-Client", use_cloud_fetch=True, + enable_query_result_lz4_compression=False, ) logger.info( diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 5592de03..6f39e264 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -130,6 +130,8 @@ def __init__( "_use_arrow_native_complex_types", True ) + self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True) + # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) @@ -456,7 +458,11 @@ def execute_command( ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY ).value disposition = ( - ResultDisposition.EXTERNAL_LINKS + ( + ResultDisposition.HYBRID + if self.use_hybrid_disposition + else ResultDisposition.EXTERNAL_LINKS + ) if use_cloud_fetch else ResultDisposition.INLINE ).value @@ -637,7 +643,9 @@ def get_execution_result( arraysize=cursor.arraysize, ) - def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink: + def get_chunk_links( + self, statement_id: str, chunk_index: int + ) -> List[ExternalLink]: """ Get links for chunks starting from the specified index. Args: @@ -654,17 +662,7 @@ def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink: response = GetChunksResponse.from_dict(response_data) links = response.external_links or [] - link = next((l for l in links if l.chunk_index == chunk_index), None) - if not link: - raise ServerOperationError( - f"No link found for chunk index {chunk_index}", - { - "operation-id": statement_id, - "diagnostic-info": None, - }, - ) - - return link + return links # == Metadata Operations == diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 6bd28c9b..5a558048 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -4,6 +4,7 @@ These models define the structures used in SEA API responses. """ +import base64 from typing import Dict, Any, List, Optional from dataclasses import dataclass @@ -91,6 +92,11 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: ) ) + # Handle attachment field - decode from base64 if present + attachment = result_data.get("attachment") + if attachment is not None: + attachment = base64.b64decode(attachment) + return ResultData( data=result_data.get("data_array"), external_links=external_links, @@ -100,7 +106,7 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: next_chunk_internal_link=result_data.get("next_chunk_internal_link"), row_count=result_data.get("row_count"), row_offset=result_data.get("row_offset"), - attachment=result_data.get("attachment"), + attachment=attachment, ) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index df6d6a80..065be2d3 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -5,6 +5,8 @@ from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager +from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler + try: import pyarrow except ImportError: @@ -23,7 +25,12 @@ from databricks.sql.exc import ProgrammingError, ServerOperationError from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.types import SSLOptions -from databricks.sql.utils import CloudFetchQueue, ResultSetQueue +from databricks.sql.utils import ( + ArrowQueue, + CloudFetchQueue, + ResultSetQueue, + create_arrow_table_from_arrow_file, +) import logging @@ -62,6 +69,18 @@ def build_queue( # INLINE disposition with JSON_ARRAY format return JsonQueue(result_data.data) elif manifest.format == ResultFormat.ARROW_STREAM.value: + if result_data.attachment is not None: + arrow_file = ( + ResultSetDownloadHandler._decompress_data(result_data.attachment) + if lz4_compressed + else result_data.attachment + ) + arrow_table = create_arrow_table_from_arrow_file( + arrow_file, description + ) + logger.debug(f"Created arrow table with {arrow_table.num_rows} rows") + return ArrowQueue(arrow_table, manifest.total_row_count) + # EXTERNAL_LINKS disposition return SeaCloudFetchQueue( result_data=result_data, @@ -142,6 +161,7 @@ def __init__( self._sea_client = sea_client self._statement_id = statement_id self._total_chunk_count = total_chunk_count + self._total_chunk_count = total_chunk_count logger.debug( "SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}".format( @@ -150,7 +170,11 @@ def __init__( ) initial_links = result_data.external_links or [] - first_link = next((l for l in initial_links if l.chunk_index == 0), None) + self._chunk_index_to_link = {link.chunk_index: link for link in initial_links} + + # Track the current chunk we're processing + self._current_chunk_index = 0 + first_link = self._chunk_index_to_link.get(self._current_chunk_index, None) if not first_link: # possibly an empty response return None @@ -173,21 +197,24 @@ def _convert_to_thrift_link(self, link: ExternalLink) -> TSparkArrowResultLink: httpHeaders=link.http_headers or {}, ) - def _get_chunk_link(self, chunk_index: int) -> Optional[ExternalLink]: - """Progress to the next chunk link.""" + def _get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]: if chunk_index >= self._total_chunk_count: return None - try: - return self._sea_client.get_chunk_link(self._statement_id, chunk_index) - except Exception as e: + if chunk_index not in self._chunk_index_to_link: + links = self._sea_client.get_chunk_links(self._statement_id, chunk_index) + self._chunk_index_to_link.update({l.chunk_index: l for l in links}) + + link = self._chunk_index_to_link.get(chunk_index, None) + if not link: raise ServerOperationError( - f"Error fetching link for chunk {chunk_index}: {e}", + f"Error fetching link for chunk {chunk_index}", { "operation-id": self._statement_id, "diagnostic-info": None, }, ) + return link def _create_table_from_link( self, link: ExternalLink diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 402da0de..46ce8c98 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -28,7 +28,7 @@ class ResultFormat(Enum): class ResultDisposition(Enum): """Enum for result disposition values.""" - # TODO: add support for hybrid disposition + HYBRID = "INLINE_OR_EXTERNAL_LINKS" EXTERNAL_LINKS = "EXTERNAL_LINKS" INLINE = "INLINE" diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 75e89d92..dfa732c2 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -99,6 +99,10 @@ def __init__( Connect to a Databricks SQL endpoint or a Databricks cluster. Parameters: + :param use_sea: `bool`, optional (default is False) + Use the SEA backend instead of the Thrift backend. + :param use_hybrid_disposition: `bool`, optional (default is False) + Use the hybrid disposition instead of the inline disposition. :param server_hostname: Databricks instance host name. :param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 28026807..877136cf 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -959,8 +959,8 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor): ) assert "Catalog name is required for get_columns" in str(excinfo.value) - def test_get_chunk_link(self, sea_client, mock_http_client, sea_command_id): - """Test get_chunk_link method.""" + def test_get_chunk_links(self, sea_client, mock_http_client, sea_command_id): + """Test get_chunk_links method when links are available.""" # Setup mock response mock_response = { "external_links": [ @@ -979,7 +979,7 @@ def test_get_chunk_link(self, sea_client, mock_http_client, sea_command_id): mock_http_client._make_request.return_value = mock_response # Call the method - result = sea_client.get_chunk_link("test-statement-123", 0) + results = sea_client.get_chunk_links("test-statement-123", 0) # Verify the HTTP client was called correctly mock_http_client._make_request.assert_called_once_with( @@ -989,7 +989,10 @@ def test_get_chunk_link(self, sea_client, mock_http_client, sea_command_id): ), ) - # Verify the result + # Verify the results + assert isinstance(results, list) + assert len(results) == 1 + result = results[0] assert result.external_link == "https://example.com/data/chunk0" assert result.expiration == "2025-07-03T05:51:18.118009" assert result.row_count == 100 @@ -999,30 +1002,14 @@ def test_get_chunk_link(self, sea_client, mock_http_client, sea_command_id): assert result.next_chunk_index == 1 assert result.http_headers == {"Authorization": "Bearer token123"} - def test_get_chunk_link_not_found(self, sea_client, mock_http_client): - """Test get_chunk_link when the requested chunk is not found.""" + def test_get_chunk_links_empty(self, sea_client, mock_http_client): + """Test get_chunk_links when no links are returned (empty list).""" # Setup mock response with no matching chunk - mock_response = { - "external_links": [ - { - "external_link": "https://example.com/data/chunk1", - "expiration": "2025-07-03T05:51:18.118009", - "row_count": 100, - "byte_count": 1024, - "row_offset": 100, - "chunk_index": 1, # Different chunk index - "next_chunk_index": 2, - "http_headers": {"Authorization": "Bearer token123"}, - } - ] - } + mock_response = {"external_links": []} mock_http_client._make_request.return_value = mock_response - # Call the method and expect an exception - with pytest.raises( - ServerOperationError, match="No link found for chunk index 0" - ): - sea_client.get_chunk_link("test-statement-123", 0) + # Call the method + results = sea_client.get_chunk_links("test-statement-123", 0) # Verify the HTTP client was called correctly mock_http_client._make_request.assert_called_once_with( @@ -1031,3 +1018,7 @@ def test_get_chunk_link_not_found(self, sea_client, mock_http_client): "test-statement-123", 0 ), ) + + # Verify the results are empty + assert isinstance(results, list) + assert results == [] diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 60c967ba..4e5af065 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -2,6 +2,8 @@ Tests for SEA-related queue classes. This module contains tests for the JsonQueue, SeaResultSetQueueFactory, and SeaCloudFetchQueue classes. +It also tests the Hybrid disposition which can create either ArrowQueue or SeaCloudFetchQueue based on +whether attachment is set. """ import pytest @@ -20,6 +22,7 @@ from databricks.sql.backend.sea.utils.constants import ResultFormat from databricks.sql.exc import ProgrammingError, ServerOperationError from databricks.sql.types import SSLOptions +from databricks.sql.utils import ArrowQueue class TestJsonQueue: @@ -418,3 +421,161 @@ def test_create_next_table_success(self, mock_logger): # Verify the result is the table assert result == mock_table + + +class TestHybridDisposition: + """Test suite for the Hybrid disposition handling in SeaResultSetQueueFactory.""" + + @pytest.fixture + def arrow_manifest(self): + """Create an Arrow manifest for testing.""" + return ResultManifest( + format=ResultFormat.ARROW_STREAM.value, + schema={}, + total_row_count=5, + total_byte_count=1000, + total_chunk_count=1, + ) + + @pytest.fixture + def description(self): + """Create column descriptions.""" + return [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ("col3", "boolean", None, None, None, None, None), + ] + + @pytest.fixture + def ssl_options(self): + """Create SSL options for testing.""" + return SSLOptions(tls_verify=True) + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + client = Mock() + client.max_download_threads = 10 + return client + + @patch("databricks.sql.backend.sea.queue.create_arrow_table_from_arrow_file") + def test_hybrid_disposition_with_attachment( + self, + mock_create_table, + arrow_manifest, + description, + ssl_options, + mock_sea_client, + ): + """Test that ArrowQueue is created when attachment is present.""" + # Create mock arrow table + mock_arrow_table = Mock() + mock_arrow_table.num_rows = 5 + mock_create_table.return_value = mock_arrow_table + + # Create result data with attachment + attachment_data = b"mock_arrow_data" + result_data = ResultData(attachment=attachment_data) + + # Build queue + queue = SeaResultSetQueueFactory.build_queue( + result_data=result_data, + manifest=arrow_manifest, + statement_id="test-statement", + ssl_options=ssl_options, + description=description, + max_download_threads=10, + sea_client=mock_sea_client, + lz4_compressed=False, + ) + + # Verify ArrowQueue was created + assert isinstance(queue, ArrowQueue) + mock_create_table.assert_called_once_with(attachment_data, description) + + @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") + @patch.object(SeaCloudFetchQueue, "_create_table_from_link", return_value=None) + def test_hybrid_disposition_with_external_links( + self, + mock_create_table, + mock_download_manager, + arrow_manifest, + description, + ssl_options, + mock_sea_client, + ): + """Test that SeaCloudFetchQueue is created when attachment is None but external links are present.""" + # Create external links + external_links = [ + ExternalLink( + external_link="https://example.com/data/chunk0", + expiration="2025-07-03T05:51:18.118009", + row_count=100, + byte_count=1024, + row_offset=0, + chunk_index=0, + next_chunk_index=1, + http_headers={"Authorization": "Bearer token123"}, + ) + ] + + # Create result data with external links but no attachment + result_data = ResultData(external_links=external_links, attachment=None) + + # Build queue + queue = SeaResultSetQueueFactory.build_queue( + result_data=result_data, + manifest=arrow_manifest, + statement_id="test-statement", + ssl_options=ssl_options, + description=description, + max_download_threads=10, + sea_client=mock_sea_client, + lz4_compressed=False, + ) + + # Verify SeaCloudFetchQueue was created + assert isinstance(queue, SeaCloudFetchQueue) + mock_create_table.assert_called_once() + + @patch("databricks.sql.backend.sea.queue.ResultSetDownloadHandler._decompress_data") + @patch("databricks.sql.backend.sea.queue.create_arrow_table_from_arrow_file") + def test_hybrid_disposition_with_compressed_attachment( + self, + mock_create_table, + mock_decompress, + arrow_manifest, + description, + ssl_options, + mock_sea_client, + ): + """Test that ArrowQueue is created with decompressed data when attachment is present and lz4_compressed is True.""" + # Create mock arrow table + mock_arrow_table = Mock() + mock_arrow_table.num_rows = 5 + mock_create_table.return_value = mock_arrow_table + + # Setup decompression mock + compressed_data = b"compressed_data" + decompressed_data = b"decompressed_data" + mock_decompress.return_value = decompressed_data + + # Create result data with attachment + result_data = ResultData(attachment=compressed_data) + + # Build queue with lz4_compressed=True + queue = SeaResultSetQueueFactory.build_queue( + result_data=result_data, + manifest=arrow_manifest, + statement_id="test-statement", + ssl_options=ssl_options, + description=description, + max_download_threads=10, + sea_client=mock_sea_client, + lz4_compressed=True, + ) + + # Verify ArrowQueue was created with decompressed data + assert isinstance(queue, ArrowQueue) + mock_decompress.assert_called_once_with(compressed_data) + mock_create_table.assert_called_once_with(decompressed_data, description)