Skip to content

Commit 22ef5e9

Browse files
authored
Merge branch 'main' into ODSC-73602/mmd_ft_weights
2 parents fa8f705 + be35fa7 commit 22ef5e9

File tree

9 files changed

+177
-57
lines changed

9 files changed

+177
-57
lines changed

ads/aqua/modeldeployment/constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
"""
@@ -8,3 +8,6 @@
88
99
This module contains constants used in Aqua Model Deployment.
1010
"""
11+
12+
DEFAULT_WAIT_TIME = 12000
13+
DEFAULT_POLL_INTERVAL = 10

ads/aqua/modeldeployment/deployment.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
56
import json
67
import shlex
8+
import threading
79
from datetime import datetime, timedelta
810
from typing import Dict, List, Optional
911

@@ -56,6 +58,7 @@
5658
ModelDeploymentConfigSummary,
5759
MultiModelDeploymentConfigLoader,
5860
)
61+
from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME
5962
from ads.aqua.modeldeployment.entities import (
6063
AquaDeployment,
6164
AquaDeploymentDetail,
@@ -65,10 +68,13 @@
6568
from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig
6669
from ads.common.object_storage_details import ObjectStorageDetails
6770
from ads.common.utils import UNKNOWN, get_log_links
71+
from ads.common.work_request import DataScienceWorkRequest
6872
from ads.config import (
6973
AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME,
7074
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
7175
AQUA_DEPLOYMENT_CONTAINER_URI_METADATA_NAME,
76+
AQUA_TELEMETRY_BUCKET,
77+
AQUA_TELEMETRY_BUCKET_NS,
7278
COMPARTMENT_OCID,
7379
PROJECT_OCID,
7480
)
@@ -531,6 +537,9 @@ def _create(
531537
if key not in env_var:
532538
env_var.update(env)
533539

540+
env_var.update({"AQUA_TELEMETRY_BUCKET_NS": AQUA_TELEMETRY_BUCKET_NS})
541+
env_var.update({"AQUA_TELEMETRY_BUCKET": AQUA_TELEMETRY_BUCKET})
542+
534543
logger.info(f"Env vars used for deploying {aqua_model.id} :{env_var}")
535544

536545
tags = {**tags, **(create_deployment_details.freeform_tags or {})}
@@ -744,8 +753,20 @@ def _create_deployment(
744753

745754
deployment_id = deployment.id
746755
logger.info(
747-
f"Aqua model deployment {deployment_id} created for model {aqua_model_id}."
756+
f"Aqua model deployment {deployment_id} created for model {aqua_model_id}. Work request Id is {deployment.dsc_model_deployment.workflow_req_id}"
757+
)
758+
759+
progress_thread = threading.Thread(
760+
target=self.get_deployment_status,
761+
args=(
762+
deployment_id,
763+
deployment.dsc_model_deployment.workflow_req_id,
764+
model_type,
765+
model_name,
766+
),
767+
daemon=True,
748768
)
769+
progress_thread.start()
749770

750771
# we arbitrarily choose last 8 characters of OCID to identify MD in telemetry
751772
telemetry_kwargs = {"ocid": get_ocid_substring(deployment_id, key_len=8)}
@@ -1234,3 +1255,62 @@ def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:
12341255
)
12351256
for oci_shape in oci_shapes
12361257
]
1258+
1259+
def get_deployment_status(
1260+
self,
1261+
model_deployment_id: str,
1262+
work_request_id: str,
1263+
model_type: str,
1264+
model_name: str,
1265+
) -> None:
1266+
"""Waits for the data science model deployment to be completed and log its status in telemetry.
1267+
1268+
Parameters
1269+
----------
1270+
1271+
model_deployment_id: str
1272+
The id of the deployed aqua model.
1273+
work_request_id: str
1274+
The work request Id of the model deployment.
1275+
model_type: str
1276+
The type of aqua model to be deployed. Allowed values are: `custom`, `service` and `multi_model`.
1277+
1278+
Returns
1279+
-------
1280+
AquaDeployment
1281+
An Aqua deployment instance.
1282+
"""
1283+
ocid = get_ocid_substring(model_deployment_id, key_len=8)
1284+
telemetry_kwargs = {"ocid": ocid}
1285+
1286+
data_science_work_request: DataScienceWorkRequest = DataScienceWorkRequest(
1287+
work_request_id
1288+
)
1289+
1290+
try:
1291+
data_science_work_request.wait_work_request(
1292+
progress_bar_description="Creating model deployment",
1293+
max_wait_time=DEFAULT_WAIT_TIME,
1294+
poll_interval=DEFAULT_POLL_INTERVAL,
1295+
)
1296+
except Exception:
1297+
if data_science_work_request._error_message:
1298+
error_str = ""
1299+
for error in data_science_work_request._error_message:
1300+
error_str = error_str + " " + error.message
1301+
1302+
self.telemetry.record_event(
1303+
category=f"aqua/{model_type}/deployment/status",
1304+
action="FAILED",
1305+
detail=error_str,
1306+
value=model_name,
1307+
**telemetry_kwargs,
1308+
)
1309+
1310+
else:
1311+
self.telemetry.record_event_async(
1312+
category=f"aqua/{model_type}/deployment/status",
1313+
action="SUCCEEDED",
1314+
value=model_name,
1315+
**telemetry_kwargs,
1316+
)

ads/common/work_request.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
3-
4-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
53
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
64

75
import logging
@@ -12,6 +10,7 @@
1210
import oci
1311
from oci import Signer
1412
from tqdm.auto import tqdm
13+
1514
from ads.common.oci_datascience import OCIDataScienceMixin
1615

1716
logger = logging.getLogger(__name__)
@@ -20,10 +19,10 @@
2019
DEFAULT_WAIT_TIME = 1200
2120
DEFAULT_POLL_INTERVAL = 10
2221
WORK_REQUEST_PERCENTAGE = 100
23-
# default tqdm progress bar format:
22+
# default tqdm progress bar format:
2423
# {l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]
2524
# customize the bar format to remove the {n_fmt}/{total_fmt} from the right side
26-
DEFAULT_BAR_FORMAT = '{l_bar}{bar}| [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]'
25+
DEFAULT_BAR_FORMAT = "{l_bar}{bar}| [{elapsed}<{remaining}, " "{rate_fmt}{postfix}]"
2726

2827

2928
class DataScienceWorkRequest(OCIDataScienceMixin):
@@ -32,13 +31,13 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
3231
"""
3332

3433
def __init__(
35-
self,
36-
id: str,
34+
self,
35+
id: str,
3736
description: str = "Processing",
38-
config: dict = None,
39-
signer: Signer = None,
40-
client_kwargs: dict = None,
41-
**kwargs
37+
config: dict = None,
38+
signer: Signer = None,
39+
client_kwargs: dict = None,
40+
**kwargs,
4241
) -> None:
4342
"""Initializes ADSWorkRequest object.
4443
@@ -49,41 +48,43 @@ def __init__(
4948
description: str
5049
Progress bar initial step description (Defaults to `Processing`).
5150
config : dict, optional
52-
OCI API key config dictionary to initialize
51+
OCI API key config dictionary to initialize
5352
oci.data_science.DataScienceClient (Defaults to None).
5453
signer : oci.signer.Signer, optional
55-
OCI authentication signer to initialize
54+
OCI authentication signer to initialize
5655
oci.data_science.DataScienceClient (Defaults to None).
5756
client_kwargs : dict, optional
58-
Additional client keyword arguments to initialize
57+
Additional client keyword arguments to initialize
5958
oci.data_science.DataScienceClient (Defaults to None).
6059
kwargs:
61-
Additional keyword arguments to initialize
60+
Additional keyword arguments to initialize
6261
oci.data_science.DataScienceClient.
6362
"""
6463
self.id = id
6564
self._description = description
6665
self._percentage = 0
6766
self._status = None
67+
self._error_message = ""
6868
super().__init__(config, signer, client_kwargs, **kwargs)
69-
7069

7170
def _sync(self):
7271
"""Fetches the latest work request information to ADSWorkRequest object."""
7372
work_request = self.client.get_work_request(self.id).data
74-
work_request_logs = self.client.list_work_request_logs(
75-
self.id
76-
).data
73+
work_request_logs = self.client.list_work_request_logs(self.id).data
7774

78-
self._percentage= work_request.percent_complete
75+
self._percentage = work_request.percent_complete
7976
self._status = work_request.status
80-
self._description = work_request_logs[-1].message if work_request_logs else "Processing"
77+
self._description = (
78+
work_request_logs[-1].message if work_request_logs else "Processing"
79+
)
80+
if work_request.status == "FAILED":
81+
self._error_message = self.client.list_work_request_errors(self.id).data
8182

8283
def watch(
83-
self,
84+
self,
8485
progress_callback: Callable,
85-
max_wait_time: int=DEFAULT_WAIT_TIME,
86-
poll_interval: int=DEFAULT_POLL_INTERVAL,
86+
max_wait_time: int = DEFAULT_WAIT_TIME,
87+
poll_interval: int = DEFAULT_POLL_INTERVAL,
8788
):
8889
"""Updates the progress bar with realtime message and percentage until the process is completed.
8990
@@ -92,10 +93,10 @@ def watch(
9293
progress_callback: Callable
9394
Progress bar callback function.
9495
It must accept `(percent_change, description)` where `percent_change` is the
95-
work request percent complete and `description` is the latest work request log message.
96+
work request percent complete and `description` is the latest work request log message.
9697
max_wait_time: int
9798
Maximum amount of time to wait in seconds (Defaults to 1200).
98-
Negative implies infinite wait time.
99+
Negative implies infinite wait time.
99100
poll_interval: int
100101
Poll interval in seconds (Defaults to 10).
101102
@@ -107,7 +108,6 @@ def watch(
107108

108109
start_time = time.time()
109110
while self._percentage < 100:
110-
111111
seconds_since = time.time() - start_time
112112
if max_wait_time > 0 and seconds_since >= max_wait_time:
113113
logger.error(f"Exceeded max wait time of {max_wait_time} seconds.")
@@ -124,12 +124,14 @@ def watch(
124124
percent_change = self._percentage - previous_percent_complete
125125
previous_percent_complete = self._percentage
126126
progress_callback(
127-
percent_change=percent_change,
128-
description=self._description
127+
percent_change=percent_change, description=self._description
129128
)
130129

131130
if self._status in WORK_REQUEST_STOP_STATE:
132-
if self._status != oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED:
131+
if (
132+
self._status
133+
!= oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED
134+
):
133135
if self._description:
134136
raise Exception(self._description)
135137
else:
@@ -145,12 +147,12 @@ def watch(
145147

146148
def wait_work_request(
147149
self,
148-
progress_bar_description: str="Processing",
149-
max_wait_time: int=DEFAULT_WAIT_TIME,
150-
poll_interval: int=DEFAULT_POLL_INTERVAL
150+
progress_bar_description: str = "Processing",
151+
max_wait_time: int = DEFAULT_WAIT_TIME,
152+
poll_interval: int = DEFAULT_POLL_INTERVAL,
151153
):
152154
"""Waits for the work request progress bar to be completed.
153-
155+
154156
Parameters
155157
----------
156158
progress_bar_description: str
@@ -160,7 +162,7 @@ def wait_work_request(
160162
Negative implies infinite wait time.
161163
poll_interval: int
162164
Poll interval in seconds (Defaults to 10).
163-
165+
164166
Returns
165167
-------
166168
None
@@ -172,7 +174,7 @@ def wait_work_request(
172174
mininterval=0,
173175
file=sys.stdout,
174176
desc=progress_bar_description,
175-
bar_format=DEFAULT_BAR_FORMAT
177+
bar_format=DEFAULT_BAR_FORMAT,
176178
) as pbar:
177179

178180
def progress_callback(percent_change, description):
@@ -184,6 +186,5 @@ def progress_callback(percent_change, description):
184186
self.watch(
185187
progress_callback=progress_callback,
186188
max_wait_time=max_wait_time,
187-
poll_interval=poll_interval
189+
poll_interval=poll_interval,
188190
)
189-

ads/model/service/oci_datascience_model_deployment.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ def activate(
185185
self.id,
186186
)
187187

188+
189+
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
188190
if wait_for_completion:
189-
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
190-
191191
try:
192192
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
193193
progress_bar_description="Activating model deployment",
@@ -233,11 +233,9 @@ def create(
233233
response = self.client.create_model_deployment(create_model_deployment_details)
234234
self.update_from_oci_model(response.data)
235235
logger.info(f"Creating model deployment `{self.id}`.")
236-
print(f"Model Deployment OCID: {self.id}")
237236

237+
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
238238
if wait_for_completion:
239-
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
240-
241239
try:
242240
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
243241
progress_bar_description="Creating model deployment",
@@ -287,10 +285,8 @@ def deactivate(
287285
response = self.client.deactivate_model_deployment(
288286
self.id,
289287
)
290-
288+
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
291289
if wait_for_completion:
292-
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
293-
294290
try:
295291
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
296292
progress_bar_description="Deactivating model deployment",
@@ -355,10 +351,9 @@ def delete(
355351
response = self.client.delete_model_deployment(
356352
self.id,
357353
)
358-
354+
355+
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
359356
if wait_for_completion:
360-
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
361-
362357
try:
363358
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
364359
progress_bar_description="Deleting model deployment",

0 commit comments

Comments
 (0)