Skip to content

Update the AQUA documentation and the AQUA OpenAI client to support multiple inference endpoints in OCI Model Deployment. #1212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions ads/aqua/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,20 @@ class HttpxOCIAuth(httpx.Auth):

def __init__(self, signer: Optional[oci.signer.Signer] = None):
"""
Initialize the HttpxOCIAuth instance.
Initializes the authentication handler with the given or default OCI signer.

Args:
signer (oci.signer.Signer): The OCI signer to use for signing requests.
Parameters
----------
signer : oci.signer.Signer, optional
The OCI signer instance to use. If None, a default signer will be retrieved.
"""

self.signer = signer or authutil.default_signer().get("signer")
try:
self.signer = signer or authutil.default_signer().get("signer")
if not self.signer:
raise ValueError("OCI signer could not be initialized.")
except Exception as e:
logger.error("Failed to initialize OCI signer: %s", e)
raise

def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
"""
Expand All @@ -80,21 +87,31 @@ def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
httpx.Request: The signed HTTPX request.
"""
# Create a requests.Request object from the HTTPX request
req = requests.Request(
method=request.method,
url=str(request.url),
headers=dict(request.headers),
data=request.content,
)
prepared_request = req.prepare()
try:
req = requests.Request(
method=request.method,
url=str(request.url),
headers=dict(request.headers),
data=request.content,
)
prepared_request = req.prepare()
self.signer.do_request_sign(prepared_request)

# Replace headers on the original HTTPX request with signed headers
request.headers.update(prepared_request.headers)
logger.debug("Successfully signed request to %s", request.url)

# Sign the request using the OCI Signer
self.signer.do_request_sign(prepared_request)
# Fix for GET/DELETE requests that OCI Gateway expects with Content-Length
if (
request.method in ["GET", "DELETE"]
and "content-length" not in request.headers
):
request.headers["content-length"] = "0"

# Update the original HTTPX request with the signed headers
request.headers.update(prepared_request.headers)
except Exception as e:
logger.error("Failed to sign request to %s: %s", request.url, e)
raise

# Proceed with the request
yield request


Expand Down Expand Up @@ -330,8 +347,8 @@ def _prepare_headers(
"Content-Type": "application/json",
"Accept": "text/event-stream" if stream else "application/json",
}
if stream:
default_headers["enable-streaming"] = "true"
# if stream:
# default_headers["enable-streaming"] = "true"
if headers:
default_headers.update(headers)

