diff --git a/elasticsearch/_async/client/__init__.py b/elasticsearch/_async/client/__init__.py index 4e4d5b029..617a94295 100644 --- a/elasticsearch/_async/client/__init__.py +++ b/elasticsearch/_async/client/__init__.py @@ -79,6 +79,8 @@ _quote, _rewrite_parameters, client_node_configs, + is_requests_http_auth, + is_requests_node_class, ) from .watcher import WatcherClient from .xpack import XPackClient @@ -309,9 +311,27 @@ def __init__( sniff_callback = default_sniff_callback if _transport is None: + + requests_session_auth = None + if http_auth is not None and http_auth is not DEFAULT: + if is_requests_http_auth(http_auth): + # If we're using custom requests authentication + # then we need to alert the user that they also + # need to use 'node_class=requests'. + if not is_requests_node_class(node_class): + raise ValueError( + "Using a custom 'requests.auth.AuthBase' class for " + "'http_auth' must be used with node_class='requests'" + ) + + # Reset 'http_auth' to DEFAULT so it's not consumed below. + requests_session_auth = http_auth + http_auth = DEFAULT + node_configs = client_node_configs( hosts, cloud_id=cloud_id, + requests_session_auth=requests_session_auth, connections_per_node=connections_per_node, http_compress=http_compress, verify_certs=verify_certs, diff --git a/elasticsearch/_async/client/utils.py b/elasticsearch/_async/client/utils.py index f14b81fbd..ec0a6e4b8 100644 --- a/elasticsearch/_async/client/utils.py +++ b/elasticsearch/_async/client/utils.py @@ -25,6 +25,8 @@ _quote_query, _rewrite_parameters, client_node_configs, + is_requests_http_auth, + is_requests_node_class, ) __all__ = [ @@ -37,4 +39,6 @@ "SKIP_IN_PATH", "client_node_configs", "_rewrite_parameters", + "is_requests_http_auth", + "is_requests_node_class", ] diff --git a/elasticsearch/_sync/client/__init__.py b/elasticsearch/_sync/client/__init__.py index d514e0d28..a26d65ce6 100644 --- a/elasticsearch/_sync/client/__init__.py +++ b/elasticsearch/_sync/client/__init__.py @@ -79,6 +79,8 @@ _quote, _rewrite_parameters, client_node_configs, + is_requests_http_auth, + is_requests_node_class, ) from .watcher import WatcherClient from .xpack import XPackClient @@ -309,9 +311,27 @@ def __init__( sniff_callback = default_sniff_callback if _transport is None: + + requests_session_auth = None + if http_auth is not None and http_auth is not DEFAULT: + if is_requests_http_auth(http_auth): + # If we're using custom requests authentication + # then we need to alert the user that they also + # need to use 'node_class=requests'. + if not is_requests_node_class(node_class): + raise ValueError( + "Using a custom 'requests.auth.AuthBase' class for " + "'http_auth' must be used with node_class='requests'" + ) + + # Reset 'http_auth' to DEFAULT so it's not consumed below. + requests_session_auth = http_auth + http_auth = DEFAULT + node_configs = client_node_configs( hosts, cloud_id=cloud_id, + requests_session_auth=requests_session_auth, connections_per_node=connections_per_node, http_compress=http_compress, verify_certs=verify_certs, diff --git a/elasticsearch/_sync/client/utils.py b/elasticsearch/_sync/client/utils.py index 25d942324..1444adf05 100644 --- a/elasticsearch/_sync/client/utils.py +++ b/elasticsearch/_sync/client/utils.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. - import base64 +import inspect import warnings from datetime import date, datetime from functools import wraps @@ -41,6 +41,7 @@ AsyncTransport, HttpHeaders, NodeConfig, + RequestsHttpNode, SniffOptions, Transport, ) @@ -88,7 +89,10 @@ def client_node_configs( - hosts: Optional[_TYPE_HOSTS], cloud_id: Optional[str], **kwargs: Any + hosts: Optional[_TYPE_HOSTS], + cloud_id: Optional[str], + requests_session_auth: Optional[Any] = None, + **kwargs: Any, ) -> List[NodeConfig]: if cloud_id is not None: if hosts is not None: @@ -108,6 +112,12 @@ def client_node_configs( headers.setdefault("user-agent", USER_AGENT) node_options["headers"] = headers + # If a custom Requests AuthBase is passed we set that via '_extras'. + if requests_session_auth is not None: + node_options.setdefault("_extras", {})[ + "requests.session.auth" + ] = requests_session_auth + def apply_node_options(node_config: NodeConfig) -> NodeConfig: """Needs special handling of headers since .replace() wipes out existing headers""" nonlocal node_options @@ -406,3 +416,28 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return wrapped # type: ignore[return-value] return wrapper + + +def is_requests_http_auth(http_auth: Any) -> bool: + """Detect if an http_auth value is a custom Requests auth object""" + try: + from requests.auth import AuthBase + + return isinstance(http_auth, AuthBase) + except ImportError: + pass + return False + + +def is_requests_node_class(node_class: Any) -> bool: + """Detect if 'RequestsHttpNode' would be used given the setting of 'node_class'""" + return ( + node_class is not None + and node_class is not DEFAULT + and ( + node_class == "requests" + or ( + inspect.isclass(node_class) and issubclass(node_class, RequestsHttpNode) + ) + ) + ) diff --git a/test_elasticsearch/test_client/test_requests_auth.py b/test_elasticsearch/test_client/test_requests_auth.py new file mode 100644 index 000000000..2eb656f5d --- /dev/null +++ b/test_elasticsearch/test_client/test_requests_auth.py @@ -0,0 +1,84 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import warnings + +import pytest +import requests +from elastic_transport import RequestsHttpNode, Urllib3HttpNode +from elastic_transport.client_utils import DEFAULT +from requests.auth import HTTPBasicAuth + +from elasticsearch import AsyncElasticsearch, Elasticsearch + + +class CustomRequestHttpNode(RequestsHttpNode): + pass + + +class CustomUrllib3HttpNode(Urllib3HttpNode): + pass + + +@pytest.mark.parametrize( + "node_class", ["requests", RequestsHttpNode, CustomRequestHttpNode] +) +def test_requests_auth(node_class): + http_auth = HTTPBasicAuth("username", "password") + + with warnings.catch_warnings(record=True) as w: + client = Elasticsearch( + "http://localhost:9200", http_auth=http_auth, node_class=node_class + ) + + # http_auth is deprecated for all other cases except this one. + assert len(w) == 0 + + # Instance should be forwarded directly to requests.Session.auth. + node = client.transport.node_pool.get() + assert isinstance(node, RequestsHttpNode) + assert isinstance(node.session, requests.Session) + assert node.session.auth is http_auth + + +@pytest.mark.parametrize("client_class", [Elasticsearch, AsyncElasticsearch]) +@pytest.mark.parametrize( + "node_class", ["urllib3", "aiohttp", None, DEFAULT, CustomUrllib3HttpNode] +) +def test_error_for_requests_auth_node_class(client_class, node_class): + http_auth = HTTPBasicAuth("username", "password") + + with pytest.raises(ValueError) as e: + client_class( + "http://localhost:9200", http_auth=http_auth, node_class=node_class + ) + assert str(e.value) == ( + "Using a custom 'requests.auth.AuthBase' class for " + "'http_auth' must be used with node_class='requests'" + ) + + +def test_error_for_requests_auth_async(): + http_auth = HTTPBasicAuth("username", "password") + + with pytest.raises(ValueError) as e: + AsyncElasticsearch( + "http://localhost:9200", http_auth=http_auth, node_class="requests" + ) + assert str(e.value) == ( + "Specified 'node_class' is not async, should be async instead" + )