diff --git a/examples/README.md b/examples/README.md index d931e5642..4fbe85279 100644 --- a/examples/README.md +++ b/examples/README.md @@ -38,3 +38,4 @@ To run all of these examples you can clone the entire repository to your disk. O this example the string `ExamplePartnerTag` will be added to the the user agent on every request. - **`staging_ingestion.py`** shows how the connector handles Databricks' experimental staging ingestion commands `GET`, `PUT`, and `REMOVE`. - **`sqlalchemy.py`** shows a basic example of connecting to Databricks with [SQLAlchemy](https://www.sqlalchemy.org/). +- **`custom_cred_provider.py`** shows how to pass a custom credential provider to bypass connector authentication. Please install databricks-sdk prior to running this example. \ No newline at end of file diff --git a/examples/custom_cred_provider.py b/examples/custom_cred_provider.py new file mode 100644 index 000000000..4c43280fe --- /dev/null +++ b/examples/custom_cred_provider.py @@ -0,0 +1,29 @@ +# please pip install databricks-sdk prior to running this example. + +from databricks import sql +from databricks.sdk.oauth import OAuthClient +import os + +oauth_client = OAuthClient(host=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + client_id=os.getenv("DATABRICKS_CLIENT_ID"), + client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"), + redirect_url=os.getenv("APP_REDIRECT_URL"), + scopes=['all-apis', 'offline_access']) + +consent = oauth_client.initiate_consent() + +creds = consent.launch_external_browser() + +with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path = os.getenv("DATABRICKS_HTTP_PATH"), + credentials_provider=creds) as connection: + + for x in range(1, 5): + cursor = connection.cursor() + cursor.execute('SELECT 1+1') + result = cursor.fetchall() + for row in result: + print(row) + cursor.close() + + connection.close() diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index d0a213aa4..b56d8f7f1 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -5,6 +5,7 @@ AuthProvider, AccessTokenAuthProvider, BasicAuthProvider, + ExternalAuthProvider, DatabricksOAuthProvider, ) from databricks.sql.experimental.oauth_persistence import OAuthPersistence @@ -30,6 +31,7 @@ def __init__( use_cert_as_auth: str = None, tls_client_cert_file: str = None, oauth_persistence=None, + credentials_provider=None, ): self.hostname = hostname self.username = username @@ -42,9 +44,12 @@ def __init__( self.use_cert_as_auth = use_cert_as_auth self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence + self.credentials_provider = credentials_provider def get_auth_provider(cfg: ClientContext): + if cfg.credentials_provider: + return ExternalAuthProvider(cfg.credentials_provider) if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value: assert cfg.oauth_redirect_port_range is not None assert cfg.oauth_client_id is not None @@ -94,5 +99,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port") else PYSQL_OAUTH_REDIRECT_PORT_RANGE, oauth_persistence=kwargs.get("experimental_oauth_persistence"), + credentials_provider=kwargs.get("credentials_provider"), ) return get_auth_provider(cfg) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index b5b1dfcb3..eb368e1ef 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -1,6 +1,7 @@ +import abc import base64 import logging -from typing import Dict, List +from typing import Callable, Dict, List from databricks.sql.auth.oauth import OAuthManager @@ -14,6 +15,22 @@ def add_headers(self, request_headers: Dict[str, str]): pass +HeaderFactory = Callable[[], Dict[str, str]] + +# In order to keep compatibility with SDK +class CredentialsProvider(abc.ABC): + """CredentialsProvider is the protocol (call-side interface) + for authenticating requests to Databricks REST APIs""" + + @abc.abstractmethod + def auth_type(self) -> str: + ... + + @abc.abstractmethod + def __call__(self, *args, **kwargs) -> HeaderFactory: + ... + + # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. class AccessTokenAuthProvider(AuthProvider): @@ -120,3 +137,13 @@ def _update_token_if_expired(self): except Exception as e: logging.error(f"unexpected error in oauth token update", e, exc_info=True) raise e + + +class ExternalAuthProvider(AuthProvider): + def __init__(self, credentials_provider: CredentialsProvider) -> None: + self._header_factory = credentials_provider() + + def add_headers(self, request_headers: Dict[str, str]): + headers = self._header_factory() + for k, v in headers.items(): + request_headers[k] = v diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 59660f17c..c52f9790e 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -1,7 +1,8 @@ import unittest -from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider +from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider, ExternalAuthProvider from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory class Auth(unittest.TestCase): @@ -37,6 +38,22 @@ def test_noop_auth_provider(self): self.assertEqual(len(http_request.keys()), 1) self.assertEqual(http_request['myKey'], 'myVal') + def test_external_provider(self): + class MyProvider(CredentialsProvider): + def auth_type(self) -> str: + return "mine" + + def __call__(self, *args, **kwargs) -> HeaderFactory: + return lambda: {"foo": "bar"} + + auth = ExternalAuthProvider(MyProvider()) + + http_request = {'myKey': 'myVal'} + auth.add_headers(http_request) + self.assertEqual(http_request['foo'], 'bar') + self.assertEqual(len(http_request.keys()), 2) + self.assertEqual(http_request['myKey'], 'myVal') + def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" kwargs = {'access_token': 'dpi123'} @@ -47,6 +64,24 @@ def test_get_python_sql_connector_auth_provider_access_token(self): auth_provider.add_headers(headers) self.assertEqual(headers['Authorization'], 'Bearer dpi123') + def test_get_python_sql_connector_auth_provider_external(self): + + class MyProvider(CredentialsProvider): + def auth_type(self) -> str: + return "mine" + + def __call__(self, *args, **kwargs) -> HeaderFactory: + return lambda: {"foo": "bar"} + + hostname = "moderakh-test.cloud.databricks.com" + kwargs = {'credentials_provider': MyProvider()} + auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") + + headers = {} + auth_provider.add_headers(headers) + self.assertEqual(headers['foo'], 'bar') + def test_get_python_sql_connector_auth_provider_username_password(self): username = "moderakh" password = "Elevate Databricks 123!!!"