Expand Down Expand Up @@ -495,7 +512,7 @@ def generate(
prompt: str,
payload: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
stream: bool = True,
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
"""
Generate text completion for the given prompt.
Expand All @@ -521,7 +538,7 @@ def chat(
messages: List[Dict[str, Any]],
payload: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
stream: bool = True,
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
"""
Perform a chat interaction with the model.
Expand Down
30 changes: 20 additions & 10 deletions ads/aqua/client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ModelDeploymentBaseEndpoint(ExtendedEnum):
"""Supported base endpoints for model deployments."""

PREDICT = "predict"
PREDICT_WITH_RESPONSE_STREAM = "predictwithresponsestream"
PREDICT_WITH_RESPONSE_STREAM = "predictWithResponseStream"


class AquaOpenAIMixin:
Expand All @@ -51,9 +51,9 @@ def _patch_route(self, original_path: str) -> str:
Returns:
str: The normalized OpenAI-compatible route path (e.g., '/v1/chat/completions').
"""
normalized_path = original_path.lower().rstrip("/")
normalized_path = original_path.rstrip("/")

match = re.search(r"/predict(withresponsestream)?", normalized_path)
match = re.search(r"/predict(WithResponseStream)?", normalized_path)
if not match:
logger.debug("Route header cannot be resolved from path: %s", original_path)
return ""
Expand All @@ -71,7 +71,7 @@ def _patch_route(self, original_path: str) -> str:
"Route suffix does not start with a version prefix (e.g., '/v1'). "
"This may lead to compatibility issues with OpenAI-style endpoints. "
"Consider updating the URL to include a version prefix, "
"such as '/predict/v1' or '/predictwithresponsestream/v1'."
"such as '/predict/v1' or '/predictWithResponseStream/v1'."
)
# route_suffix = f"v1/{route_suffix}"

Expand Down Expand Up @@ -124,13 +124,13 @@ def _patch_headers(self, request: httpx.Request) -> None:

def _patch_url(self) -> httpx.URL:
"""
Strips any suffixes from the base URL to retain only the `/predict` or `/predictwithresponsestream` path.
Strips any suffixes from the base URL to retain only the `/predict` or `/predictWithResponseStream` path.

Returns:
httpx.URL: The normalized base URL with the correct model deployment path.
"""
base_path = f"{self.base_url.path.lower().rstrip('/')}/"
match = re.search(r"/predict(withresponsestream)?/", base_path)
base_path = f"{self.base_url.path.rstrip('/')}/"
match = re.search(r"/predict(WithResponseStream)?/", base_path)
if match:
trimmed = base_path[: match.end() - 1]
return self.base_url.copy_with(path=trimmed)
Expand All @@ -144,7 +144,7 @@ def _prepare_request_common(self, request: httpx.Request) -> None:

This includes:
- Patching headers with streaming and routing info.
- Normalizing the URL path to include only `/predict` or `/predictwithresponsestream`.
- Normalizing the URL path to include only `/predict` or `/predictWithResponseStream`.

Args:
request (httpx.Request): The outgoing HTTPX request.
Expand Down Expand Up @@ -176,6 +176,7 @@ def __init__(
http_client: Optional[httpx.Client] = None,
http_client_kwargs: Optional[Dict[str, Any]] = None,
_strict_response_validation: bool = False,
patch_headers: bool = False,
**kwargs: Any,
) -> None:
"""
Expand All @@ -196,6 +197,7 @@ def __init__(
http_client (httpx.Client, optional): Custom HTTP client; if not provided, one will be auto-created.
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
_strict_response_validation (bool, optional): Enable strict response validation.
patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
**kwargs: Additional keyword arguments passed to the parent __init__.
"""
if http_client is None:
Expand All @@ -207,6 +209,8 @@ def __init__(
logger.debug("API key not provided; using default placeholder for OCI.")
api_key = "OCI"

self.patch_headers = patch_headers

super().__init__(
api_key=api_key,
organization=organization,
Expand All @@ -229,7 +233,8 @@ def _prepare_request(self, request: httpx.Request) -> None:
Args:
request (httpx.Request): The outgoing HTTP request.
"""
self._prepare_request_common(request)
if self.patch_headers:
self._prepare_request_common(request)


class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
Expand All @@ -248,6 +253,7 @@ def __init__(
http_client: Optional[httpx.Client] = None,
http_client_kwargs: Optional[Dict[str, Any]] = None,
_strict_response_validation: bool = False,
patch_headers: bool = False,
**kwargs: Any,
) -> None:
"""
Expand All @@ -269,6 +275,7 @@ def __init__(
http_client (httpx.AsyncClient, optional): Custom asynchronous HTTP client; if not provided, one will be auto-created.
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
_strict_response_validation (bool, optional): Enable strict response validation.
patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
**kwargs: Additional keyword arguments passed to the parent __init__.
"""
if http_client is None:
Expand All @@ -280,6 +287,8 @@ def __init__(
logger.debug("API key not provided; using default placeholder for OCI.")
api_key = "OCI"

self.patch_headers = patch_headers

super().__init__(
api_key=api_key,
organization=organization,
Expand All @@ -302,4 +311,5 @@ async def _prepare_request(self, request: httpx.Request) -> None:
Args:
request (httpx.Request): The outgoing HTTP request.
"""
self._prepare_request_common(request)
if self.patch_headers:
self._prepare_request_common(request)
32 changes: 32 additions & 0 deletions docs/source/user_guide/large_language_model/aqua_client.rst
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,35 @@ The asynchronous client, ``AsynOpenAI``, extends the AsyncOpenAI client. If no a
print(event)

asyncio.run(test_async())


Using the Native OpenAI Client
------------------------------

If you prefer to use the **original `openai.OpenAI` client**, you must manually provide:

- A custom HTTP client created via `ads.aqua.get_httpx_client()`, and
- `api_key="OCI"` (required for SDK compatibility).

.. code-block:: python

import ads
from openai import OpenAI

ads.set_auth(auth="security_token")

# Create the patched HTTP client with OCI signer
http_client = ads.aqua.get_httpx_client()

client = OpenAI(
api_key="OCI",
base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<OCID>/predict/v1",
http_client=http_client
)

response = client.chat.completions.create(
model="odsc-llm",
messages=[{"role": "user", "content": "Write a short story about a unicorn."}],
)

print(response)
2 changes: 1 addition & 1 deletion tests/unitary/with_extras/aqua/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def test_prepare_headers_stream(self):
expected_headers = {
"Content-Type": "application/json",
"Accept": "text/event-stream",
"enable-streaming": "true",
# "enable-streaming": "true",
"Custom-Header": "Value",
}
assert result == expected_headers
Expand Down