diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index d812f93b9..cf665d5ec 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -1,4 +1,5 @@ from decimal import Decimal +import errno import logging import math import time @@ -15,6 +16,9 @@ from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes from databricks.sql import * +from databricks.sql.thrift_api.TCLIService.TCLIService import ( + Client as TCLIServiceClient, +) from databricks.sql.utils import ( ArrowQueue, ExecuteResponse, @@ -39,6 +43,7 @@ "_retry_delay_max": (float, 60, 5, 3600), "_retry_stop_after_attempts_count": (int, 30, 1, 60), "_retry_stop_after_attempts_duration": (float, 900, 1, 86400), + "_retry_delay_default": (float, 5, 1, 60), } @@ -71,6 +76,8 @@ def __init__( # _retry_delay_min (default: 1) # _retry_delay_max (default: 60) # {min,max} pre-retry delay bounds + # _retry_delay_default (default: 5) + # Only used when GetOperationStatus fails due to a TCP/OS Error. # _retry_stop_after_attempts_count (default: 30) # total max attempts during retry sequence # _retry_stop_after_attempts_duration (default: 900) @@ -158,7 +165,7 @@ def _initialize_retry_args(self, kwargs): "retry parameter: {} given_or_default {}".format(key, given_or_default) ) if bound != given_or_default: - logger.warn( + logger.warning( "Override out of policy retry parameter: " + "{} given {}, restricted to {}".format( key, given_or_default, bound @@ -243,7 +250,9 @@ def _handle_request_error(self, error_info, attempt, elapsed): # FUTURE: Consider moving to https://github.com/litl/backoff or # https://github.com/jd/tenacity for retry logic. def make_request(self, method, request): - """Execute given request, attempting retries when receiving HTTP 429/503. + """Execute given request, attempting retries when + 1. Receiving HTTP 429/503 from server + 2. OSError is raised during a GetOperationStatus For delay between attempts, honor the given Retry-After header, but with bounds. Use lower bound of expontial-backoff based on _retry_delay_min, @@ -260,6 +269,13 @@ def make_request(self, method, request): def get_elapsed(): return time.time() - t0 + def bound_retry_delay(attempt, proposed_delay): + """bound delay (seconds) by [min_delay*1.5^(attempt-1), max_delay]""" + delay = int(proposed_delay) + delay = max(delay, self._retry_delay_min * math.pow(1.5, attempt - 1)) + delay = min(delay, self._retry_delay_max) + return delay + def extract_retry_delay(attempt): # encapsulate retry checks, returns None || delay-in-secs # Retry IFF 429/503 code + Retry-After header set @@ -267,10 +283,7 @@ def extract_retry_delay(attempt): retry_after = getattr(self._transport, "headers", {}).get("Retry-After") if http_code in [429, 503] and retry_after: # bound delay (seconds) by [min_delay*1.5^(attempt-1), max_delay] - delay = int(retry_after) - delay = max(delay, self._retry_delay_min * math.pow(1.5, attempt - 1)) - delay = min(delay, self._retry_delay_max) - return delay + return bound_retry_delay(attempt, int(retry_after)) return None def attempt_request(attempt): @@ -279,24 +292,57 @@ def attempt_request(attempt): # - non-None method_return -> success, return and be done # - non-None retry_delay -> sleep delay before retry # - error, error_message always set when available + + error, error_message, retry_delay = None, None, None try: logger.debug("Sending request: {}".format(request)) response = method(request) logger.debug("Received response: {}".format(response)) return response - except Exception as error: + except OSError as err: + error = err + error_message = str(err) + + gos_name = TCLIServiceClient.GetOperationStatus.__name__ + if method.__name__ == gos_name: + retry_delay = bound_retry_delay(attempt, self._retry_delay_default) + + # fmt: off + # The built-in errno package encapsulates OSError codes, which are OS-specific. + # log.info for errors we believe are not unusual or unexpected. log.warn for + # for others like EEXIST, EBADF, ERANGE which are not expected in this context. + # + # I manually tested this retry behaviour using mitmweb and confirmed that + # GetOperationStatus requests are retried when I forced network connection + # interruptions / timeouts / reconnects. See #24 for more info. + # | Debian | Darwin | + info_errs = [ # |--------|--------| + errno.ESHUTDOWN, # | 32 | 32 | + errno.EAFNOSUPPORT, # | 97 | 47 | + errno.ECONNRESET, # | 104 | 54 | + errno.ETIMEDOUT, # | 110 | 60 | + ] + + # fmt: on + log_string = f"{gos_name} failed with code {err.errno} and will attempt to retry" + if err.errno in info_errs: + logger.info(log_string) + else: + logger.warning(log_string) + except Exception as err: + error = err retry_delay = extract_retry_delay(attempt) error_message = ThriftBackend._extract_error_message_from_headers( getattr(self._transport, "headers", {}) ) - return RequestErrorInfo( - error=error, - error_message=error_message, - retry_delay=retry_delay, - http_code=getattr(self._transport, "code", None), - method=method.__name__, - request=request, - ) + return RequestErrorInfo( + error=error, + error_message=error_message, + retry_delay=retry_delay, + http_code=getattr(self._transport, "code", None), + method=method.__name__, + request=request, + ) # The real work: # - for each available attempt: diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index d411df76d..e8c5a727f 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -19,6 +19,7 @@ def retry_policy_factory(): "_retry_delay_max": (float, 60, None, None), "_retry_stop_after_attempts_count": (int, 30, None, None), "_retry_stop_after_attempts_duration": (float, 900, None, None), + "_retry_delay_default": (float, 5, 1, 60) } @@ -968,6 +969,62 @@ def test_handle_execute_response_sets_active_op_handle(self): self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle) + @patch("thrift.transport.THttpClient.THttpClient") + @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") + @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) + def test_make_request_will_retry_GetOperationStatus( + self, mock_retry_policy, mock_GetOperationStatus, t_transport_class): + + import thrift, errno + from databricks.sql.thrift_api.TCLIService.TCLIService import Client + from databricks.sql.exc import RequestError + from databricks.sql.utils import NoRetryReason + + this_gos_name = "GetOperationStatus" + mock_GetOperationStatus.__name__ = this_gos_name + mock_GetOperationStatus.side_effect = OSError(errno.ETIMEDOUT, "Connection timed out") + + protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(t_transport_class) + client = Client(protocol) + + req = ttypes.TGetOperationStatusReq( + operationHandle=self.operation_handle, + getProgressUpdate=False, + ) + + EXPECTED_RETRIES = 2 + + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", [], + _retry_stop_after_attempts_count=EXPECTED_RETRIES, + _retry_delay_default=1) + + + with self.assertRaises(RequestError) as cm: + thrift_backend.make_request(client.GetOperationStatus, req) + + self.assertEqual(NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"]) + self.assertEqual(f'{EXPECTED_RETRIES}/{EXPECTED_RETRIES}', cm.exception.context["attempt"]) + + # Unusual OSError code + mock_GetOperationStatus.side_effect = OSError(errno.EEXIST, "File does not exist") + + with self.assertLogs("databricks.sql.thrift_backend", level=logging.WARNING) as cm: + with self.assertRaises(RequestError): + thrift_backend.make_request(client.GetOperationStatus, req) + + # There should be two warning log messages: one for each retry + self.assertEqual(len(cm.output), EXPECTED_RETRIES) + + # The warnings should be identical + self.assertEqual(cm.output[1], cm.output[0]) + + # The warnings should include this text + self.assertIn(f"{this_gos_name} failed with code {errno.EEXIST} and will attempt to retry", cm.output[0]) + + @patch("thrift.transport.THttpClient.THttpClient") def test_make_request_wont_retry_if_headers_not_present(self, t_transport_class): t_transport_instance = t_transport_class.return_value