|
12 | 12 | # permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
| 15 | +from enum import Enum |
15 | 16 | from stepfunctions.inputs import ExecutionInput, StepInput
|
16 | 17 | from stepfunctions.steps.states import Task
|
17 | 18 | from stepfunctions.steps.fields import Field
|
18 |
| -from stepfunctions.steps.utils import tags_dict_to_kv_list, resource_integration_arn_builder |
19 |
| -from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, SageMakerApi |
| 19 | +from stepfunctions.steps.utils import tags_dict_to_kv_list, get_service_integration_arn |
| 20 | +from stepfunctions.steps.integration_resources import IntegrationPattern |
20 | 21 |
|
21 | 22 | from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
|
22 | 23 | from sagemaker.model import Model, FrameworkModel
|
23 | 24 | from sagemaker.model_monitor import DataCaptureConfig
|
24 | 25 |
|
| 26 | +SageMaker = "sagemaker" |
| 27 | + |
| 28 | + |
| 29 | +class SageMakerApi(Enum): |
| 30 | + CreateTrainingJob = "createTrainingJob" |
| 31 | + CreateTransformJob = "createTransformJob" |
| 32 | + CreateModel = "createModel" |
| 33 | + CreateEndpointConfig = "createEndpointConfig" |
| 34 | + UpdateEndpoint = "updateEndpoint" |
| 35 | + CreateEndpoint = "createEndpoint" |
| 36 | + CreateHyperParameterTuningJob = "createHyperParameterTuningJob" |
| 37 | + CreateProcessingJob = "createProcessingJob" |
| 38 | + |
| 39 | + |
25 | 40 | class TrainingStep(Task):
|
26 | 41 |
|
27 | 42 | """
|
@@ -62,15 +77,17 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
|
62 | 77 | """
|
63 | 78 | Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync
|
64 | 79 | """
|
65 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
66 |
| - SageMakerApi.CreateTrainingJob, |
67 |
| - IntegrationPattern.WaitForCompletion) |
| 80 | + |
| 81 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 82 | + SageMakerApi.CreateTrainingJob, |
| 83 | + IntegrationPattern.WaitForCompletion) |
68 | 84 | else:
|
69 | 85 | """
|
70 | 86 | Example resource arn: arn:aws:states:::sagemaker:createTrainingJob
|
71 | 87 | """
|
72 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
73 |
| - SageMakerApi.CreateTrainingJob) |
| 88 | + |
| 89 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 90 | + SageMakerApi.CreateTrainingJob) |
74 | 91 |
|
75 | 92 | if isinstance(job_name, str):
|
76 | 93 | parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
|
@@ -154,15 +171,17 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
|
154 | 171 | """
|
155 | 172 | Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync
|
156 | 173 | """
|
157 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
158 |
| - SageMakerApi.CreateTransformJob, |
159 |
| - IntegrationPattern.WaitForCompletion) |
| 174 | + |
| 175 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 176 | + SageMakerApi.CreateTransformJob, |
| 177 | + IntegrationPattern.WaitForCompletion) |
160 | 178 | else:
|
161 | 179 | """
|
162 | 180 | Example resource arn: arn:aws:states:::sagemaker:createTransformJob
|
163 | 181 | """
|
164 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
165 |
| - SageMakerApi.CreateTransformJob) |
| 182 | + |
| 183 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 184 | + SageMakerApi.CreateTransformJob) |
166 | 185 |
|
167 | 186 | if isinstance(job_name, str):
|
168 | 187 | parameters = transform_config(
|
@@ -248,8 +267,9 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
|
248 | 267 | """
|
249 | 268 | Example resource arn: arn:aws:states:::sagemaker:createModel
|
250 | 269 | """
|
251 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
252 |
| - SageMakerApi.CreateModel) |
| 270 | + |
| 271 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 272 | + SageMakerApi.CreateModel) |
253 | 273 |
|
254 | 274 | super(ModelStep, self).__init__(state_id, **kwargs)
|
255 | 275 |
|
@@ -293,8 +313,9 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
|
293 | 313 | """
|
294 | 314 | Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig
|
295 | 315 | """
|
296 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
297 |
| - SageMakerApi.CreateEndpointConfig) |
| 316 | + |
| 317 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 318 | + SageMakerApi.CreateEndpointConfig) |
298 | 319 |
|
299 | 320 | kwargs[Field.Parameters.value] = parameters
|
300 | 321 |
|
@@ -330,14 +351,16 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
|
330 | 351 | """
|
331 | 352 | Example resource arn: arn:aws:states:::sagemaker:updateEndpoint
|
332 | 353 | """
|
333 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
334 |
| - SageMakerApi.UpdateEndpoint) |
| 354 | + |
| 355 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 356 | + SageMakerApi.UpdateEndpoint) |
335 | 357 | else:
|
336 | 358 | """
|
337 | 359 | Example resource arn: arn:aws:states:::sagemaker:createEndpoint
|
338 | 360 | """
|
339 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
340 |
| - SageMakerApi.CreateEndpoint) |
| 361 | + |
| 362 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 363 | + SageMakerApi.CreateEndpoint) |
341 | 364 |
|
342 | 365 | kwargs[Field.Parameters.value] = parameters
|
343 | 366 |
|
@@ -378,15 +401,17 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
|
378 | 401 | """
|
379 | 402 | Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync
|
380 | 403 | """
|
381 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
382 |
| - SageMakerApi.CreateHyperParameterTuningJob, |
383 |
| - IntegrationPattern.WaitForCompletion) |
| 404 | + |
| 405 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 406 | + SageMakerApi.CreateHyperParameterTuningJob, |
| 407 | + IntegrationPattern.WaitForCompletion) |
384 | 408 | else:
|
385 | 409 | """
|
386 | 410 | Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob
|
387 | 411 | """
|
388 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
389 |
| - SageMakerApi.CreateHyperParameterTuningJob) |
| 412 | + |
| 413 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 414 | + SageMakerApi.CreateHyperParameterTuningJob) |
390 | 415 |
|
391 | 416 | parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
|
392 | 417 |
|
@@ -436,15 +461,17 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
|
436 | 461 | """
|
437 | 462 | Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync
|
438 | 463 | """
|
439 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
440 |
| - SageMakerApi.CreateProcessingJob, |
441 |
| - IntegrationPattern.WaitForCompletion) |
| 464 | + |
| 465 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 466 | + SageMakerApi.CreateProcessingJob, |
| 467 | + IntegrationPattern.WaitForCompletion) |
442 | 468 | else:
|
443 | 469 | """
|
444 | 470 | Example resource arn: arn:aws:states:::sagemaker:createProcessingJob
|
445 | 471 | """
|
446 |
| - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
447 |
| - SageMakerApi.CreateProcessingJob) |
| 472 | + |
| 473 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 474 | + SageMakerApi.CreateProcessingJob) |
448 | 475 |
|
449 | 476 | if isinstance(job_name, str):
|
450 | 477 | parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
|
|
0 commit comments