diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 85e4236b..dafa818c 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import ABC -from typing import List, Optional, Tuple, Union, TYPE_CHECKING +import threading +from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager @@ -121,6 +122,110 @@ def close(self): return +class LinkFetcher: + def __init__( + self, + download_manager: ResultFileDownloadManager, + backend: SeaDatabricksClient, + statement_id: str, + initial_links: List[ExternalLink], + total_chunk_count: int, + ): + self.download_manager = download_manager + self.backend = backend + self._statement_id = statement_id + + self._shutdown_event = threading.Event() + + self._link_data_update = threading.Condition() + self._error: Optional[Exception] = None + self.chunk_index_to_link: Dict[int, ExternalLink] = {} + + self._add_links(initial_links) + self.total_chunk_count = total_chunk_count + + def _add_links(self, links: List[ExternalLink]): + for link in links: + self.chunk_index_to_link[link.chunk_index] = link + self.download_manager.add_link(LinkFetcher._convert_to_thrift_link(link)) + + def _get_next_chunk_index(self) -> Optional[int]: + with self._link_data_update: + max_chunk_index = max(self.chunk_index_to_link.keys(), default=None) + if max_chunk_index is None: + return 0 + max_link = self.chunk_index_to_link[max_chunk_index] + return max_link.next_chunk_index + + def _trigger_next_batch_download(self) -> bool: + next_chunk_index = self._get_next_chunk_index() + if next_chunk_index is None: + return False + + try: + links = self.backend.get_chunk_links(self._statement_id, next_chunk_index) + with self._link_data_update: + self._add_links(links) + self._link_data_update.notify_all() + except Exception as e: + logger.error( + f"LinkFetcher: Error fetching links for chunk {next_chunk_index}: {e}" + ) + with self._link_data_update: + self._error = e + self._link_data_update.notify_all() + return False + + return True + + def get_chunk_link(self, chunk_index: int) -> Optional[ExternalLink]: + if chunk_index >= self.total_chunk_count: + return None + + with self._link_data_update: + while chunk_index not in self.chunk_index_to_link: + if self._error: + raise self._error + if self._shutdown_event.is_set(): + raise ProgrammingError( + "LinkFetcher is shutting down without providing link for chunk index {}".format( + chunk_index + ) + ) + self._link_data_update.wait() + + return self.chunk_index_to_link.get(chunk_index, None) + + @staticmethod + def _convert_to_thrift_link(link: ExternalLink) -> TSparkArrowResultLink: + """Convert SEA external links to Thrift format for compatibility with existing download manager.""" + # Parse the ISO format expiration time + expiry_time = int(dateutil.parser.parse(link.expiration).timestamp()) + return TSparkArrowResultLink( + fileLink=link.external_link, + expiryTime=expiry_time, + rowCount=link.row_count, + bytesNum=link.byte_count, + startRowOffset=link.row_offset, + httpHeaders=link.http_headers or {}, + ) + + def _worker_loop(self): + while not self._shutdown_event.is_set(): + links_downloaded = self._trigger_next_batch_download() + if not links_downloaded: + break + self._link_data_update.notify_all() + + def start(self): + self._worker_thread = threading.Thread(target=self._worker_loop) + self._worker_thread.start() + + def stop(self): + self._shutdown_event.set() + self._worker_thread.join() + + class SeaCloudFetchQueue(CloudFetchQueue): """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend.""" @@ -176,62 +281,35 @@ def __init__( first_link = self._chunk_index_to_link.get(self._current_chunk_index, None) if not first_link: # possibly an empty response - return None + return - # Track the current chunk we're processing - self._current_chunk_index = 0 - # Initialize table and position - self.table = self._create_table_from_link(first_link) + self.current_chunk_index = 0 - def _convert_to_thrift_link(self, link: ExternalLink) -> TSparkArrowResultLink: - """Convert SEA external links to Thrift format for compatibility with existing download manager.""" - # Parse the ISO format expiration time - expiry_time = int(dateutil.parser.parse(link.expiration).timestamp()) - return TSparkArrowResultLink( - fileLink=link.external_link, - expiryTime=expiry_time, - rowCount=link.row_count, - bytesNum=link.byte_count, - startRowOffset=link.row_offset, - httpHeaders=link.http_headers or {}, + self.link_fetcher = LinkFetcher( + download_manager=self.download_manager, + backend=self._sea_client, + statement_id=self._statement_id, + initial_links=initial_links, + total_chunk_count=total_chunk_count, ) + self.link_fetcher.start() - def _get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]: - if chunk_index >= self._total_chunk_count: - return None - - if chunk_index not in self._chunk_index_to_link: - links = self._sea_client.get_chunk_links(self._statement_id, chunk_index) - self._chunk_index_to_link.update({l.chunk_index: l for l in links}) - - link = self._chunk_index_to_link.get(chunk_index, None) - if not link: - raise ServerOperationError( - f"Error fetching link for chunk {chunk_index}", - { - "operation-id": self._statement_id, - "diagnostic-info": None, - }, - ) - return link - - def _create_table_from_link( - self, link: ExternalLink - ) -> Union["pyarrow.Table", None]: - """Create a table from a link.""" + # Initialize table and position + self.table = self._create_next_table() - thrift_link = self._convert_to_thrift_link(link) - self.download_manager.add_link(thrift_link) + def _create_next_table(self) -> Union["pyarrow.Table", None]: + """Create next table by retrieving the logical next downloaded file.""" + chunk_link = self.link_fetcher.get_chunk_link(self.current_chunk_index) + if not chunk_link: + return None - row_offset = link.row_offset + row_offset = chunk_link.row_offset arrow_table = self._create_table_at_offset(row_offset) + self.current_chunk_index += 1 + return arrow_table - def _create_next_table(self) -> Union["pyarrow.Table", None]: - """Create next table by retrieving the logical next downloaded file.""" - self._current_chunk_index += 1 - next_chunk_link = self._get_chunk_link(self._current_chunk_index) - if not next_chunk_link: - return None - return self._create_table_from_link(next_chunk_link) + def close(self): + super().close() + self.link_fetcher.stop() diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 4e5af065..e4f2ba78 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -11,6 +11,7 @@ from databricks.sql.backend.sea.queue import ( JsonQueue, + LinkFetcher, SeaResultSetQueueFactory, SeaCloudFetchQueue, ) @@ -23,6 +24,8 @@ from databricks.sql.exc import ProgrammingError, ServerOperationError from databricks.sql.types import SSLOptions from databricks.sql.utils import ArrowQueue +import threading +import time class TestJsonQueue: @@ -216,9 +219,7 @@ def test_build_queue_arrow_stream( with patch( "databricks.sql.backend.sea.queue.ResultFileDownloadManager" - ), patch.object( - SeaCloudFetchQueue, "_create_table_from_link", return_value=None - ): + ), patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, manifest=arrow_manifest, @@ -303,10 +304,8 @@ def sample_external_link_no_headers(self): def test_convert_to_thrift_link(self, sample_external_link): """Test conversion of ExternalLink to TSparkArrowResultLink.""" - queue = Mock(spec=SeaCloudFetchQueue) - # Call the method directly - result = SeaCloudFetchQueue._convert_to_thrift_link(queue, sample_external_link) + result = LinkFetcher._convert_to_thrift_link(sample_external_link) # Verify the conversion assert result.fileLink == sample_external_link.external_link @@ -317,12 +316,8 @@ def test_convert_to_thrift_link(self, sample_external_link): def test_convert_to_thrift_link_no_headers(self, sample_external_link_no_headers): """Test conversion of ExternalLink with no headers to TSparkArrowResultLink.""" - queue = Mock(spec=SeaCloudFetchQueue) - # Call the method directly - result = SeaCloudFetchQueue._convert_to_thrift_link( - queue, sample_external_link_no_headers - ) + result = LinkFetcher._convert_to_thrift_link(sample_external_link_no_headers) # Verify the conversion assert result.fileLink == sample_external_link_no_headers.external_link @@ -344,9 +339,7 @@ def test_init_with_valid_initial_link( ): """Test initialization with valid initial link.""" # Create a queue with valid initial link - with patch.object( - SeaCloudFetchQueue, "_create_table_from_link", return_value=None - ): + with patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): queue = SeaCloudFetchQueue( result_data=ResultData(external_links=[sample_external_link]), max_download_threads=5, @@ -398,29 +391,29 @@ def test_create_next_table_success(self, mock_logger): """Test _create_next_table with successful table creation.""" # Create a queue instance without initializing queue = Mock(spec=SeaCloudFetchQueue) - queue._current_chunk_index = 0 + queue.current_chunk_index = 0 queue.download_manager = Mock() + queue.link_fetcher = Mock() # Mock the dependencies mock_table = Mock() mock_chunk_link = Mock() - queue._get_chunk_link = Mock(return_value=mock_chunk_link) - queue._create_table_from_link = Mock(return_value=mock_table) + queue.link_fetcher.get_chunk_link = Mock(return_value=mock_chunk_link) + queue._create_table_at_offset = Mock(return_value=mock_table) # Call the method directly - result = SeaCloudFetchQueue._create_next_table(queue) + SeaCloudFetchQueue._create_next_table(queue) # Verify the chunk index was incremented - assert queue._current_chunk_index == 1 + assert queue.current_chunk_index == 1 # Verify the chunk link was retrieved - queue._get_chunk_link.assert_called_once_with(1) + queue.link_fetcher.get_chunk_link.assert_called_once_with(0) # Verify the table was created from the link - queue._create_table_from_link.assert_called_once_with(mock_chunk_link) - - # Verify the result is the table - assert result == mock_table + queue._create_table_at_offset.assert_called_once_with( + mock_chunk_link.row_offset + ) class TestHybridDisposition: @@ -494,7 +487,7 @@ def test_hybrid_disposition_with_attachment( mock_create_table.assert_called_once_with(attachment_data, description) @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") - @patch.object(SeaCloudFetchQueue, "_create_table_from_link", return_value=None) + @patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None) def test_hybrid_disposition_with_external_links( self, mock_create_table, @@ -579,3 +572,156 @@ def test_hybrid_disposition_with_compressed_attachment( assert isinstance(queue, ArrowQueue) mock_decompress.assert_called_once_with(compressed_data) mock_create_table.assert_called_once_with(decompressed_data, description) + + +class TestLinkFetcher: + """Unit tests for the LinkFetcher helper class.""" + + @pytest.fixture + def sample_links(self): + """Provide a pair of ExternalLink objects forming two sequential chunks.""" + link0 = ExternalLink( + external_link="https://example.com/data/chunk0", + expiration="2030-01-01T00:00:00.000000", + row_count=100, + byte_count=1024, + row_offset=0, + chunk_index=0, + next_chunk_index=1, + http_headers={"Authorization": "Bearer token0"}, + ) + + link1 = ExternalLink( + external_link="https://example.com/data/chunk1", + expiration="2030-01-01T00:00:00.000000", + row_count=100, + byte_count=1024, + row_offset=100, + chunk_index=1, + next_chunk_index=None, + http_headers={"Authorization": "Bearer token1"}, + ) + + return link0, link1 + + def _create_fetcher( + self, + initial_links, + backend_mock=None, + download_manager_mock=None, + total_chunk_count=10, + ): + """Helper to create a LinkFetcher instance with supplied mocks.""" + if backend_mock is None: + backend_mock = Mock() + if download_manager_mock is None: + download_manager_mock = Mock() + + return ( + LinkFetcher( + download_manager=download_manager_mock, + backend=backend_mock, + statement_id="statement-123", + initial_links=list(initial_links), + total_chunk_count=total_chunk_count, + ), + backend_mock, + download_manager_mock, + ) + + def test_add_links_and_get_next_chunk_index(self, sample_links): + """Verify that initial links are stored and next chunk index is computed correctly.""" + link0, link1 = sample_links + + fetcher, _backend, download_manager = self._create_fetcher([link0]) + + # add_link should have been called for the initial link + download_manager.add_link.assert_called_once() + + # Internal mapping should contain the link + assert fetcher.chunk_index_to_link[0] == link0 + + # The next chunk index should be 1 (from link0.next_chunk_index) + assert fetcher._get_next_chunk_index() == 1 + + # Add second link and validate it is present + fetcher._add_links([link1]) + assert fetcher.chunk_index_to_link[1] == link1 + + def test_trigger_next_batch_download_success(self, sample_links): + """Check that _trigger_next_batch_download fetches and stores new links.""" + link0, link1 = sample_links + + backend_mock = Mock() + backend_mock.get_chunk_links = Mock(return_value=[link1]) + + fetcher, backend, download_manager = self._create_fetcher( + [link0], backend_mock=backend_mock + ) + + # Trigger download of the next chunk (index 1) + success = fetcher._trigger_next_batch_download() + + assert success is True + backend.get_chunk_links.assert_called_once_with("statement-123", 1) + assert fetcher.chunk_index_to_link[1] == link1 + # Two calls to add_link: one for initial link, one for new link + assert download_manager.add_link.call_count == 2 + + def test_trigger_next_batch_download_error(self, sample_links): + """Ensure that errors from backend are captured and surfaced.""" + link0, _link1 = sample_links + + backend_mock = Mock() + backend_mock.get_chunk_links.side_effect = ServerOperationError( + "Backend failure" + ) + + fetcher, backend, download_manager = self._create_fetcher( + [link0], backend_mock=backend_mock + ) + + success = fetcher._trigger_next_batch_download() + + assert success is False + assert fetcher._error is not None + + def test_get_chunk_link_waits_until_available(self, sample_links): + """Validate that get_chunk_link blocks until the requested link is available and then returns it.""" + link0, link1 = sample_links + + backend_mock = Mock() + # Configure backend to return link1 when requested for chunk index 1 + backend_mock.get_chunk_links = Mock(return_value=[link1]) + + fetcher, backend, download_manager = self._create_fetcher( + [link0], backend_mock=backend_mock, total_chunk_count=2 + ) + + # Holder to capture the link returned from the background thread + result_container = {} + + def _worker(): + result_container["link"] = fetcher.get_chunk_link(1) + + thread = threading.Thread(target=_worker) + thread.start() + + # Give the thread a brief moment to start and attempt to fetch (and therefore block) + time.sleep(0.1) + + # Trigger the backend fetch which will add link1 and notify waiting threads + fetcher._trigger_next_batch_download() + + thread.join(timeout=2) + + # The thread should have finished and captured link1 + assert result_container.get("link") == link1 + + def test_get_chunk_link_out_of_range_returns_none(self, sample_links): + """Requesting a chunk index >= total_chunk_count should immediately return None.""" + link0, _ = sample_links + + fetcher, _backend, _dm = self._create_fetcher([link0], total_chunk_count=1) + + assert fetcher.get_chunk_link(10) is None