diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py new file mode 100644 index 000000000..712f033c6 --- /dev/null +++ b/examples/experimental/sea_connector_test.py @@ -0,0 +1,121 @@ +""" +Main script to run all SEA connector tests. + +This script runs all the individual test modules and displays +a summary of test results with visual indicators. + +In order to run the script, the following environment variables need to be set: +- DATABRICKS_SERVER_HOSTNAME: The hostname of the Databricks server +- DATABRICKS_HTTP_PATH: The HTTP path of the Databricks server +- DATABRICKS_TOKEN: The token to use for authentication +""" + +import os +import sys +import logging +import subprocess +from typing import List, Tuple + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +TEST_MODULES = [ + "test_sea_session", + "test_sea_sync_query", + "test_sea_async_query", + "test_sea_metadata", +] + + +def run_test_module(module_name: str) -> bool: + """Run a test module and return success status.""" + module_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" + ) + + # Simply run the module as a script - each module handles its own test execution + result = subprocess.run( + [sys.executable, module_path], capture_output=True, text=True + ) + + # Log the output from the test module + if result.stdout: + for line in result.stdout.strip().split("\n"): + logger.info(line) + + if result.stderr: + for line in result.stderr.strip().split("\n"): + logger.error(line) + + return result.returncode == 0 + + +def run_tests() -> List[Tuple[str, bool]]: + """Run all tests and return results.""" + results = [] + + for module_name in TEST_MODULES: + try: + logger.info(f"\n{'=' * 50}") + logger.info(f"Running test: {module_name}") + logger.info(f"{'-' * 50}") + + success = run_test_module(module_name) + results.append((module_name, success)) + + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"Test {module_name}: {status}") + + except Exception as e: + logger.error(f"Error loading or running test {module_name}: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + results.append((module_name, False)) + + return results + + +def print_summary(results: List[Tuple[str, bool]]) -> None: + """Print a summary of test results.""" + logger.info(f"\n{'=' * 50}") + logger.info("TEST SUMMARY") + logger.info(f"{'-' * 50}") + + passed = sum(1 for _, success in results if success) + total = len(results) + + for module_name, success in results: + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"{status} - {module_name}") + + logger.info(f"{'-' * 50}") + logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") + logger.info(f"{'=' * 50}") + + +if __name__ == "__main__": + # Check if required environment variables are set + required_vars = [ + "DATABRICKS_SERVER_HOSTNAME", + "DATABRICKS_HTTP_PATH", + "DATABRICKS_TOKEN", + ] + missing_vars = [var for var in required_vars if not os.environ.get(var)] + + if missing_vars: + logger.error( + f"Missing required environment variables: {', '.join(missing_vars)}" + ) + logger.error("Please set these variables before running the tests.") + sys.exit(1) + + # Run all tests + results = run_tests() + + # Print summary + print_summary(results) + + # Exit with appropriate status code + all_passed = all(success for _, success in results) + sys.exit(0 if all_passed else 1) diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py new file mode 100644 index 000000000..2742e8cb2 --- /dev/null +++ b/examples/experimental/tests/test_sea_async_query.py @@ -0,0 +1,192 @@ +""" +Test for SEA asynchronous query execution functionality. +""" +import os +import sys +import logging +import time +from databricks.sql.client import Connection +from databricks.sql.backend.types import CommandState + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_async_query_with_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info( + "Creating connection for asynchronous query execution with cloud fetch enabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info( + "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" + ) + cursor.execute_async("SELECT 1 as test_value") + logger.info( + "Asynchronous query submitted successfully with cloud fetch enabled" + ) + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch enabled" + ) + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_without_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info( + "Creating connection for asynchronous query execution with cloud fetch disabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info( + "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" + ) + cursor.execute_async("SELECT 1 as test_value") + logger.info( + "Asynchronous query submitted successfully with cloud fetch disabled" + ) + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch disabled" + ) + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_exec(): + """ + Run both asynchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() + logger.info( + f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() + logger.info( + f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_async_query_exec() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py new file mode 100644 index 000000000..a200d97d3 --- /dev/null +++ b/examples/experimental/tests/test_sea_metadata.py @@ -0,0 +1,98 @@ +""" +Test for SEA metadata functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_metadata(): + """ + Test metadata operations using the SEA backend. + + This function connects to a Databricks SQL endpoint using the SEA backend, + and executes metadata operations like catalogs(), schemas(), tables(), and columns(). + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + if not catalog: + logger.error( + "DATABRICKS_CATALOG environment variable is required for metadata tests." + ) + return False + + try: + # Create connection + logger.info("Creating connection for metadata operations") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Test catalogs + cursor = connection.cursor() + logger.info("Fetching catalogs...") + cursor.catalogs() + logger.info("Successfully fetched catalogs") + + # Test schemas + logger.info(f"Fetching schemas for catalog '{catalog}'...") + cursor.schemas(catalog_name=catalog) + logger.info("Successfully fetched schemas") + + # Test tables + logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") + cursor.tables(catalog_name=catalog, schema_name="default") + logger.info("Successfully fetched tables") + + # Test columns for a specific table + # Using a common table that should exist in most environments + logger.info( + f"Fetching columns for catalog '{catalog}', schema 'default', table 'customer'..." + ) + cursor.columns( + catalog_name=catalog, schema_name="default", table_name="customer" + ) + logger.info("Successfully fetched columns") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA metadata test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_metadata() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py new file mode 100644 index 000000000..516c1bbb8 --- /dev/null +++ b/examples/experimental/tests/test_sea_session.py @@ -0,0 +1,71 @@ +""" +Test for SEA session management functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"Backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_session() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py new file mode 100644 index 000000000..5ab6d823b --- /dev/null +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -0,0 +1,162 @@ +""" +Test for SEA synchronous query execution functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_sync_query_with_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info( + "Creating connection for synchronous query execution with cloud fetch enabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info( + "Executing synchronous query with cloud fetch: SELECT 1 as test_value" + ) + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch enabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_without_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info( + "Creating connection for synchronous query execution with cloud fetch disabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info( + "Executing synchronous query without cloud fetch: SELECT 1 as test_value" + ) + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch disabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_exec(): + """ + Run both synchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() + logger.info( + f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() + logger.info( + f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_sync_query_exec() + sys.exit(0 if success else 1) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index ee158b452..2213635fe 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -82,6 +82,7 @@ def execute_command( parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, + row_limit: Optional[int] = None, ) -> Union[ResultSet, None]: """ Executes a SQL command or query within the specified session. @@ -100,6 +101,7 @@ def execute_command( parameters: List of parameters to bind to the query async_op: Whether to execute the command asynchronously enforce_embedded_schema_correctness: Whether to enforce schema correctness + row_limit: Maximum number of rows in the response. Returns: If async_op is False, returns a ResultSet object containing the diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py new file mode 100644 index 000000000..c0b89da75 --- /dev/null +++ b/src/databricks/sql/backend/sea/backend.py @@ -0,0 +1,794 @@ +from __future__ import annotations + +import logging +import time +import re +from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set + +from databricks.sql.backend.sea.models.base import ResultManifest, StatementStatus +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, + ResultFormat, + ResultDisposition, + ResultCompression, + WaitTimeout, + MetadataCommands, +) +from databricks.sql.thrift_api.TCLIService import ttypes + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + +from databricks.sql.result_set import SeaResultSet + +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.types import ( + SessionId, + CommandId, + CommandState, + BackendType, + ExecuteResponse, +) +from databricks.sql.exc import DatabaseError, ServerOperationError +from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.types import SSLOptions + +from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, + CreateSessionRequest, + DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, + CreateSessionResponse, +) + +logger = logging.getLogger(__name__) + + +def _filter_session_configuration( + session_configuration: Optional[Dict[str, Any]], +) -> Dict[str, str]: + """ + Filter and normalise the provided session configuration parameters. + + The Statement Execution API supports only a subset of SQL session + configuration options. This helper validates the supplied + ``session_configuration`` dictionary against the allow-list defined in + ``ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP`` and returns a new + dictionary that contains **only** the supported parameters. + + Args: + session_configuration: Optional mapping of session configuration + names to their desired values. Key comparison is + case-insensitive. + + Returns: + Dict[str, str]: A dictionary containing only the supported + configuration parameters with lower-case keys and string values. If + *session_configuration* is ``None`` or empty, an empty dictionary is + returned. + """ + + if not session_configuration: + return {} + + filtered_session_configuration = {} + ignored_configs: Set[str] = set() + + for key, value in session_configuration.items(): + if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: + filtered_session_configuration[key.lower()] = str(value) + else: + ignored_configs.add(key) + + if ignored_configs: + logger.warning( + "Some session configurations were ignored because they are not supported: %s", + ignored_configs, + ) + logger.warning( + "Supported session configurations are: %s", + list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()), + ) + + return filtered_session_configuration + + +class SeaDatabricksClient(DatabricksClient): + """ + Statement Execution API (SEA) implementation of the DatabricksClient interface. + """ + + # SEA API paths + BASE_PATH = "/api/2.0/sql/" + SESSION_PATH = BASE_PATH + "sessions" + SESSION_PATH_WITH_ID = SESSION_PATH + "/{}" + STATEMENT_PATH = BASE_PATH + "statements" + STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" + CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" + + # SEA constants + POLL_INTERVAL_SECONDS = 0.2 + + def __init__( + self, + server_hostname: str, + port: int, + http_path: str, + http_headers: List[Tuple[str, str]], + auth_provider, + ssl_options: SSLOptions, + **kwargs, + ): + """ + Initialize the SEA backend client. + + Args: + server_hostname: Hostname of the Databricks server + port: Port number for the connection + http_path: HTTP path for the connection + http_headers: List of HTTP headers to include in requests + auth_provider: Authentication provider + ssl_options: SSL configuration options + **kwargs: Additional keyword arguments + """ + + logger.debug( + "SeaDatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)", + server_hostname, + port, + http_path, + ) + + self._max_download_threads = kwargs.get("max_download_threads", 10) + self._ssl_options = ssl_options + self._use_arrow_native_complex_types = kwargs.get( + "_use_arrow_native_complex_types", True + ) + + self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True) + + # Extract warehouse ID from http_path + self.warehouse_id = self._extract_warehouse_id(http_path) + + # Initialize HTTP client + self._http_client = SeaHttpClient( + server_hostname=server_hostname, + port=port, + http_path=http_path, + http_headers=http_headers, + auth_provider=auth_provider, + ssl_options=ssl_options, + **kwargs, + ) + + def _extract_warehouse_id(self, http_path: str) -> str: + """ + Extract the warehouse ID from the HTTP path. + + Args: + http_path: The HTTP path from which to extract the warehouse ID + + Returns: + The extracted warehouse ID + + Raises: + ValueError: If the warehouse ID cannot be extracted from the path + """ + + warehouse_pattern = re.compile(r".*/warehouses/(.+)") + endpoint_pattern = re.compile(r".*/endpoints/(.+)") + + for pattern in [warehouse_pattern, endpoint_pattern]: + match = pattern.match(http_path) + if not match: + continue + warehouse_id = match.group(1) + logger.debug( + f"Extracted warehouse ID: {warehouse_id} from path: {http_path}" + ) + return warehouse_id + + # If no match found, raise error + error_message = ( + f"Could not extract warehouse ID from http_path: {http_path}. " + f"Expected format: /path/to/warehouses/{{warehouse_id}} or " + f"/path/to/endpoints/{{warehouse_id}}." + f"Note: SEA only works for warehouses." + ) + logger.error(error_message) + raise ValueError(error_message) + + @property + def max_download_threads(self) -> int: + """Get the maximum number of download threads for cloud fetch operations.""" + return self._max_download_threads + + def open_session( + self, + session_configuration: Optional[Dict[str, Any]], + catalog: Optional[str], + schema: Optional[str], + ) -> SessionId: + """ + Opens a new session with the Databricks SQL service using SEA. + + Args: + session_configuration: Optional dictionary of configuration parameters for the session. + Only specific parameters are supported as documented at: + https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters + catalog: Optional catalog name to use as the initial catalog for the session + schema: Optional schema name to use as the initial schema for the session + + Returns: + SessionId: A session identifier object that can be used for subsequent operations + + Raises: + Error: If the session configuration is invalid + OperationalError: If there's an error establishing the session + """ + + logger.debug( + "SeaDatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)", + session_configuration, + catalog, + schema, + ) + + session_configuration = _filter_session_configuration(session_configuration) + + request_data = CreateSessionRequest( + warehouse_id=self.warehouse_id, + session_confs=session_configuration, + catalog=catalog, + schema=schema, + ) + + response = self._http_client._make_request( + method="POST", path=self.SESSION_PATH, data=request_data.to_dict() + ) + + session_response = CreateSessionResponse.from_dict(response) + session_id = session_response.session_id + if not session_id: + raise ServerOperationError( + "Failed to create session: No session ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + return SessionId.from_sea_session_id(session_id) + + def close_session(self, session_id: SessionId) -> None: + """ + Closes an existing session with the Databricks SQL service. + + Args: + session_id: The session identifier returned by open_session() + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error closing the session + """ + + logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) + + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + sea_session_id = session_id.to_sea_session_id() + + request_data = DeleteSessionRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + ) + + self._http_client._make_request( + method="DELETE", + path=self.SESSION_PATH_WITH_ID.format(sea_session_id), + data=request_data.to_dict(), + ) + + def _extract_description_from_manifest( + self, manifest: ResultManifest + ) -> Optional[List]: + """ + Extract column description from a manifest object, in the format defined by + the spec: https://peps.python.org/pep-0249/#description + + Args: + manifest: The ResultManifest object containing schema information + + Returns: + Optional[List]: A list of column tuples or None if no columns are found + """ + + schema_data = manifest.schema + columns_data = schema_data.get("columns", []) + + if not columns_data: + return None + + columns = [] + for col_data in columns_data: + # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) + name = col_data.get("name", "") + type_name = col_data.get("type_name", "") + type_name = ( + type_name[:-5] if type_name.endswith("_TYPE") else type_name + ).lower() + precision = col_data.get("type_precision") + scale = col_data.get("type_scale") + + columns.append( + ( + name, # name + type_name, # type_code + None, # display_size (not provided by SEA) + None, # internal_size (not provided by SEA) + precision, # precision + scale, # scale + None, # null_ok + ) + ) + + return columns if columns else None + + def _results_message_to_execute_response( + self, response: Union[ExecuteStatementResponse, GetStatementResponse] + ) -> ExecuteResponse: + """ + Convert a SEA response to an ExecuteResponse and extract result data. + + Args: + sea_response: The response from the SEA API + command_id: The command ID + + Returns: + ExecuteResponse: The normalized execute response + """ + + # Extract description from manifest schema + description = self._extract_description_from_manifest(response.manifest) + + # Check for compression + lz4_compressed = ( + response.manifest.result_compression == ResultCompression.LZ4_FRAME + ) + + execute_response = ExecuteResponse( + command_id=CommandId.from_sea_statement_id(response.statement_id), + status=response.status.state, + description=description, + has_been_closed_server_side=False, + lz4_compressed=lz4_compressed, + is_staging_operation=response.manifest.is_volume_operation, + arrow_schema_bytes=None, + result_format=response.manifest.format, + ) + + return execute_response + + def _response_to_result_set( + self, + response: Union[ExecuteStatementResponse, GetStatementResponse], + cursor: Cursor, + ) -> SeaResultSet: + """ + Convert a SEA response to a SeaResultSet. + """ + + execute_response = self._results_message_to_execute_response(response) + + return SeaResultSet( + connection=cursor.connection, + execute_response=execute_response, + sea_client=self, + result_data=response.result, + manifest=response.manifest, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + ) + + def _check_command_not_in_failed_or_closed_state( + self, status: StatementStatus, command_id: CommandId + ) -> None: + state = status.state + if state == CommandState.CLOSED: + raise DatabaseError( + "Command {} unexpectedly closed server side".format(command_id), + { + "operation-id": command_id, + }, + ) + if state == CommandState.FAILED: + error = status.error + error_code = error.error_code if error else "UNKNOWN_ERROR_CODE" + error_message = error.message if error else "UNKNOWN_ERROR_MESSAGE" + raise ServerOperationError( + "Command failed: {} - {}".format(error_code, error_message), + { + "operation-id": command_id, + }, + ) + + def _wait_until_command_done( + self, response: ExecuteStatementResponse + ) -> Union[ExecuteStatementResponse, GetStatementResponse]: + """ + Wait until a command is done. + """ + + final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response + command_id = CommandId.from_sea_statement_id(final_response.statement_id) + + while final_response.status.state in [ + CommandState.PENDING, + CommandState.RUNNING, + ]: + time.sleep(self.POLL_INTERVAL_SECONDS) + final_response = self._poll_query(command_id) + + self._check_command_not_in_failed_or_closed_state( + final_response.status, command_id + ) + + return final_response + + def execute_command( + self, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: Cursor, + use_cloud_fetch: bool, + parameters: List[ttypes.TSparkParameter], + async_op: bool, + enforce_embedded_schema_correctness: bool, + row_limit: Optional[int] = None, + ) -> Union[SeaResultSet, None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=( + param.value.stringValue if param.value is not None else None + ), + type=param.type, + ) + ) + + format = ( + ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY + ).value + disposition = ( + ( + ResultDisposition.HYBRID + if self.use_hybrid_disposition + else ResultDisposition.EXTERNAL_LINKS + ) + if use_cloud_fetch + else ResultDisposition.INLINE + ).value + result_compression = ( + ResultCompression.LZ4_FRAME if lz4_compression else ResultCompression.NONE + ).value + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value, + on_wait_timeout="CONTINUE", + row_limit=row_limit, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, + ) + + response_data = self._http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return and let the client poll for results + if async_op: + return None + + final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response + if response.status.state != CommandState.SUCCEEDED: + final_response = self._wait_until_command_done(response) + + return self._response_to_result_set(final_response, cursor) + + def cancel_command(self, command_id: CommandId) -> None: + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self._http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + def close_command(self, command_id: CommandId) -> None: + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self._http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + def _poll_query(self, command_id: CommandId) -> GetStatementResponse: + """ + Poll for the current command info. + """ + + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self._http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + response = GetStatementResponse.from_dict(response_data) + + return response + + def get_query_state(self, command_id: CommandId) -> CommandState: + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + + response = self._poll_query(command_id) + return response.status.state + + def get_execution_result( + self, + command_id: CommandId, + cursor: Cursor, + ) -> SeaResultSet: + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + SeaResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + + response = self._poll_query(command_id) + return self._response_to_result_set(response, cursor) + + # == Metadata Operations == + + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, + ) -> SeaResultSet: + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation=MetadataCommands.SHOW_CATALOGS.value, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result + + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ) -> SeaResultSet: + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise DatabaseError("Catalog name is required for get_schemas") + + operation = MetadataCommands.SHOW_SCHEMAS.value.format(catalog_name) + + if schema_name: + operation += MetadataCommands.LIKE_PATTERN.value.format(schema_name) + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result + + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ) -> SeaResultSet: + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + operation = ( + MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value + if catalog_name in [None, "*", "%"] + else MetadataCommands.SHOW_TABLES.value.format( + MetadataCommands.CATALOG_SPECIFIC.value.format(catalog_name) + ) + ) + + if schema_name: + operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name) + + if table_name: + operation += MetadataCommands.LIKE_PATTERN.value.format(table_name) + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types + from databricks.sql.backend.sea.utils.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result + + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ) -> SeaResultSet: + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise DatabaseError("Catalog name is required for get_columns") + + operation = MetadataCommands.SHOW_COLUMNS.value.format(catalog_name) + + if schema_name: + operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name) + + if table_name: + operation += MetadataCommands.TABLE_LIKE_PATTERN.value.format(table_name) + + if column_name: + operation += MetadataCommands.LIKE_PATTERN.value.format(column_name) + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py new file mode 100644 index 000000000..b899b791d --- /dev/null +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -0,0 +1,50 @@ +""" +Models for the SEA (Statement Execution API) backend. + +This package contains data models for SEA API requests and responses. +""" + +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ResultManifest, +) + +from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, + CreateSessionRequest, + DeleteSessionRequest, +) + +from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, + CreateSessionResponse, +) + +__all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ResultManifest", + # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", + "CreateSessionRequest", + "DeleteSessionRequest", + # Response models + "ExecuteStatementResponse", + "GetStatementResponse", + "CreateSessionResponse", +] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..3eacc8887 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,82 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + byte_count: int = 0 + row_count: int = 0 + row_offset: int = 0 + next_chunk_index: Optional[int] = None + next_chunk_internal_link: Optional[str] = None + http_headers: Optional[Dict[str, str]] = None + + +@dataclass +class ChunkInfo: + """Information about a chunk in the result set.""" + + chunk_index: int + byte_count: int + row_offset: int + row_count: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + byte_count: Optional[int] = None + chunk_index: Optional[int] = None + next_chunk_index: Optional[int] = None + next_chunk_internal_link: Optional[str] = None + row_count: Optional[int] = None + row_offset: Optional[int] = None + attachment: Optional[bytes] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + format: str + schema: Dict[str, Any] + total_row_count: int + total_byte_count: int + total_chunk_count: int + truncated: bool = False + chunks: Optional[List[ChunkInfo]] = None + result_compression: Optional[str] = None + is_volume_operation: bool = False diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py new file mode 100644 index 000000000..ad046ff54 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -0,0 +1,133 @@ +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Representation of a parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Representation of a request to execute a SQL statement.""" + + session_id: str + statement: str + warehouse_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + result_compression: Optional[str] = None + parameters: Optional[List[StatementParameter]] = None + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + "value": param.value, + "type": param.type, + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Representation of a request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Representation of a request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Representation of a request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CreateSessionRequest: + """Representation of a request to create a new session.""" + + warehouse_id: str + session_confs: Optional[Dict[str, str]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = {"warehouse_id": self.warehouse_id} + + if self.session_confs: + result["session_confs"] = self.session_confs + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + return result + + +@dataclass +class DeleteSessionRequest: + """Representation of a request to delete a session.""" + + warehouse_id: str + session_id: str + + def to_dict(self) -> Dict[str, str]: + """Convert the request to a dictionary for JSON serialization.""" + return {"warehouse_id": self.warehouse_id, "session_id": self.session_id} diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py new file mode 100644 index 000000000..75596ec9b --- /dev/null +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -0,0 +1,162 @@ +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +import base64 +from typing import Dict, Any, List, Optional +from dataclasses import dataclass + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, + ExternalLink, + ChunkInfo, +) + + +def _parse_status(data: Dict[str, Any]) -> StatementStatus: + """Parse status from response data.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + return StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + +def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: + """Parse manifest from response data.""" + + manifest_data = data.get("manifest", {}) + chunks = None + if "chunks" in manifest_data: + chunks = [ + ChunkInfo( + chunk_index=chunk.get("chunk_index", 0), + byte_count=chunk.get("byte_count", 0), + row_offset=chunk.get("row_offset", 0), + row_count=chunk.get("row_count", 0), + ) + for chunk in manifest_data.get("chunks", []) + ] + + return ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=chunks, + result_compression=manifest_data.get("result_compression"), + is_volume_operation=manifest_data.get("is_volume_operation", False), + ) + + +def _parse_result(data: Dict[str, Any]) -> ResultData: + """Parse result data from response data.""" + result_data = data.get("result", {}) + external_links = None + + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get("next_chunk_internal_link"), + http_headers=link_data.get("http_headers"), + ) + ) + + # Handle attachment field - decode from base64 if present + attachment = result_data.get("attachment") + if attachment is not None: + attachment = base64.b64decode(attachment) + + return ResultData( + data=result_data.get("data_array"), + external_links=external_links, + byte_count=result_data.get("byte_count"), + chunk_index=result_data.get("chunk_index"), + next_chunk_index=result_data.get("next_chunk_index"), + next_chunk_internal_link=result_data.get("next_chunk_internal_link"), + row_count=result_data.get("row_count"), + row_offset=result_data.get("row_offset"), + attachment=attachment, + ) + + +@dataclass +class ExecuteStatementResponse: + """Representation of the response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: ResultManifest + result: ResultData + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + return cls( + statement_id=data.get("statement_id", ""), + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), + ) + + +@dataclass +class GetStatementResponse: + """Representation of the response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: ResultManifest + result: ResultData + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + return cls( + statement_id=data.get("statement_id", ""), + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), + ) + + +@dataclass +class CreateSessionResponse: + """Representation of the response from creating a new session.""" + + session_id: str + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": + """Create a CreateSessionResponse from a dictionary.""" + return cls(session_id=data.get("session_id", "")) diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py new file mode 100644 index 000000000..46ce8c98a --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -0,0 +1,67 @@ +""" +Constants for the Statement Execution API (SEA) backend. +""" + +from typing import Dict +from enum import Enum + +# from https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters +ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: Dict[str, str] = { + "ANSI_MODE": "true", + "ENABLE_PHOTON": "true", + "LEGACY_TIME_PARSER_POLICY": "Exception", + "MAX_FILE_PARTITION_BYTES": "128m", + "READ_ONLY_EXTERNAL_METASTORE": "false", + "STATEMENT_TIMEOUT": "0", + "TIMEZONE": "UTC", + "USE_CACHED_RESULT": "true", +} + + +class ResultFormat(Enum): + """Enum for result format values.""" + + ARROW_STREAM = "ARROW_STREAM" + JSON_ARRAY = "JSON_ARRAY" + + +class ResultDisposition(Enum): + """Enum for result disposition values.""" + + HYBRID = "INLINE_OR_EXTERNAL_LINKS" + EXTERNAL_LINKS = "EXTERNAL_LINKS" + INLINE = "INLINE" + + +class ResultCompression(Enum): + """Enum for result compression values.""" + + LZ4_FRAME = "LZ4_FRAME" + NONE = None + + +class WaitTimeout(Enum): + """Enum for wait timeout values.""" + + ASYNC = "0s" + SYNC = "10s" + + +class MetadataCommands(Enum): + """SQL commands used in the SEA backend. + + These constants are used for metadata operations and other SQL queries + to ensure consistency and avoid string literal duplication. + """ + + SHOW_CATALOGS = "SHOW CATALOGS" + SHOW_SCHEMAS = "SHOW SCHEMAS IN {}" + SHOW_TABLES = "SHOW TABLES IN {}" + SHOW_TABLES_ALL_CATALOGS = "SHOW TABLES IN ALL CATALOGS" + SHOW_COLUMNS = "SHOW COLUMNS IN CATALOG {}" + + LIKE_PATTERN = " LIKE '{}'" + SCHEMA_LIKE_PATTERN = " SCHEMA" + LIKE_PATTERN + TABLE_LIKE_PATTERN = " TABLE" + LIKE_PATTERN + + CATALOG_SPECIFIC = "CATALOG {}" diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py new file mode 100644 index 000000000..43db35984 --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -0,0 +1,152 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +from __future__ import annotations + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + cast, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import SeaResultSet + +from databricks.sql.backend.types import ExecuteResponse + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool] + ) -> SeaResultSet: + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + + # Get all remaining rows + all_rows = result_set.results.remaining_rows() + + # Filter rows + filtered_rows = [row for row in all_rows if filter_func(row)] + + # Reuse the command_id from the original result set + command_id = result_set.command_id + + # Create an ExecuteResponse for the filtered data + execute_response = ExecuteResponse( + command_id=command_id, + status=result_set.status, + description=result_set.description, + has_been_closed_server_side=result_set.has_been_closed_server_side, + lz4_compressed=result_set.lz4_compressed, + arrow_schema_bytes=result_set._arrow_schema_bytes, + is_staging_operation=False, + ) + + # Create a new ResultData object with filtered data + from databricks.sql.backend.sea.models.base import ResultData + + result_data = ResultData(data=filtered_rows, external_links=None) + + from databricks.sql.backend.sea.backend import SeaDatabricksClient + from databricks.sql.result_set import SeaResultSet + + # Create a new SeaResultSet with the filtered data + filtered_result_set = SeaResultSet( + connection=result_set.connection, + execute_response=execute_response, + sea_client=cast(SeaDatabricksClient, result_set.backend), + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + result_data=result_data, + ) + + return filtered_result_set + + @staticmethod + def filter_by_column_values( + result_set: SeaResultSet, + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> SeaResultSet: + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + @staticmethod + def filter_tables_by_type( + result_set: SeaResultSet, table_types: Optional[List[str]] = None + ) -> SeaResultSet: + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=True + ) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py new file mode 100644 index 000000000..fe292919c --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -0,0 +1,186 @@ +import json +import logging +import requests +from typing import Callable, Dict, Any, Optional, List, Tuple +from urllib.parse import urljoin + +from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.types import SSLOptions + +logger = logging.getLogger(__name__) + + +class SeaHttpClient: + """ + HTTP client for Statement Execution API (SEA). + + This client handles the HTTP communication with the SEA endpoints, + including authentication, request formatting, and response parsing. + """ + + def __init__( + self, + server_hostname: str, + port: int, + http_path: str, + http_headers: List[Tuple[str, str]], + auth_provider: AuthProvider, + ssl_options: SSLOptions, + **kwargs, + ): + """ + Initialize the SEA HTTP client. + + Args: + server_hostname: Hostname of the Databricks server + port: Port number for the connection + http_path: HTTP path for the connection + http_headers: List of HTTP headers to include in requests + auth_provider: Authentication provider + ssl_options: SSL configuration options + **kwargs: Additional keyword arguments + """ + + self.server_hostname = server_hostname + self.port = port + self.http_path = http_path + self.auth_provider = auth_provider + self.ssl_options = ssl_options + + self.base_url = f"https://{server_hostname}:{port}" + + self.headers: Dict[str, str] = dict(http_headers) + self.headers.update({"Content-Type": "application/json"}) + + self.max_retries = kwargs.get("_retry_stop_after_attempts_count", 30) + + # Create a session for connection pooling + self.session = requests.Session() + + # Configure SSL verification + if ssl_options.tls_verify: + self.session.verify = ssl_options.tls_trusted_ca_file or True + else: + self.session.verify = False + + # Configure client certificates if provided + if ssl_options.tls_client_cert_file: + client_cert = ssl_options.tls_client_cert_file + client_key = ssl_options.tls_client_cert_key_file + client_key_password = ssl_options.tls_client_cert_key_password + + if client_key: + self.session.cert = (client_cert, client_key) + else: + self.session.cert = client_cert + + if client_key_password: + # Note: requests doesn't directly support key passwords + # This would require more complex handling with libraries like pyOpenSSL + logger.warning( + "Client key password provided but not supported by requests library" + ) + + def _get_auth_headers(self) -> Dict[str, str]: + """Get authentication headers from the auth provider.""" + headers: Dict[str, str] = {} + self.auth_provider.add_headers(headers) + return headers + + def _get_call(self, method: str) -> Callable: + """Get the appropriate HTTP method function.""" + method = method.upper() + if method == "GET": + return self.session.get + if method == "POST": + return self.session.post + if method == "DELETE": + return self.session.delete + raise ValueError(f"Unsupported HTTP method: {method}") + + def _make_request( + self, + method: str, + path: str, + data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Make an HTTP request to the SEA endpoint. + + Args: + method: HTTP method (GET, POST, DELETE) + path: API endpoint path + data: Request payload data + params: Query parameters + + Returns: + Dict[str, Any]: Response data parsed from JSON + + Raises: + RequestError: If the request fails + """ + + url = urljoin(self.base_url, path) + headers: Dict[str, str] = {**self.headers, **self._get_auth_headers()} + + logger.debug(f"making {method} request to {url}") + + try: + call = self._get_call(method) + response = call( + url=url, + headers=headers, + json=data, + params=params, + ) + + # Check for HTTP errors + response.raise_for_status() + + # Log response details + logger.debug(f"Response status: {response.status_code}") + + # Parse JSON response + if response.content: + result = response.json() + # Log response content (but limit it for large responses) + content_str = json.dumps(result) + if len(content_str) > 1000: + logger.debug( + f"Response content (truncated): {content_str[:1000]}..." + ) + else: + logger.debug(f"Response content: {content_str}") + return result + return {} + + except requests.exceptions.RequestException as e: + # Handle request errors and extract details from response if available + error_message = f"SEA HTTP request failed: {str(e)}" + + if hasattr(e, "response") and e.response is not None: + status_code = e.response.status_code + try: + error_details = e.response.json() + error_message = ( + f"{error_message}: {error_details.get('message', '')}" + ) + logger.error( + f"Request failed (status {status_code}): {error_details}" + ) + except (ValueError, KeyError): + # If we can't parse JSON, log raw content + content = ( + e.response.content.decode("utf-8", errors="replace") + if isinstance(e.response.content, bytes) + else str(e.response.content) + ) + logger.error(f"Request failed (status {status_code}): {content}") + else: + logger.error(error_message) + + # Re-raise as a RequestError + from databricks.sql.exc import RequestError + + raise RequestError(error_message, e) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index c40dee604..16a664e78 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,10 +5,11 @@ import math import time import threading -from typing import Union, TYPE_CHECKING +from typing import List, Optional, Union, Any, TYPE_CHECKING from databricks.sql.result_set import ThriftResultSet + if TYPE_CHECKING: from databricks.sql.client import Cursor from databricks.sql.result_set import ResultSet @@ -17,8 +18,9 @@ CommandState, SessionId, CommandId, + ExecuteResponse, ) -from databricks.sql.backend.utils.guid_utils import guid_to_hex_id +from databricks.sql.backend.utils import guid_to_hex_id try: import pyarrow @@ -36,13 +38,12 @@ from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes from databricks.sql import * -from databricks.sql.exc import MaxRetryDurationError from databricks.sql.thrift_api.TCLIService.TCLIService import ( Client as TCLIServiceClient, ) from databricks.sql.utils import ( - ExecuteResponse, + ResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, @@ -786,11 +787,13 @@ def _results_message_to_execute_response(self, resp, operation_state): ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - has_more_rows = ( + + is_direct_results = ( (not direct_results) or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) + description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, self._session_id_hex, @@ -809,39 +812,25 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = None lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - if direct_results and direct_results.resultSet: - assert direct_results.resultSet.results.startRowOffset == 0 - assert direct_results.resultSetMetadata - - arrow_queue_opt = ResultSetQueueFactory.build_queue( - row_set_type=t_result_set_metadata_resp.resultFormat, - t_row_set=direct_results.resultSet.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) - else: - arrow_queue_opt = None - command_id = CommandId.from_thrift_handle(resp.operationHandle) - if command_id is None: - raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") - return ExecuteResponse( - arrow_queue=arrow_queue_opt, - status=CommandState.from_thrift_state(operation_state), - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, + status = CommandState.from_thrift_state(operation_state) + if status is None: + raise ValueError(f"Unknown command state: {operation_state}") + + execute_response = ExecuteResponse( command_id=command_id, + status=status, description=description, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=t_result_set_metadata_resp.isStagingOperation, arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) + return execute_response, is_direct_results + def get_execution_result( self, command_id: CommandId, cursor: Cursor ) -> "ResultSet": @@ -866,9 +855,6 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata - lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, self._session_id_hex, @@ -886,26 +872,21 @@ def get_execution_result( else: schema_bytes = None - queue = ResultSetQueueFactory.build_queue( - row_set_type=resp.resultSetMetadata.resultFormat, - t_row_set=resp.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + is_direct_results = resp.hasMoreRows + + status = CommandState.from_thrift_state(resp.status) or CommandState.RUNNING execute_response = ExecuteResponse( - arrow_queue=queue, - status=CommandState.from_thrift_state(resp.status), + command_id=command_id, + status=status, + description=description, has_been_closed_server_side=False, - has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( @@ -915,6 +896,10 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=resp.results, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -984,6 +969,7 @@ def execute_command( parameters=[], async_op=False, enforce_embedded_schema_correctness=False, + row_limit: Optional[int] = None, ) -> Union["ResultSet", None]: thrift_handle = session_id.to_thrift_handle() if not thrift_handle: @@ -1024,6 +1010,7 @@ def execute_command( useArrowNativeTypes=spark_arrow_types, parameters=parameters, enforceEmbeddedSchemaCorrectness=enforce_embedded_schema_correctness, + resultRowLimit=row_limit, ) resp = self.make_request(self._client.ExecuteStatement, req) @@ -1031,7 +1018,13 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) + + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results return ThriftResultSet( connection=cursor.connection, @@ -1040,6 +1033,10 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_catalogs( @@ -1048,7 +1045,7 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: Cursor, - ) -> "ResultSet": + ) -> ResultSet: thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1061,7 +1058,13 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) + + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results return ThriftResultSet( connection=cursor.connection, @@ -1070,6 +1073,10 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_schemas( @@ -1080,7 +1087,9 @@ def get_schemas( cursor: Cursor, catalog_name=None, schema_name=None, - ) -> "ResultSet": + ) -> ResultSet: + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1095,7 +1104,13 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) + + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results return ThriftResultSet( connection=cursor.connection, @@ -1104,6 +1119,10 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_tables( @@ -1116,7 +1135,9 @@ def get_tables( schema_name=None, table_name=None, table_types=None, - ) -> "ResultSet": + ) -> ResultSet: + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1133,7 +1154,13 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) + + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results return ThriftResultSet( connection=cursor.connection, @@ -1142,6 +1169,10 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_columns( @@ -1154,7 +1185,9 @@ def get_columns( schema_name=None, table_name=None, column_name=None, - ) -> "ResultSet": + ) -> ResultSet: + from databricks.sql.result_set import ThriftResultSet + thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1171,7 +1204,13 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) + + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results return ThriftResultSet( connection=cursor.connection, @@ -1180,6 +1219,10 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def _handle_execute_response(self, resp, cursor): diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index ddeac474a..f645fc6d1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils.guid_utils import guid_to_hex_id @@ -80,6 +81,26 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + Args: + state: SEA state string + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + class BackendType(Enum): """ @@ -389,3 +410,17 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[List[Tuple]] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False + arrow_schema_bytes: Optional[bytes] = None + result_format: Optional[Any] = None diff --git a/src/databricks/sql/backend/utils/__init__.py b/src/databricks/sql/backend/utils/__init__.py index e69de29bb..3d601e5e6 100644 --- a/src/databricks/sql/backend/utils/__init__.py +++ b/src/databricks/sql/backend/utils/__init__.py @@ -0,0 +1,3 @@ +from .guid_utils import guid_to_hex_id + +__all__ = ["guid_to_hex_id"] diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e4166f117..caf6ddc0c 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -26,7 +26,6 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.utils import ( - ExecuteResponse, ParamEscaper, inject_parameters, transform_paramstyle, @@ -100,6 +99,10 @@ def __init__( Connect to a Databricks SQL endpoint or a Databricks cluster. Parameters: + :param use_sea: `bool`, optional (default is False) + Use the SEA backend instead of the Thrift backend. + :param use_hybrid_disposition: `bool`, optional (default is False) + Use the hybrid disposition instead of the inline disposition. :param server_hostname: Databricks instance host name. :param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) @@ -292,6 +295,7 @@ def read(self) -> Optional[OAuthToken]: driver_connection_params=driver_connection_params, user_agent=self.session.useragent_header, ) + self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) def _set_use_inline_params_with_warning(self, value: Union[bool, str]): """Valid values are True, False, and "silent" @@ -379,8 +383,14 @@ def cursor( self, arraysize: int = DEFAULT_ARRAY_SIZE, buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, + row_limit: Optional[int] = None, ) -> "Cursor": """ + Args: + arraysize: The maximum number of rows in direct results. + buffer_size_bytes: The maximum number of bytes in direct results. + row_limit: The maximum number of rows in the result. + Return a new Cursor object using the connection. Will throw an Error if the connection has been closed. @@ -396,6 +406,7 @@ def cursor( self.session.backend, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes, + row_limit=row_limit, ) self._cursors.append(cursor) return cursor @@ -434,6 +445,7 @@ def __init__( backend: DatabricksClient, result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, arraysize: int = DEFAULT_ARRAY_SIZE, + row_limit: Optional[int] = None, ) -> None: """ These objects represent a database cursor, which is used to manage the context of a fetch @@ -443,16 +455,18 @@ def __init__( visible by other cursors or connections. """ - self.connection = connection - self.rowcount = -1 # Return -1 as this is not supported - self.buffer_size_bytes = result_buffer_size_bytes + self.connection: Connection = connection + + self.rowcount: int = -1 # Return -1 as this is not supported + self.buffer_size_bytes: int = result_buffer_size_bytes self.active_result_set: Union[ResultSet, None] = None - self.arraysize = arraysize + self.arraysize: int = arraysize + self.row_limit: Optional[int] = row_limit # Note that Cursor closed => active result set closed, but not vice versa - self.open = True - self.executing_command_id = None - self.backend = backend - self.active_command_id = None + self.open: bool = True + self.executing_command_id: Optional[CommandId] = None + self.backend: DatabricksClient = backend + self.active_command_id: Optional[CommandId] = None self.escaper = ParamEscaper() self.lastrowid = None @@ -853,6 +867,7 @@ def execute( parameters=prepared_params, async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, + row_limit=self.row_limit, ) if self.active_result_set and self.active_result_set.is_staging_operation: @@ -910,6 +925,7 @@ def execute_async( parameters=prepared_params, async_op=True, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, + row_limit=self.row_limit, ) return self diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 074877d32..9627c5977 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,13 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Optional, TYPE_CHECKING +from typing import List, Optional, Any, TYPE_CHECKING import logging import pandas -from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import CommandId, CommandState try: import pyarrow @@ -16,11 +14,14 @@ if TYPE_CHECKING: from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.client import Connection +from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.types import Row from databricks.sql.exc import RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue +from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -36,31 +37,49 @@ def __init__( self, connection: Connection, backend: DatabricksClient, - command_id: CommandId, - op_state: Optional[CommandState], - has_been_closed_server_side: bool, arraysize: int, buffer_size_bytes: int, + command_id: CommandId, + status: CommandState, + has_been_closed_server_side: bool = False, + is_direct_results: bool = False, + results_queue=None, + description=None, + is_staging_operation: bool = False, + lz4_compressed: bool = False, + arrow_schema_bytes: Optional[bytes] = None, ): """ A ResultSet manages the results of a single command. - :param connection: The parent connection that was used to execute this command - :param backend: The specialised backend client to be invoked in the fetch phase - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - amount :param arraysize: The max number of rows to fetch at a time (PEP-249) + Parameters: + :param connection: The parent connection that was used to execute this command + :param backend: The specialised backend client to be invoked in the fetch phase + :param arraysize: The max number of rows to fetch at a time (PEP-249) + :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + :param command_id: The command ID + :param status: The command status + :param has_been_closed_server_side: Whether the command has been closed on the server + :param is_direct_results: Whether the command has more rows + :param results_queue: The results queue + :param description: column description of the results + :param is_staging_operation: Whether the command is a staging operation """ - self.command_id = command_id - self.op_state = op_state - self.has_been_closed_server_side = has_been_closed_server_side self.connection = connection self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 - self.description = None + self.description = description + self.command_id = command_id + self.status = status + self.has_been_closed_server_side = has_been_closed_server_side + self.is_direct_results = is_direct_results + self.results = results_queue + self._is_staging_operation = is_staging_operation + self.lz4_compressed = lz4_compressed + self._arrow_schema_bytes = arrow_schema_bytes def __iter__(self): while True: @@ -75,10 +94,9 @@ def rownumber(self): return self._next_row_index @property - @abstractmethod def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" - pass + return self._is_staging_operation # Define abstract methods that concrete implementations must implement @abstractmethod @@ -118,10 +136,11 @@ def close(self) -> None: If the connection has not been closed, and the result set has not already been closed on the server for some other reason, issue a request to the server to close it. """ - try: + if self.results: + self.results.close() if ( - self.op_state != CommandState.CLOSED + self.status != CommandState.CLOSED and not self.has_been_closed_server_side and self.connection.open ): @@ -131,7 +150,7 @@ def close(self) -> None: logger.info("Operation was canceled by a prior request") finally: self.has_been_closed_server_side = True - self.op_state = CommandState.CLOSED + self.status = CommandState.CLOSED class ThriftResultSet(ResultSet): @@ -145,50 +164,70 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, + t_row_set=None, + max_download_threads: int = 10, + ssl_options=None, + is_direct_results: bool = True, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. - Args: - connection: The parent connection - execute_response: Response from the execute command - thrift_client: The ThriftDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - use_cloud_fetch: Whether to use cloud fetch for retrieving results + Parameters: + :param connection: The parent connection + :param execute_response: Response from the execute command + :param thrift_client: The ThriftDatabricksClient instance for direct access + :param buffer_size_bytes: Buffer size for fetching results + :param arraysize: Default number of rows to fetch + :param use_cloud_fetch: Whether to use cloud fetch for retrieving results + :param t_row_set: The TRowSet containing result data (if available) + :param max_download_threads: Maximum number of download threads for cloud fetch + :param ssl_options: SSL options for cloud fetch + :param is_direct_results: Whether there are more rows to fetch """ - super().__init__( - connection, - thrift_client, - execute_response.command_id, - execute_response.status, - execute_response.has_been_closed_server_side, - arraysize, - buffer_size_bytes, - ) - # Initialize ThriftResultSet-specific attributes - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.lz4_compressed = execute_response.lz4_compressed - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self._is_staging_operation = execute_response.is_staging_operation + self.is_direct_results = is_direct_results + + # Build the results queue if t_row_set is provided + results_queue = None + if t_row_set and execute_response.result_format is not None: + from databricks.sql.utils import ResultSetQueueFactory + + # Create the results queue using the provided format + results_queue = ResultSetQueueFactory.build_queue( + row_set_type=execute_response.result_format, + t_row_set=t_row_set, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + max_download_threads=max_download_threads, + lz4_compressed=execute_response.lz4_compressed, + description=execute_response.description, + ssl_options=ssl_options, + ) + + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=thrift_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + is_direct_results=is_direct_results, + results_queue=results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + lz4_compressed=execute_response.lz4_compressed, + arrow_schema_bytes=execute_response.arrow_schema_bytes, + ) - # Initialize results queue - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity + # Initialize results queue if not provided + if not self.results: self._fill_results_buffer() def _fill_results_buffer(self): - # At initialization or if the server does not have cloud fetch result links available - results, has_more_rows = self.backend.fetch_results( + results, is_direct_results = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -199,7 +238,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -284,7 +323,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -309,7 +348,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -324,7 +363,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows partial_result_chunks = [results] - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -350,7 +389,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -393,11 +432,6 @@ def fetchmany(self, size: int) -> List[Row]: else: return self._convert_arrow_table(self.fetchmany_arrow(size)) - @property - def is_staging_operation(self) -> bool: - """Whether this result set represents a staging operation.""" - return self._is_staging_operation - @staticmethod def _get_schema_description(table_schema_message): """ @@ -414,3 +448,82 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] + + +class SeaResultSet(ResultSet): + """ResultSet implementation for SEA backend.""" + + def __init__( + self, + connection: Connection, + execute_response: ExecuteResponse, + sea_client: SeaDatabricksClient, + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + result_data=None, + manifest=None, + ): + """ + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + execute_response: Response from the execute command + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + result_data: Result data from SEA response (optional) + manifest: Manifest from SEA response (optional) + """ + + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + lz4_compressed=execute_response.lz4_compressed, + arrow_schema_bytes=execute_response.arrow_schema_bytes, + ) + + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError( + "_fill_results_buffer is not implemented for SEA backend" + ) + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 251f502df..aafa02a4b 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Tuple, List, Optional, Any +from typing import Dict, Tuple, List, Optional, Any, Type from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions @@ -8,8 +8,9 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId +from databricks.sql.backend.types import SessionId, BackendType logger = logging.getLogger(__name__) @@ -61,6 +62,7 @@ def __init__( self.useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) base_headers = [("User-Agent", self.useragent_header)] + all_headers = (http_headers or []) + base_headers self._ssl_options = SSLOptions( # Double negation is generally a bad thing, but we have to keep backward compatibility @@ -74,19 +76,49 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - self.backend: DatabricksClient = ThriftDatabricksClient( - self.host, - self.port, + self.backend = self._create_backend( + server_hostname, http_path, - (http_headers or []) + base_headers, + all_headers, self.auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, - **kwargs, + _use_arrow_native_complex_types, + kwargs, ) self.protocol_version = None + def _create_backend( + self, + server_hostname: str, + http_path: str, + all_headers: List[Tuple[str, str]], + auth_provider, + _use_arrow_native_complex_types: Optional[bool], + kwargs: dict, + ) -> DatabricksClient: + """Create and return the appropriate backend client.""" + use_sea = kwargs.get("use_sea", False) + + databricks_client_class: Type[DatabricksClient] + if use_sea: + logger.debug("Creating SEA backend client") + databricks_client_class = SeaDatabricksClient + else: + logger.debug("Creating Thrift backend client") + databricks_client_class = ThriftDatabricksClient + + common_args = { + "server_hostname": server_hostname, + "port": self.port, + "http_path": http_path, + "http_headers": all_headers, + "auth_provider": auth_provider, + "ssl_options": self._ssl_options, + "_use_arrow_native_complex_types": _use_arrow_native_complex_types, + **kwargs, + } + return databricks_client_class(**common_args) + def open(self): self._session_id = self.backend.open_session( session_configuration=self.session_configuration, diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index a3e3e1dd0..7e8a4fa0c 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Mapping from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Union, Sequence +from typing import Any, Dict, List, Optional, Tuple, Union, Sequence import re import lz4.frame @@ -61,7 +61,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -188,6 +188,7 @@ def __init__( def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """Get upto the next n rows of the Arrow dataframe""" + length = min(num_rows, self.n_valid_rows - self.cur_row_index) # Note that the table.slice API is not the same as Python's slice # The second argument should be length, not end index @@ -215,7 +216,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. @@ -363,13 +364,6 @@ def close(self): self.download_manager._shutdown_manager() -ExecuteResponse = namedtuple( - "ExecuteResponse", - "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " - "command_id arrow_queue arrow_schema_bytes", -) - - def _bound(min_x, max_x, x): """Bound x by [min_x, max_x] diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index b5d01a45d..dd509c062 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -326,7 +326,7 @@ def test_retry_abort_close_operation_on_404(self, caplog): with self.connection(extra_params={**self._retry_policy}) as conn: with conn.cursor() as curs: with patch( - "databricks.sql.utils.ExecuteResponse.has_been_closed_server_side", + "databricks.sql.backend.types.ExecuteResponse.has_been_closed_server_side", new_callable=PropertyMock, return_value=False, ): diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 042fcc10a..8f15bccc6 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -30,6 +30,7 @@ OperationalError, RequestError, ) +from databricks.sql.backend.types import CommandState from tests.e2e.common.predicates import ( pysql_has_version, pysql_supports_arrow, @@ -112,10 +113,12 @@ def connection(self, extra_params=()): conn.close() @contextmanager - def cursor(self, extra_params=()): + def cursor(self, extra_params=(), extra_cursor_params=()): with self.connection(extra_params) as conn: cursor = conn.cursor( - arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes + arraysize=self.arraysize, + buffer_size_bytes=self.buffer_size_bytes, + **dict(extra_cursor_params), ) try: yield cursor @@ -808,6 +811,60 @@ def test_catalogs_returns_arrow_table(self): results = cursor.fetchall_arrow() assert isinstance(results, pyarrow.Table) + def test_row_limit_with_larger_result(self): + """Test that row_limit properly constrains results when query would return more rows""" + row_limit = 1000 + with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor: + # Execute a query that returns more than row_limit rows + cursor.execute("SELECT * FROM range(2000)") + rows = cursor.fetchall() + + # Check if the number of rows is limited to row_limit + assert len(rows) == row_limit, f"Expected {row_limit} rows, got {len(rows)}" + + def test_row_limit_with_smaller_result(self): + """Test that row_limit doesn't affect results when query returns fewer rows than limit""" + row_limit = 100 + expected_rows = 50 + with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor: + # Execute a query that returns fewer than row_limit rows + cursor.execute(f"SELECT * FROM range({expected_rows})") + rows = cursor.fetchall() + + # Check if all rows are returned (not limited by row_limit) + assert ( + len(rows) == expected_rows + ), f"Expected {expected_rows} rows, got {len(rows)}" + + @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") + def test_row_limit_with_arrow_larger_result(self): + """Test that row_limit properly constrains arrow results when query would return more rows""" + row_limit = 800 + with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor: + # Execute a query that returns more than row_limit rows + cursor.execute("SELECT * FROM range(1500)") + arrow_table = cursor.fetchall_arrow() + + # Check if the number of rows in the arrow table is limited to row_limit + assert ( + arrow_table.num_rows == row_limit + ), f"Expected {row_limit} rows, got {arrow_table.num_rows}" + + @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") + def test_row_limit_with_arrow_smaller_result(self): + """Test that row_limit doesn't affect arrow results when query returns fewer rows than limit""" + row_limit = 200 + expected_rows = 100 + with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor: + # Execute a query that returns fewer than row_limit rows + cursor.execute(f"SELECT * FROM range({expected_rows})") + arrow_table = cursor.fetchall_arrow() + + # Check if all rows are returned (not limited by row_limit) + assert ( + arrow_table.num_rows == expected_rows + ), f"Expected {expected_rows} rows, got {arrow_table.num_rows}" + # use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep # the 429/503 subsuites separate since they execute under different circumstances. diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index a5db003e7..520a0f377 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -26,9 +26,8 @@ from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState -from databricks.sql.utils import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse -from databricks.sql.utils import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite from tests.unit.test_arrow_queue import ArrowQueueSuite @@ -40,8 +39,6 @@ def new(cls): ThriftBackendMock = Mock(spec=ThriftDatabricksClient) ThriftBackendMock.return_value = ThriftBackendMock - cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) - mock_result_set = Mock(spec=ThriftResultSet) cls.apply_property_to_mock( mock_result_set, @@ -49,7 +46,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - has_more_rows=True, + is_direct_results=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) @@ -116,6 +113,9 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None + mock_backend.fetch_results.return_value = (Mock(), False) + + # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend # Create connection and cursor @@ -142,7 +142,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): assert real_result_set.has_been_closed_server_side is True # 2. op_state should always be CLOSED after close() - assert real_result_set.op_state == CommandState.CLOSED + assert real_result_set.status == CommandState.CLOSED # 3. Backend close_command should be called appropriately if not closed: @@ -179,12 +179,16 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() + mock_results = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( connection=mock_connection, execute_response=Mock(), thrift_client=mock_backend, ) + result_set.results = mock_results + # Setup session mock on the mock_connection mock_session = Mock() mock_session.open = False @@ -200,20 +204,26 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response.has_been_closed_server_side = False mock_connection = Mock() mock_thrift_backend = Mock() + mock_results = Mock() # Setup session mock on the mock_connection mock_session = Mock() mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) + mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( - mock_connection, mock_results_response, mock_thrift_backend + mock_connection, + mock_results_response, + mock_thrift_backend, ) + result_set.results = mock_results result_set.close() mock_thrift_backend.close_command.assert_called_once_with( mock_results_response.command_id ) + mock_results.close.assert_called_once() def test_executing_multiple_commands_uses_the_most_recent_command(self): mock_result_sets = [Mock(), Mock()] @@ -221,6 +231,12 @@ def test_executing_multiple_commands_uses_the_most_recent_command(self): for mock_rs in mock_result_sets: mock_rs.is_staging_operation = False + mock_backend = ThriftDatabricksClientMockFactory.new() + mock_backend.execute_command.side_effect = mock_result_sets + # Set is_staging_operation to False to avoid _handle_staging_operation being called + for mock_rs in mock_result_sets: + mock_rs.is_staging_operation = False + mock_backend = ThriftDatabricksClientMockFactory.new() mock_backend.execute_command.side_effect = mock_result_sets @@ -249,7 +265,10 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = ThriftResultSet(Mock(), Mock(), Mock()) + mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) + + result_set = ThriftResultSet(Mock(), Mock(), mock_backend) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -457,7 +476,6 @@ def make_fake_row_slice(n_rows): mock_aq = Mock() mock_aq.next_n_rows.side_effect = make_fake_row_slice mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq - mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.execute("foo") @@ -541,7 +559,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): self.assertEqual(instance.close_session.call_count, 0) cursor.close() - @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) + @patch("%s.backend.types.ExecuteResponse" % PACKAGE_NAME) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 030510a64..a649941e1 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -8,7 +8,8 @@ pa = None import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.result_set import ThriftResultSet @@ -39,26 +40,30 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) + + # Create a mock backend that will return the queue when _fill_results_buffer is called + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) + mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + + num_cols = len(initial_results[0]) if initial_results else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=True, - has_more_rows=False, - description=Mock(), - lz4_compressed=Mock(), - command_id=None, - arrow_queue=arrow_queue, - arrow_schema_bytes=schema.serialize().to_pybytes(), + description=description, + lz4_compressed=True, is_staging_operation=False, ), - thrift_client=None, + thrift_client=mock_thrift_backend, + t_row_set=None, ) - num_cols = len(initial_results[0]) if initial_results else 0 - rs.description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] return rs @staticmethod @@ -85,20 +90,19 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=False, - has_more_rows=True, - description=[ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ], - lz4_compressed=Mock(), - command_id=None, - arrow_queue=None, - arrow_schema_bytes=None, + description=description, + lz4_compressed=True, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index b302c00da..e4a9e5cdd 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -10,7 +10,8 @@ import pytest import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") @@ -35,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - has_more_rows=False, + is_direct_results=False, description=Mock(), command_id=None, arrow_queue=arrow_queue, diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py new file mode 100644 index 000000000..975376e13 --- /dev/null +++ b/tests/unit/test_filters.py @@ -0,0 +1,160 @@ +""" +Tests for the ResultSetFilter class. +""" + +import unittest +from unittest.mock import MagicMock, patch + +from databricks.sql.backend.sea.utils.filters import ResultSetFilter + + +class TestResultSetFilter(unittest.TestCase): + """Tests for the ResultSetFilter class.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a mock SeaResultSet + self.mock_sea_result_set = MagicMock() + + # Set up the remaining_rows method on the results attribute + self.mock_sea_result_set.results = MagicMock() + self.mock_sea_result_set.results.remaining_rows.return_value = [ + ["catalog1", "schema1", "table1", "owner1", "2023-01-01", "TABLE", ""], + ["catalog1", "schema1", "table2", "owner1", "2023-01-01", "VIEW", ""], + [ + "catalog1", + "schema1", + "table3", + "owner1", + "2023-01-01", + "SYSTEM TABLE", + "", + ], + [ + "catalog1", + "schema1", + "table4", + "owner1", + "2023-01-01", + "EXTERNAL TABLE", + "", + ], + ] + + # Set up the connection and other required attributes + self.mock_sea_result_set.connection = MagicMock() + self.mock_sea_result_set.backend = MagicMock() + self.mock_sea_result_set.buffer_size_bytes = 1000 + self.mock_sea_result_set.arraysize = 100 + self.mock_sea_result_set.statement_id = "test-statement-id" + self.mock_sea_result_set.lz4_compressed = False + + # Create a mock CommandId + from databricks.sql.backend.types import CommandId, BackendType + + mock_command_id = CommandId(BackendType.SEA, "test-statement-id") + self.mock_sea_result_set.command_id = mock_command_id + + self.mock_sea_result_set.status = MagicMock() + self.mock_sea_result_set.description = [ + ("catalog_name", "string", None, None, None, None, True), + ("schema_name", "string", None, None, None, None, True), + ("table_name", "string", None, None, None, None, True), + ("owner", "string", None, None, None, None, True), + ("creation_time", "string", None, None, None, None, True), + ("table_type", "string", None, None, None, None, True), + ("remarks", "string", None, None, None, None, True), + ] + self.mock_sea_result_set.has_been_closed_server_side = False + self.mock_sea_result_set._arrow_schema_bytes = None + + def test_filter_by_column_values(self): + """Test filtering by column values with various options.""" + # Case 1: Case-sensitive filtering + allowed_values = ["table1", "table3"] + + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + # Call filter_by_column_values on the table_name column (index 2) + result = ResultSetFilter.filter_by_column_values( + self.mock_sea_result_set, 2, allowed_values, case_sensitive=True + ) + + # Verify the filter was applied correctly + mock_sea_result_set_class.assert_called_once() + + # Check the filtered data passed to the constructor + args, kwargs = mock_sea_result_set_class.call_args + result_data = kwargs.get("result_data") + self.assertIsNotNone(result_data) + self.assertEqual(len(result_data.data), 2) + self.assertIn(result_data.data[0][2], allowed_values) + self.assertIn(result_data.data[1][2], allowed_values) + + # Case 2: Case-insensitive filtering + mock_sea_result_set_class.reset_mock() + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + # Call filter_by_column_values with case-insensitive matching + result = ResultSetFilter.filter_by_column_values( + self.mock_sea_result_set, + 2, + ["TABLE1", "TABLE3"], + case_sensitive=False, + ) + mock_sea_result_set_class.assert_called_once() + + def test_filter_tables_by_type(self): + """Test filtering tables by type with various options.""" + # Case 1: Specific table types + table_types = ["TABLE", "VIEW"] + + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): + with patch.object( + ResultSetFilter, "filter_by_column_values" + ) as mock_filter: + ResultSetFilter.filter_tables_by_type( + self.mock_sea_result_set, table_types + ) + args, kwargs = mock_filter.call_args + self.assertEqual(args[0], self.mock_sea_result_set) + self.assertEqual(args[1], 5) # Table type column index + self.assertEqual(args[2], table_types) + self.assertEqual(kwargs.get("case_sensitive"), True) + + # Case 2: Default table types (None or empty list) + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): + with patch.object( + ResultSetFilter, "filter_by_column_values" + ) as mock_filter: + # Test with None + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None) + args, kwargs = mock_filter.call_args + self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) + + # Test with empty list + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, []) + args, kwargs = mock_filter.call_args + self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py new file mode 100644 index 000000000..6d839162e --- /dev/null +++ b/tests/unit/test_sea_backend.py @@ -0,0 +1,886 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.sea.backend import ( + SeaDatabricksClient, + _filter_session_configuration, +) +from databricks.sql.backend.sea.models.base import ServiceError, StatementStatus +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.parameters.native import IntegerParameter, TDbsqlParameter +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions +from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.exc import ( + Error, + NotSupportedError, + ProgrammingError, + ServerOperationError, + DatabaseError, +) + + +class TestSeaBackend: + """Test suite for the SeaDatabricksClient class.""" + + @pytest.fixture + def mock_http_client(self): + """Create a mock HTTP client.""" + with patch( + "databricks.sql.backend.sea.backend.SeaHttpClient" + ) as mock_client_class: + mock_client = mock_client_class.return_value + yield mock_client + + @pytest.fixture + def sea_client(self, mock_http_client): + """Create a SeaDatabricksClient instance with mocked dependencies.""" + server_hostname = "test-server.databricks.com" + port = 443 + http_path = "/sql/warehouses/abc123" + http_headers = [("header1", "value1"), ("header2", "value2")] + auth_provider = AuthProvider() + ssl_options = SSLOptions() + + client = SeaDatabricksClient( + server_hostname=server_hostname, + port=port, + http_path=http_path, + http_headers=http_headers, + auth_provider=auth_provider, + ssl_options=ssl_options, + ) + + return client + + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + cursor.buffer_size_bytes = 1000 + cursor.arraysize = 100 + return cursor + + @pytest.fixture + def thrift_session_id(self): + """Create a Thrift session ID (not SEA).""" + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + return SessionId.from_thrift_handle(mock_thrift_handle) + + @pytest.fixture + def thrift_command_id(self): + """Create a Thrift command ID (not SEA).""" + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + return CommandId.from_thrift_handle(mock_thrift_operation_handle) + + def test_initialization(self, mock_http_client): + """Test client initialization and warehouse ID extraction.""" + # Test with warehouses format + client1 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client1.warehouse_id == "abc123" + assert client1.max_download_threads == 10 # Default value + + # Test with endpoints format + client2 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/endpoints/def456", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client2.warehouse_id == "def456" + + # Test with custom max_download_threads + client3 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=5, + ) + assert client3.max_download_threads == 5 + + # Test with invalid HTTP path + with pytest.raises(ValueError) as excinfo: + SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/invalid/path", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert "Could not extract warehouse ID" in str(excinfo.value) + + def test_session_management(self, sea_client, mock_http_client, thrift_session_id): + """Test session management methods.""" + # Test open_session with minimal parameters + mock_http_client._make_request.return_value = {"session_id": "test-session-123"} + session_id = sea_client.open_session(None, None, None) + assert isinstance(session_id, SessionId) + assert session_id.backend_type == BackendType.SEA + assert session_id.guid == "test-session-123" + mock_http_client._make_request.assert_called_with( + method="POST", path=sea_client.SESSION_PATH, data={"warehouse_id": "abc123"} + ) + + # Test open_session with all parameters + mock_http_client.reset_mock() + mock_http_client._make_request.return_value = {"session_id": "test-session-456"} + session_config = { + "ANSI_MODE": "FALSE", # Supported parameter + "STATEMENT_TIMEOUT": "3600", # Supported parameter + "unsupported_param": "value", # Unsupported parameter + } + catalog = "test_catalog" + schema = "test_schema" + session_id = sea_client.open_session(session_config, catalog, schema) + assert session_id.guid == "test-session-456" + expected_data = { + "warehouse_id": "abc123", + "session_confs": { + "ansi_mode": "FALSE", + "statement_timeout": "3600", + }, + "catalog": catalog, + "schema": schema, + } + mock_http_client._make_request.assert_called_with( + method="POST", path=sea_client.SESSION_PATH, data=expected_data + ) + + # Test open_session error handling + mock_http_client.reset_mock() + mock_http_client._make_request.return_value = {} + with pytest.raises(Error) as excinfo: + sea_client.open_session(None, None, None) + assert "Failed to create session" in str(excinfo.value) + + # Test close_session with valid ID + mock_http_client.reset_mock() + session_id = SessionId.from_sea_session_id("test-session-789") + sea_client.close_session(session_id) + mock_http_client._make_request.assert_called_with( + method="DELETE", + path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), + data={"session_id": "test-session-789", "warehouse_id": "abc123"}, + ) + + # Test close_session with invalid ID type + with pytest.raises(ValueError) as excinfo: + sea_client.close_session(thrift_session_id) + assert "Not a valid SEA session ID" in str(excinfo.value) + + def test_command_execution_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test synchronous command execution.""" + # Test synchronous execution + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + with patch.object( + sea_client, "_response_to_result_set", return_value="mock_result_set" + ) as mock_get_result: + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result == "mock_result_set" + + # Test with invalid session ID + with pytest.raises(ValueError) as excinfo: + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + thrift_session_id = SessionId.from_thrift_handle(mock_thrift_handle) + + sea_client.execute_command( + operation="SELECT 1", + session_id=thrift_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert "Not a valid SEA session ID" in str(excinfo.value) + + def test_command_execution_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test asynchronous command execution.""" + # Test asynchronous execution + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response + + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, + enforce_embedded_schema_correctness=False, + ) + assert result is None + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_command_execution_advanced( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test advanced command execution scenarios.""" + # Test with polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + with patch.object( + sea_client, "_response_to_result_set", return_value="mock_result_set" + ) as mock_get_result: + with patch("time.sleep"): + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result == "mock_result_set" + + # Test with parameters + mock_http_client.reset_mock() + mock_http_client._make_request.side_effect = None # Reset side_effect + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + } + mock_http_client._make_request.return_value = execute_response + dbsql_param = IntegerParameter(name="param1", value=1) + param = dbsql_param.as_tspark_param(named=True) + + with patch.object(sea_client, "_response_to_result_set"): + sea_client.execute_command( + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[param], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "1" + assert kwargs["data"]["parameters"][0]["type"] == "INT" + + # Test execution failure + mock_http_client.reset_mock() + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + mock_http_client._make_request.return_value = error_response + + with patch("time.sleep"): + with patch.object( + sea_client, "get_query_state", return_value=CommandState.FAILED + ): + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert "Command failed" in str(excinfo.value) + + def test_command_management( + self, + sea_client, + mock_http_client, + sea_command_id, + thrift_command_id, + mock_cursor, + ): + """Test command management methods.""" + # Test cancel_command + mock_http_client._make_request.return_value = {} + sea_client.cancel_command(sea_command_id) + mock_http_client._make_request.assert_called_with( + method="POST", + path=sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, + ) + + # Test cancel_command with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.cancel_command(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) + + # Test close_command + mock_http_client.reset_mock() + sea_client.close_command(sea_command_id) + mock_http_client._make_request.assert_called_with( + method="DELETE", + path=sea_client.STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, + ) + + # Test close_command with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.close_command(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) + + # Test get_query_state + mock_http_client.reset_mock() + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + state = sea_client.get_query_state(sea_command_id) + assert state == CommandState.RUNNING + mock_http_client._make_request.assert_called_with( + method="GET", + path=sea_client.STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, + ) + + # Test get_query_state with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.get_query_state(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) + + # Test get_execution_result + mock_http_client.reset_mock() + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + assert result.command_id.to_sea_statement_id() == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Test get_execution_result with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.get_execution_result(thrift_command_id, mock_cursor) + assert "Not a valid SEA command ID" in str(excinfo.value) + + def test_check_command_state(self, sea_client, sea_command_id): + """Test _check_command_not_in_failed_or_closed_state method.""" + # Test with RUNNING state (should not raise) + sea_client._check_command_not_in_failed_or_closed_state( + StatementStatus(state=CommandState.RUNNING), sea_command_id + ) + + # Test with SUCCEEDED state (should not raise) + sea_client._check_command_not_in_failed_or_closed_state( + StatementStatus(state=CommandState.SUCCEEDED), sea_command_id + ) + + # Test with CLOSED state (should raise DatabaseError) + with pytest.raises(DatabaseError) as excinfo: + sea_client._check_command_not_in_failed_or_closed_state( + StatementStatus(state=CommandState.CLOSED), sea_command_id + ) + assert "Command test-statement-123 unexpectedly closed server side" in str( + excinfo.value + ) + + # Test with FAILED state (should raise ServerOperationError) + with pytest.raises(ServerOperationError) as excinfo: + sea_client._check_command_not_in_failed_or_closed_state( + StatementStatus( + state=CommandState.FAILED, + error=ServiceError(message="Test error", error_code="TEST_ERROR"), + ), + sea_command_id, + ) + assert "Command failed" in str(excinfo.value) + + def test_extract_description_from_manifest(self, sea_client): + """Test _extract_description_from_manifest.""" + manifest_obj = MagicMock() + manifest_obj.schema = { + "columns": [ + { + "name": "col1", + "type_name": "STRING", + "type_precision": 10, + "type_scale": 2, + }, + { + "name": "col2", + "type_name": "INT", + "nullable": False, + }, + ] + } + + description = sea_client._extract_description_from_manifest(manifest_obj) + assert description is not None + assert len(description) == 2 + assert description[0][0] == "col1" # name + assert description[0][1] == "string" # type_code + assert description[0][4] == 10 # precision + assert description[0][5] == 2 # scale + assert description[0][6] is None # null_ok + assert description[1][0] == "col2" # name + assert description[1][1] == "int" # type_code + assert description[1][6] is None # null_ok + + def test_filter_session_configuration(self): + """Test that _filter_session_configuration converts all values to strings.""" + session_config = { + "ANSI_MODE": True, + "statement_timeout": 3600, + "TIMEZONE": "UTC", + "enable_photon": False, + "MAX_FILE_PARTITION_BYTES": 128.5, + "unsupported_param": "value", + "ANOTHER_UNSUPPORTED": 42, + } + + result = _filter_session_configuration(session_config) + + # Verify result is not None + assert result is not None + + # Verify all returned values are strings + for key, value in result.items(): + assert isinstance( + value, str + ), f"Value for key '{key}' is not a string: {type(value)}" + + # Verify specific conversions + expected_result = { + "ansi_mode": "True", # boolean True -> "True", key lowercased + "statement_timeout": "3600", # int -> "3600", key lowercased + "timezone": "UTC", # string -> "UTC", key lowercased + "enable_photon": "False", # boolean False -> "False", key lowercased + "max_file_partition_bytes": "128.5", # float -> "128.5", key lowercased + } + + assert result == expected_result + + # Test with None input + assert _filter_session_configuration(None) == {} + + # Test with only unsupported parameters + unsupported_config = { + "unsupported_param1": "value1", + "unsupported_param2": 123, + } + result = _filter_session_configuration(unsupported_config) + assert result == {} + + # Test case insensitivity for keys + case_insensitive_config = { + "ansi_mode": "false", # lowercase key + "STATEMENT_TIMEOUT": 7200, # uppercase key + "TiMeZoNe": "America/New_York", # mixed case key + } + result = _filter_session_configuration(case_insensitive_config) + expected_case_result = { + "ansi_mode": "false", + "statement_timeout": "7200", + "timezone": "America/New_York", + } + assert result == expected_case_result + + # Verify all values are strings in case insensitive test + for key, value in result.items(): + assert isinstance( + value, str + ), f"Value for key '{key}' is not a string: {type(value)}" + + def test_results_message_to_execute_response_is_staging_operation(self, sea_client): + """Test that is_staging_operation is correctly set from manifest.is_volume_operation.""" + # Test when is_volume_operation is True + response = MagicMock() + response.statement_id = "test-statement-123" + response.status.state = CommandState.SUCCEEDED + response.manifest.is_volume_operation = True + response.manifest.result_compression = "NONE" + response.manifest.format = "JSON_ARRAY" + + # Mock the _extract_description_from_manifest method to return None + with patch.object( + sea_client, "_extract_description_from_manifest", return_value=None + ): + result = sea_client._results_message_to_execute_response(response) + assert result.is_staging_operation is True + + # Test when is_volume_operation is False + response.manifest.is_volume_operation = False + with patch.object( + sea_client, "_extract_description_from_manifest", return_value=None + ): + result = sea_client._results_message_to_execute_response(response) + assert result.is_staging_operation is False + + def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): + """Test the get_catalogs method.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call get_catalogs + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify execute_command was called with the correct parameters + mock_execute.assert_called_once_with( + operation="SHOW CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result is correct + assert result == mock_result_set + + def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): + """Test the get_schemas method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Case 1: With catalog name only + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW SCHEMAS IN test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 2: With catalog and schema names + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + mock_execute.assert_called_with( + operation="SHOW SCHEMAS IN test_catalog LIKE 'test_schema'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 3: Without catalog name (should raise ValueError) + with pytest.raises(DatabaseError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables(self, sea_client, sea_session_id, mock_cursor): + """Test the get_tables method with various parameter combinations.""" + # Mock the execute_command method + from databricks.sql.result_set import SeaResultSet + + mock_result_set = Mock(spec=SeaResultSet) + + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Mock the filter_tables_by_type method + with patch( + "databricks.sql.backend.sea.utils.filters.ResultSetFilter.filter_tables_by_type", + return_value=mock_result_set, + ) as mock_filter: + # Case 1: With catalog name only + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN CATALOG test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + mock_filter.assert_called_with(mock_result_set, None) + + # Case 2: With all parameters + table_types = ["TABLE", "VIEW"] + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + table_types=table_types, + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN CATALOG test_catalog SCHEMA LIKE 'test_schema' LIKE 'test_table'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + mock_filter.assert_called_with(mock_result_set, table_types) + + # Case 3: With wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN ALL CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + def test_get_columns(self, sea_client, sea_session_id, mock_cursor): + """Test the get_columns method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Case 1: With catalog name only + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW COLUMNS IN CATALOG test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 2: With all parameters + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + mock_execute.assert_called_with( + operation="SHOW COLUMNS IN CATALOG test_catalog SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 3: Without catalog name (should raise ValueError) + with pytest.raises(DatabaseError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..c596dbc14 --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,201 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.is_direct_results = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + return mock_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.command_id == execute_response.command_id + assert result_set.status == CommandState.SUCCEEDED + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set.description == execute_response.description + + def test_close(self, mock_connection, mock_sea_client, execute_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, execute_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, execute_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, execute_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, execute_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, + match="_fill_results_buffer is not implemented for SEA backend", + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index a5c751782..6823b1b33 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -62,9 +62,9 @@ def test_auth_args(self, mock_client_class): for args in connection_args: connection = databricks.sql.connect(**args) - host, port, http_path, *_ = mock_client_class.call_args[0] - assert args["server_hostname"] == host - assert args["http_path"] == http_path + call_kwargs = mock_client_class.call_args[1] + assert args["server_hostname"] == call_kwargs["server_hostname"] + assert args["http_path"] == call_kwargs["http_path"] connection.close() @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @@ -72,8 +72,8 @@ def test_http_header_passthrough(self, mock_client_class): http_headers = [("foo", "bar")] databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) - call_args = mock_client_class.call_args[0][3] - assert ("foo", "bar") in call_args + call_kwargs = mock_client_class.call_args[1] + assert ("foo", "bar") in call_kwargs["http_headers"] @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_tls_arg_passthrough(self, mock_client_class): @@ -95,7 +95,8 @@ def test_tls_arg_passthrough(self, mock_client_class): def test_useragent_header(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - http_headers = mock_client_class.call_args[0][3] + call_kwargs = mock_client_class.call_args[1] + http_headers = call_kwargs["http_headers"] user_agent_header = ( "User-Agent", "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), @@ -109,7 +110,8 @@ def test_useragent_header(self, mock_client_class): databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" ), ) - http_headers = mock_client_class.call_args[0][3] + call_kwargs = mock_client_class.call_args[1] + http_headers = call_kwargs["http_headers"] assert user_agent_header_with_entry in http_headers @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 2cfad7bf4..1b1a7e380 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -624,7 +624,10 @@ def test_handle_execute_response_sets_compression_in_direct_results( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=Mock(), + operationStatus=ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ), resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -645,7 +648,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -838,9 +841,10 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value @@ -884,11 +888,12 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - results_message_response = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) + ( + execute_response, + _, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( - results_message_response.status, + execute_response.status, CommandState.SUCCEEDED, ) @@ -952,8 +957,14 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) @@ -978,8 +989,14 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) thrift_backend = self._make_fake_thrift_backend() - thrift_backend._handle_execute_response(t_execute_resp, Mock()) + _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -993,10 +1010,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1008,7 +1025,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, ), closeOperation=Mock(), @@ -1024,11 +1041,12 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) + ( + execute_response, + has_more_rows_result, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) - self.assertEqual(has_more_rows, execute_response.has_more_rows) + self.assertEqual(is_direct_results, has_more_rows_result) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1037,10 +1055,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1053,7 +1071,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1086,7 +1104,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( description=Mock(), ) - self.assertEqual(has_more_rows, has_more_rows_resp) + self.assertEqual(is_direct_results, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): @@ -1141,9 +1159,10 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1157,13 +1176,14 @@ def test_execute_statement_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] @@ -1175,9 +1195,10 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1191,11 +1212,12 @@ def test_get_catalogs_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] @@ -1206,9 +1228,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.result_set.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1222,6 +1245,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1233,7 +1257,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( schema_name="schema_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] @@ -1246,9 +1270,10 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.result_set.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1262,6 +1287,7 @@ def test_get_tables_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1275,7 +1301,7 @@ def test_get_tables_calls_client_and_handle_execute_response( table_types=["type1", "type2"], ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] @@ -1290,9 +1316,10 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.result_set.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1306,6 +1333,7 @@ def test_get_columns_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1319,7 +1347,7 @@ def test_get_columns_calls_client_and_handle_execute_response( column_name="column_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] @@ -2208,14 +2236,23 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class + self, mock_handle_execute_response, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value + # Set up the mock to return a tuple with two values + mock_execute_response = Mock() + mock_arrow_schema = Mock() + mock_handle_execute_response.return_value = ( + mock_execute_response, + mock_arrow_schema, + ) + # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( [True, False, None], [True, False, None], [True, False, None]