Skip to content

Commit fa4baef

Browse files
committed
Support passing requests.auth.AuthBase via http_auth
1 parent 4bb16d8 commit fa4baef

File tree

5 files changed

+165
-2
lines changed

5 files changed

+165
-2
lines changed

elasticsearch/_async/client/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@
7979
_quote,
8080
_rewrite_parameters,
8181
client_node_configs,
82+
is_requests_http_auth,
83+
is_requests_node_class,
8284
)
8385
from .watcher import WatcherClient
8486
from .xpack import XPackClient
@@ -309,9 +311,27 @@ def __init__(
309311
sniff_callback = default_sniff_callback
310312

311313
if _transport is None:
314+
315+
requests_session_auth = None
316+
if http_auth is not None and http_auth is not DEFAULT:
317+
if is_requests_http_auth(http_auth):
318+
# If we're using custom requests authentication
319+
# then we need to alert the user that they also
320+
# need to use 'node_class=requests'.
321+
if not is_requests_node_class(node_class):
322+
raise ValueError(
323+
"Using a custom 'requests.auth.AuthBase' class for "
324+
"'http_auth' must be used with node_class='requests'"
325+
)
326+
327+
# Reset 'http_auth' to DEFAULT so it's not consumed below.
328+
requests_session_auth = http_auth
329+
http_auth = DEFAULT
330+
312331
node_configs = client_node_configs(
313332
hosts,
314333
cloud_id=cloud_id,
334+
requests_session_auth=requests_session_auth,
315335
connections_per_node=connections_per_node,
316336
http_compress=http_compress,
317337
verify_certs=verify_certs,

elasticsearch/_async/client/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
_quote_query,
2626
_rewrite_parameters,
2727
client_node_configs,
28+
is_requests_http_auth,
29+
is_requests_node_class,
2830
)
2931

3032
__all__ = [
@@ -37,4 +39,6 @@
3739
"SKIP_IN_PATH",
3840
"client_node_configs",
3941
"_rewrite_parameters",
42+
"is_requests_http_auth",
43+
"is_requests_node_class",
4044
]

elasticsearch/_sync/client/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@
7979
_quote,
8080
_rewrite_parameters,
8181
client_node_configs,
82+
is_requests_http_auth,
83+
is_requests_node_class,
8284
)
8385
from .watcher import WatcherClient
8486
from .xpack import XPackClient
@@ -309,9 +311,27 @@ def __init__(
309311
sniff_callback = default_sniff_callback
310312

311313
if _transport is None:
314+
315+
requests_session_auth = None
316+
if http_auth is not None and http_auth is not DEFAULT:
317+
if is_requests_http_auth(http_auth):
318+
# If we're using custom requests authentication
319+
# then we need to alert the user that they also
320+
# need to use 'node_class=requests'.
321+
if not is_requests_node_class(node_class):
322+
raise ValueError(
323+
"Using a custom 'requests.auth.AuthBase' class for "
324+
"'http_auth' must be used with node_class='requests'"
325+
)
326+
327+
# Reset 'http_auth' to DEFAULT so it's not consumed below.
328+
requests_session_auth = http_auth
329+
http_auth = DEFAULT
330+
312331
node_configs = client_node_configs(
313332
hosts,
314333
cloud_id=cloud_id,
334+
requests_session_auth=requests_session_auth,
315335
connections_per_node=connections_per_node,
316336
http_compress=http_compress,
317337
verify_certs=verify_certs,

elasticsearch/_sync/client/utils.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
1918
import base64
19+
import inspect
2020
import warnings
2121
from datetime import date, datetime
2222
from functools import wraps
@@ -41,6 +41,7 @@
4141
AsyncTransport,
4242
HttpHeaders,
4343
NodeConfig,
44+
RequestsHttpNode,
4445
SniffOptions,
4546
Transport,
4647
)
@@ -88,7 +89,10 @@
8889

8990

9091
def client_node_configs(
91-
hosts: Optional[_TYPE_HOSTS], cloud_id: Optional[str], **kwargs: Any
92+
hosts: Optional[_TYPE_HOSTS],
93+
cloud_id: Optional[str],
94+
requests_session_auth: Optional[Any] = None,
95+
**kwargs: Any,
9296
) -> List[NodeConfig]:
9397
if cloud_id is not None:
9498
if hosts is not None:
@@ -108,6 +112,12 @@ def client_node_configs(
108112
headers.setdefault("user-agent", USER_AGENT)
109113
node_options["headers"] = headers
110114

115+
# If a custom Requests AuthBase is passed we set that via '_extras'.
116+
if requests_session_auth is not None:
117+
node_options.setdefault("_extras", {})[
118+
"requests.session.auth"
119+
] = requests_session_auth
120+
111121
def apply_node_options(node_config: NodeConfig) -> NodeConfig:
112122
"""Needs special handling of headers since .replace() wipes out existing headers"""
113123
nonlocal node_options
@@ -406,3 +416,28 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
406416
return wrapped # type: ignore[return-value]
407417

408418
return wrapper
419+
420+
421+
def is_requests_http_auth(http_auth: Any) -> bool:
422+
"""Detect if an http_auth value is a custom Requests auth object"""
423+
try:
424+
from requests.auth import AuthBase
425+
426+
return isinstance(http_auth, AuthBase)
427+
except ImportError:
428+
pass
429+
return False
430+
431+
432+
def is_requests_node_class(node_class: Any) -> bool:
433+
"""Detect if 'RequestsHttpNode' would be used given the setting of 'node_class'"""
434+
return (
435+
node_class is not None
436+
and node_class is not DEFAULT
437+
and (
438+
node_class == "requests"
439+
or (
440+
inspect.isclass(node_class) and issubclass(node_class, RequestsHttpNode)
441+
)
442+
)
443+
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import warnings
19+
20+
import pytest
21+
import requests
22+
from elastic_transport import RequestsHttpNode, Urllib3HttpNode
23+
from elastic_transport.client_utils import DEFAULT
24+
from requests.auth import HTTPBasicAuth
25+
26+
from elasticsearch import AsyncElasticsearch, Elasticsearch
27+
28+
29+
class CustomRequestHttpNode(RequestsHttpNode):
30+
pass
31+
32+
33+
class CustomUrllib3HttpNode(Urllib3HttpNode):
34+
pass
35+
36+
37+
@pytest.mark.parametrize(
38+
"node_class", ["requests", RequestsHttpNode, CustomRequestHttpNode]
39+
)
40+
def test_requests_auth(node_class):
41+
http_auth = HTTPBasicAuth("username", "password")
42+
43+
with warnings.catch_warnings(record=True) as w:
44+
client = Elasticsearch(
45+
"http://localhost:9200", http_auth=http_auth, node_class=node_class
46+
)
47+
48+
# http_auth is deprecated for all other cases except this one.
49+
assert len(w) == 0
50+
51+
# Instance should be forwarded directly to requests.Session.auth.
52+
node = client.transport.node_pool.get()
53+
assert isinstance(node, RequestsHttpNode)
54+
assert isinstance(node.session, requests.Session)
55+
assert node.session.auth is http_auth
56+
57+
58+
@pytest.mark.parametrize("client_class", [Elasticsearch, AsyncElasticsearch])
59+
@pytest.mark.parametrize(
60+
"node_class", ["urllib3", "aiohttp", None, DEFAULT, CustomUrllib3HttpNode]
61+
)
62+
def test_error_for_requests_auth_node_class(client_class, node_class):
63+
http_auth = HTTPBasicAuth("username", "password")
64+
65+
with pytest.raises(ValueError) as e:
66+
client_class(
67+
"http://localhost:9200", http_auth=http_auth, node_class=node_class
68+
)
69+
assert str(e.value) == (
70+
"Using a custom 'requests.auth.AuthBase' class for "
71+
"'http_auth' must be used with node_class='requests'"
72+
)
73+
74+
75+
def test_error_for_requests_auth_async():
76+
http_auth = HTTPBasicAuth("username", "password")
77+
78+
with pytest.raises(ValueError) as e:
79+
AsyncElasticsearch(
80+
"http://localhost:9200", http_auth=http_auth, node_class="requests"
81+
)
82+
assert str(e.value) == (
83+
"Specified 'node_class' is not async, should be async instead"
84+
)

0 commit comments

Comments
 (0)