Skip to content

Add external auth provider #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ To run all of these examples you can clone the entire repository to your disk. O
this example the string `ExamplePartnerTag` will be added to the the user agent on every request.
- **`staging_ingestion.py`** shows how the connector handles Databricks' experimental staging ingestion commands `GET`, `PUT`, and `REMOVE`.
- **`sqlalchemy.py`** shows a basic example of connecting to Databricks with [SQLAlchemy](https://www.sqlalchemy.org/).
- **`custom_cred_provider.py`** shows how to pass a custom credential provider to bypass connector authentication. Please install databricks-sdk prior to running this example.
29 changes: 29 additions & 0 deletions examples/custom_cred_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# please pip install databricks-sdk prior to running this example.

from databricks import sql
from databricks.sdk.oauth import OAuthClient
import os

oauth_client = OAuthClient(host=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
redirect_url=os.getenv("APP_REDIRECT_URL"),
scopes=['all-apis', 'offline_access'])

consent = oauth_client.initiate_consent()

creds = consent.launch_external_browser()

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
credentials_provider=creds) as connection:

for x in range(1, 5):
cursor = connection.cursor()
cursor.execute('SELECT 1+1')
result = cursor.fetchall()
for row in result:
print(row)
cursor.close()

connection.close()
6 changes: 6 additions & 0 deletions src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AuthProvider,
AccessTokenAuthProvider,
BasicAuthProvider,
ExternalAuthProvider,
DatabricksOAuthProvider,
)
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
Expand All @@ -30,6 +31,7 @@ def __init__(
use_cert_as_auth: str = None,
tls_client_cert_file: str = None,
oauth_persistence=None,
credentials_provider=None,
):
self.hostname = hostname
self.username = username
Expand All @@ -42,9 +44,12 @@ def __init__(
self.use_cert_as_auth = use_cert_as_auth
self.tls_client_cert_file = tls_client_cert_file
self.oauth_persistence = oauth_persistence
self.credentials_provider = credentials_provider


def get_auth_provider(cfg: ClientContext):
if cfg.credentials_provider:
return ExternalAuthProvider(cfg.credentials_provider)
if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value:
assert cfg.oauth_redirect_port_range is not None
assert cfg.oauth_client_id is not None
Expand Down Expand Up @@ -94,5 +99,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
else PYSQL_OAUTH_REDIRECT_PORT_RANGE,
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
credentials_provider=kwargs.get("credentials_provider"),
)
return get_auth_provider(cfg)
29 changes: 28 additions & 1 deletion src/databricks/sql/auth/authenticators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import base64
import logging
from typing import Dict, List
from typing import Callable, Dict, List

from databricks.sql.auth.oauth import OAuthManager

Expand All @@ -14,6 +15,22 @@ def add_headers(self, request_headers: Dict[str, str]):
pass


HeaderFactory = Callable[[], Dict[str, str]]

# In order to keep compatibility with SDK
class CredentialsProvider(abc.ABC):
"""CredentialsProvider is the protocol (call-side interface)
for authenticating requests to Databricks REST APIs"""

@abc.abstractmethod
def auth_type(self) -> str:
...

@abc.abstractmethod
def __call__(self, *args, **kwargs) -> HeaderFactory:
...


# Private API: this is an evolving interface and it will change in the future.
# Please must not depend on it in your applications.
class AccessTokenAuthProvider(AuthProvider):
Expand Down Expand Up @@ -120,3 +137,13 @@ def _update_token_if_expired(self):
except Exception as e:
logging.error(f"unexpected error in oauth token update", e, exc_info=True)
raise e


class ExternalAuthProvider(AuthProvider):
def __init__(self, credentials_provider: CredentialsProvider) -> None:
self._header_factory = credentials_provider()

def add_headers(self, request_headers: Dict[str, str]):
headers = self._header_factory()
for k, v in headers.items():
request_headers[k] = v
37 changes: 36 additions & 1 deletion tests/unit/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest

from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider
from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider, ExternalAuthProvider
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory


class Auth(unittest.TestCase):
Expand Down Expand Up @@ -37,6 +38,22 @@ def test_noop_auth_provider(self):
self.assertEqual(len(http_request.keys()), 1)
self.assertEqual(http_request['myKey'], 'myVal')

def test_external_provider(self):
class MyProvider(CredentialsProvider):
def auth_type(self) -> str:
return "mine"

def __call__(self, *args, **kwargs) -> HeaderFactory:
return lambda: {"foo": "bar"}

auth = ExternalAuthProvider(MyProvider())

http_request = {'myKey': 'myVal'}
auth.add_headers(http_request)
self.assertEqual(http_request['foo'], 'bar')
self.assertEqual(len(http_request.keys()), 2)
self.assertEqual(http_request['myKey'], 'myVal')

def test_get_python_sql_connector_auth_provider_access_token(self):
hostname = "moderakh-test.cloud.databricks.com"
kwargs = {'access_token': 'dpi123'}
Expand All @@ -47,6 +64,24 @@ def test_get_python_sql_connector_auth_provider_access_token(self):
auth_provider.add_headers(headers)
self.assertEqual(headers['Authorization'], 'Bearer dpi123')

def test_get_python_sql_connector_auth_provider_external(self):

class MyProvider(CredentialsProvider):
def auth_type(self) -> str:
return "mine"

def __call__(self, *args, **kwargs) -> HeaderFactory:
return lambda: {"foo": "bar"}

hostname = "moderakh-test.cloud.databricks.com"
kwargs = {'credentials_provider': MyProvider()}
auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs)
self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider")

headers = {}
auth_provider.add_headers(headers)
self.assertEqual(headers['foo'], 'bar')

def test_get_python_sql_connector_auth_provider_username_password(self):
username = "moderakh"
password = "Elevate Databricks 123!!!"
Expand Down