From 7e71e1830943079bdf75f92b35e647fee655aaf6 Mon Sep 17 00:00:00 2001 From: Johan Stenberg Date: Wed, 25 Jun 2025 00:36:28 +0000 Subject: [PATCH] Prototype, custom auth openai --- src/openai/__init__.py | 2 +- src/openai/_base_client.py | 14 ++++++++++++++ src/openai/_client.py | 8 ++++++++ src/openai/lib/azure.py | 15 ++++++++++++++- 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/openai/__init__.py b/src/openai/__init__.py index 92beeb5da1..d3f4429e52 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -84,7 +84,7 @@ from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool from .version import VERSION as VERSION -from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI +from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI, AzureAuth as AzureAuth from .lib._old_api import * from .lib.streaming import ( AssistantEventHandler as AssistantEventHandler, diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py index 2f87d23aaa..8d787ac230 100644 --- a/src/openai/_base_client.py +++ b/src/openai/_base_client.py @@ -819,6 +819,7 @@ def __init__( max_retries: int = DEFAULT_MAX_RETRIES, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, http_client: httpx.Client | None = None, + auth: httpx.Auth | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, _strict_response_validation: bool, @@ -856,6 +857,12 @@ def __init__( # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), ) + self._custom_auth = auth + + @property + @override + def custom_auth(self) -> httpx.Auth | None: + return self._custom_auth def is_closed(self) -> bool: return self._client.is_closed @@ -1343,6 +1350,7 @@ def __init__( max_retries: int = DEFAULT_MAX_RETRIES, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, http_client: httpx.AsyncClient | None = None, + auth: httpx.Auth | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, ) -> None: @@ -1379,7 +1387,13 @@ def __init__( # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), ) + self._custom_auth = auth + @property + @override + def custom_auth(self) -> httpx.Auth | None: + return self._custom_auth + def is_closed(self) -> bool: return self._client.is_closed diff --git a/src/openai/_client.py b/src/openai/_client.py index 4ed9a2f52e..5e6686970f 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -103,6 +103,7 @@ def __init__( # We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. # See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. http_client: httpx.Client | None = None, + auth: httpx.Auth | None = None, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised # if the API responds with invalid data for the expected schema. @@ -149,6 +150,7 @@ def __init__( max_retries=max_retries, timeout=timeout, http_client=http_client, + auth = auth, custom_headers=default_headers, custom_query=default_query, _strict_response_validation=_strict_response_validation, @@ -292,6 +294,7 @@ def copy( base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, http_client: httpx.Client | None = None, + auth: httpx.Auth | None = None, max_retries: int | NotGiven = NOT_GIVEN, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, @@ -329,6 +332,7 @@ def copy( base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, + auth=auth or self.custom_auth, max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, @@ -404,6 +408,7 @@ def __init__( # We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. # See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details. http_client: httpx.AsyncClient | None = None, + auth: httpx.Auth | None = None, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised # if the API responds with invalid data for the expected schema. @@ -450,6 +455,7 @@ def __init__( max_retries=max_retries, timeout=timeout, http_client=http_client, + auth=auth, custom_headers=default_headers, custom_query=default_query, _strict_response_validation=_strict_response_validation, @@ -593,6 +599,7 @@ def copy( base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, http_client: httpx.AsyncClient | None = None, + auth: httpx.Auth | None = None, max_retries: int | NotGiven = NOT_GIVEN, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, @@ -630,6 +637,7 @@ def copy( base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, + auth=auth or self.custom_auth, max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 655dd71d4c..efb996ab9a 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -2,7 +2,7 @@ import os import inspect -from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload +from typing import Any, Union, Mapping, TypeVar, Callable, Generator, Awaitable, cast, overload from typing_extensions import Self, override import httpx @@ -85,6 +85,15 @@ def _prepare_url(self, url: str) -> httpx.URL: return super()._prepare_url(url) +class AzureAuth(httpx.Auth): + def __init__(self, credential: Any, *, scope: str = 'https://cognitiveservices.azure.com/.default'): + self.credential = credential + self.scope = scope + + @override + def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]: + request.headers['Authorization'] = 'Bearer ' + self.credential.get_token(self.scope).token + yield request class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI): @overload @@ -254,6 +263,7 @@ def copy( self, *, api_key: str | None = None, + auth: httpx.Auth | None = None, organization: str | None = None, project: str | None = None, websocket_base_url: str | httpx.URL | None = None, @@ -426,6 +436,7 @@ def __init__( azure_deployment: str | None = None, api_version: str | None = None, api_key: str | None = None, + auth: httpx.Auth | None = None, azure_ad_token: str | None = None, azure_ad_token_provider: AsyncAzureADTokenProvider | None = None, organization: str | None = None, @@ -528,6 +539,7 @@ def copy( self, *, api_key: str | None = None, + auth: httpx.Auth | None = None, organization: str | None = None, project: str | None = None, websocket_base_url: str | httpx.URL | None = None, @@ -549,6 +561,7 @@ def copy( """ return super().copy( api_key=api_key, + auth=auth, organization=organization, project=project, websocket_base_url=websocket_base_url,