Skip to content

Commit 9f9801a

Browse files
andrefurlan-dbJesse Whitehouse
authored andcommitted
Add external auth provider + example (#101)
Signed-off-by: Andre Furlan <[email protected]> Signed-off-by: Jesse Whitehouse <[email protected]> Co-authored-by: Jesse Whitehouse <[email protected]> Signed-off-by: Sai Shree Pradhan <[email protected]>
1 parent f9d4566 commit 9f9801a

File tree

5 files changed

+100
-2
lines changed

5 files changed

+100
-2
lines changed

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ To run all of these examples you can clone the entire repository to your disk. O
3838
this example the string `ExamplePartnerTag` will be added to the the user agent on every request.
3939
- **`staging_ingestion.py`** shows how the connector handles Databricks' experimental staging ingestion commands `GET`, `PUT`, and `REMOVE`.
4040
- **`sqlalchemy.py`** shows a basic example of connecting to Databricks with [SQLAlchemy](https://www.sqlalchemy.org/).
41+
- **`custom_cred_provider.py`** shows how to pass a custom credential provider to bypass connector authentication. Please install databricks-sdk prior to running this example.

examples/custom_cred_provider.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# please pip install databricks-sdk prior to running this example.
2+
3+
from databricks import sql
4+
from databricks.sdk.oauth import OAuthClient
5+
import os
6+
7+
oauth_client = OAuthClient(host=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
8+
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
9+
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
10+
redirect_url=os.getenv("APP_REDIRECT_URL"),
11+
scopes=['all-apis', 'offline_access'])
12+
13+
consent = oauth_client.initiate_consent()
14+
15+
creds = consent.launch_external_browser()
16+
17+
with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
18+
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
19+
credentials_provider=creds) as connection:
20+
21+
for x in range(1, 5):
22+
cursor = connection.cursor()
23+
cursor.execute('SELECT 1+1')
24+
result = cursor.fetchall()
25+
for row in result:
26+
print(row)
27+
cursor.close()
28+
29+
connection.close()

src/databricks/sql/auth/auth.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
AuthProvider,
66
AccessTokenAuthProvider,
77
BasicAuthProvider,
8+
ExternalAuthProvider,
89
DatabricksOAuthProvider,
910
)
1011
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
@@ -30,6 +31,7 @@ def __init__(
3031
use_cert_as_auth: str = None,
3132
tls_client_cert_file: str = None,
3233
oauth_persistence=None,
34+
credentials_provider=None,
3335
):
3436
self.hostname = hostname
3537
self.username = username
@@ -42,9 +44,12 @@ def __init__(
4244
self.use_cert_as_auth = use_cert_as_auth
4345
self.tls_client_cert_file = tls_client_cert_file
4446
self.oauth_persistence = oauth_persistence
47+
self.credentials_provider = credentials_provider
4548

4649

4750
def get_auth_provider(cfg: ClientContext):
51+
if cfg.credentials_provider:
52+
return ExternalAuthProvider(cfg.credentials_provider)
4853
if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value:
4954
assert cfg.oauth_redirect_port_range is not None
5055
assert cfg.oauth_client_id is not None
@@ -94,5 +99,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
9499
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
95100
else PYSQL_OAUTH_REDIRECT_PORT_RANGE,
96101
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
102+
credentials_provider=kwargs.get("credentials_provider"),
97103
)
98104
return get_auth_provider(cfg)

src/databricks/sql/auth/authenticators.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import abc
12
import base64
23
import logging
3-
from typing import Dict, List
4+
from typing import Callable, Dict, List
45

56
from databricks.sql.auth.oauth import OAuthManager
67

@@ -14,6 +15,22 @@ def add_headers(self, request_headers: Dict[str, str]):
1415
pass
1516

1617

18+
HeaderFactory = Callable[[], Dict[str, str]]
19+
20+
# In order to keep compatibility with SDK
21+
class CredentialsProvider(abc.ABC):
22+
"""CredentialsProvider is the protocol (call-side interface)
23+
for authenticating requests to Databricks REST APIs"""
24+
25+
@abc.abstractmethod
26+
def auth_type(self) -> str:
27+
...
28+
29+
@abc.abstractmethod
30+
def __call__(self, *args, **kwargs) -> HeaderFactory:
31+
...
32+
33+
1734
# Private API: this is an evolving interface and it will change in the future.
1835
# Please must not depend on it in your applications.
1936
class AccessTokenAuthProvider(AuthProvider):
@@ -120,3 +137,13 @@ def _update_token_if_expired(self):
120137
except Exception as e:
121138
logging.error(f"unexpected error in oauth token update", e, exc_info=True)
122139
raise e
140+
141+
142+
class ExternalAuthProvider(AuthProvider):
143+
def __init__(self, credentials_provider: CredentialsProvider) -> None:
144+
self._header_factory = credentials_provider()
145+
146+
def add_headers(self, request_headers: Dict[str, str]):
147+
headers = self._header_factory()
148+
for k, v in headers.items():
149+
request_headers[k] = v

tests/unit/test_auth.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import unittest
22

3-
from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider
3+
from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider, ExternalAuthProvider
44
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
5+
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory
56

67

78
class Auth(unittest.TestCase):
@@ -37,6 +38,22 @@ def test_noop_auth_provider(self):
3738
self.assertEqual(len(http_request.keys()), 1)
3839
self.assertEqual(http_request['myKey'], 'myVal')
3940

41+
def test_external_provider(self):
42+
class MyProvider(CredentialsProvider):
43+
def auth_type(self) -> str:
44+
return "mine"
45+
46+
def __call__(self, *args, **kwargs) -> HeaderFactory:
47+
return lambda: {"foo": "bar"}
48+
49+
auth = ExternalAuthProvider(MyProvider())
50+
51+
http_request = {'myKey': 'myVal'}
52+
auth.add_headers(http_request)
53+
self.assertEqual(http_request['foo'], 'bar')
54+
self.assertEqual(len(http_request.keys()), 2)
55+
self.assertEqual(http_request['myKey'], 'myVal')
56+
4057
def test_get_python_sql_connector_auth_provider_access_token(self):
4158
hostname = "moderakh-test.cloud.databricks.com"
4259
kwargs = {'access_token': 'dpi123'}
@@ -47,6 +64,24 @@ def test_get_python_sql_connector_auth_provider_access_token(self):
4764
auth_provider.add_headers(headers)
4865
self.assertEqual(headers['Authorization'], 'Bearer dpi123')
4966

67+
def test_get_python_sql_connector_auth_provider_external(self):
68+
69+
class MyProvider(CredentialsProvider):
70+
def auth_type(self) -> str:
71+
return "mine"
72+
73+
def __call__(self, *args, **kwargs) -> HeaderFactory:
74+
return lambda: {"foo": "bar"}
75+
76+
hostname = "moderakh-test.cloud.databricks.com"
77+
kwargs = {'credentials_provider': MyProvider()}
78+
auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs)
79+
self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider")
80+
81+
headers = {}
82+
auth_provider.add_headers(headers)
83+
self.assertEqual(headers['foo'], 'bar')
84+
5085
def test_get_python_sql_connector_auth_provider_username_password(self):
5186
username = "moderakh"
5287
password = "Elevate Databricks 123!!!"

0 commit comments

Comments
 (0)