diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 3d23344b..7ecb2a56 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -41,7 +41,6 @@ DeleteSessionRequest, StatementParameter, ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -324,7 +323,7 @@ def _extract_description_from_manifest( return columns def _results_message_to_execute_response( - self, response: GetStatementResponse + self, response: ExecuteStatementResponse ) -> ExecuteResponse: """ Convert a SEA response to an ExecuteResponse and extract result data. @@ -358,6 +357,28 @@ def _results_message_to_execute_response( return execute_response + def _response_to_result_set( + self, response: ExecuteStatementResponse, cursor: Cursor + ) -> SeaResultSet: + """ + Convert a SEA response to a SeaResultSet. + """ + + # Create and return a SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet + + execute_response = self._results_message_to_execute_response(response) + + return SeaResultSet( + connection=cursor.connection, + execute_response=execute_response, + sea_client=self, + result_data=response.result, + manifest=response.manifest, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + ) + def _check_command_not_in_failed_or_closed_state( self, state: CommandState, command_id: CommandId ) -> None: @@ -378,7 +399,7 @@ def _check_command_not_in_failed_or_closed_state( def _wait_until_command_done( self, response: ExecuteStatementResponse - ) -> CommandState: + ) -> ExecuteStatementResponse: """ Wait until a command is done. """ @@ -388,11 +409,12 @@ def _wait_until_command_done( while state in [CommandState.PENDING, CommandState.RUNNING]: time.sleep(self.POLL_INTERVAL_SECONDS) - state = self.get_query_state(command_id) + response = self._poll_query(command_id) + state = response.status.state self._check_command_not_in_failed_or_closed_state(state, command_id) - return state + return response def execute_command( self, @@ -494,8 +516,12 @@ def execute_command( if async_op: return None - self._wait_until_command_done(response) - return self.get_execution_result(command_id, cursor) + if response.status.state == CommandState.SUCCEEDED: + # if the response succeeded within the wait_timeout, return the results immediately + return self._response_to_result_set(response, cursor) + + response = self._wait_until_command_done(response) + return self._response_to_result_set(response, cursor) def cancel_command(self, command_id: CommandId) -> None: """ @@ -547,18 +573,9 @@ def close_command(self, command_id: CommandId) -> None: data=request.to_dict(), ) - def get_query_state(self, command_id: CommandId) -> CommandState: + def _poll_query(self, command_id: CommandId) -> ExecuteStatementResponse: """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ProgrammingError: If the command ID is invalid + Poll for the current command info. """ if command_id.backend_type != BackendType.SEA: @@ -574,9 +591,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState: path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), data=request.to_dict(), ) + response = ExecuteStatementResponse.from_dict(response_data) - # Parse the response - response = GetStatementResponse.from_dict(response_data) + return response + + def get_query_state(self, command_id: CommandId) -> CommandState: + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ProgrammingError: If the command ID is invalid + """ + + response = self._poll_query(command_id) return response.status.state def get_execution_result( @@ -598,38 +631,8 @@ def get_execution_result( ValueError: If the command ID is invalid """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - if sea_statement_id is None: - raise ValueError("Not a valid SEA command ID") - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - response = GetStatementResponse.from_dict(response_data) - - # Create and return a SeaResultSet - from databricks.sql.backend.sea.result_set import SeaResultSet - - execute_response = self._results_message_to_execute_response(response) - - return SeaResultSet( - connection=cursor.connection, - execute_response=execute_response, - sea_client=self, - result_data=response.result, - manifest=response.manifest, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, - ) + response = self._poll_query(command_id) + return self._response_to_result_set(response, cursor) # == Metadata Operations == diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index b7c8bd39..e591f0fd 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -25,7 +25,6 @@ from databricks.sql.backend.sea.models.responses import ( ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -47,6 +46,5 @@ "DeleteSessionRequest", # Response models "ExecuteStatementResponse", - "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 302b32d0..01526e3a 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -124,26 +124,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": ) -@dataclass -class GetStatementResponse: - """Representation of the response from getting information about a statement.""" - - statement_id: str - status: StatementStatus - manifest: ResultManifest - result: ResultData - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": - """Create a GetStatementResponse from a dictionary.""" - return cls( - statement_id=data.get("statement_id", ""), - status=_parse_status(data), - manifest=_parse_manifest(data), - result=_parse_result(data), - ) - - @dataclass class CreateSessionResponse: """Representation of the response from creating a new session.""" diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index da45b429..101eb6b4 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -227,7 +227,7 @@ def test_command_execution_sync( mock_http_client._make_request.return_value = execute_response with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" + sea_client, "_response_to_result_set", return_value="mock_result_set" ) as mock_get_result: result = sea_client.execute_command( operation="SELECT 1", @@ -242,9 +242,6 @@ def test_command_execution_sync( enforce_embedded_schema_correctness=False, ) assert result == "mock_result_set" - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" # Test with invalid session ID with pytest.raises(ValueError) as excinfo: @@ -332,7 +329,7 @@ def test_command_execution_advanced( mock_http_client._make_request.side_effect = [initial_response, poll_response] with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" + sea_client, "_response_to_result_set", return_value="mock_result_set" ) as mock_get_result: with patch("time.sleep"): result = sea_client.execute_command( @@ -360,7 +357,7 @@ def test_command_execution_advanced( dbsql_param = IntegerParameter(name="param1", value=1) param = dbsql_param.as_tspark_param(named=True) - with patch.object(sea_client, "get_execution_result"): + with patch.object(sea_client, "_response_to_result_set"): sea_client.execute_command( operation="SELECT * FROM table WHERE col = :param1", session_id=sea_session_id,