diff --git a/pyiceberg/catalog/rest/auth.py b/pyiceberg/catalog/rest/auth.py index 89395f1158..cc692a7165 100644 --- a/pyiceberg/catalog/rest/auth.py +++ b/pyiceberg/catalog/rest/auth.py @@ -17,9 +17,12 @@ import base64 import importlib +import threading +import time from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Type +import requests from requests import HTTPError, PreparedRequest, Session from requests.auth import AuthBase @@ -42,11 +45,15 @@ def auth_header(self) -> Optional[str]: class NoopAuthManager(AuthManager): + """Auth Manager implementation with no auth.""" + def auth_header(self) -> Optional[str]: return None class BasicAuthManager(AuthManager): + """AuthManager implementation that supports basic password auth.""" + def __init__(self, username: str, password: str): credentials = f"{username}:{password}" self._token = base64.b64encode(credentials.encode()).decode() @@ -56,6 +63,12 @@ def auth_header(self) -> str: class LegacyOAuth2AuthManager(AuthManager): + """Legacy OAuth2 AuthManager implementation. + + This class exists for backward compatibility, and will be removed in + PyIceberg 1.0.0 in favor of OAuth2AuthManager. + """ + _session: Session _auth_url: Optional[str] _token: Optional[str] @@ -109,6 +122,80 @@ def auth_header(self) -> str: return f"Bearer {self._token}" +class OAuth2TokenProvider: + """Thread-safe OAuth2 token provider with token refresh support.""" + + client_id: str + client_secret: str + token_url: str + scope: Optional[str] + refresh_margin: int + expires_in: Optional[int] + + _token: Optional[str] + _expires_at: int + _lock: threading.Lock + + def __init__( + self, + client_id: str, + client_secret: str, + token_url: str, + scope: Optional[str] = None, + refresh_margin: int = 60, + expires_in: Optional[int] = None, + ): + self.client_id = client_id + self.client_secret = client_secret + self.token_url = token_url + self.scope = scope + self.refresh_margin = refresh_margin + self.expires_in = expires_in + + self._token = None + self._expires_at = 0 + self._lock = threading.Lock() + + def _refresh_token(self) -> None: + data = { + "grant_type": "client_credentials", + "client_id": self.client_id, + "client_secret": self.client_secret, + } + if self.scope: + data["scope"] = self.scope + + response = requests.post(self.token_url, data=data) + response.raise_for_status() + result = response.json() + + self._token = result["access_token"] + expires_in = result.get("expires_in", self.expires_in) + if expires_in is None: + raise ValueError( + "The expiration time of the Token must be provided by the Server in the Access Token Response in `expired_in` field, or by the PyIceberg Client." + ) + self._expires_at = time.time() + expires_in - self.refresh_margin + + def get_token(self) -> str: + with self._lock: + if not self._token or time.time() >= self._expires_at: + self._refresh_token() + if self._token is None: + raise ValueError("Authorization token is None after refresh") + return self._token + + +class OAuth2AuthManager(AuthManager): + """Auth Manager implementation that supports OAuth2 as defined in IETF RFC6749.""" + + def __init__(self, token_provider: OAuth2TokenProvider): + self.token_provider = token_provider + + def auth_header(self) -> str: + return f"Bearer {self.token_provider.get_token()}" + + class AuthManagerAdapter(AuthBase): """A `requests.auth.AuthBase` adapter that integrates an `AuthManager` into a `requests.Session` to automatically attach the appropriate Authorization header to every request.