Skip to content

Commit bee2953

Browse files
author
Bowen Yuan
committed
fix: use boto3 mock for utils test and create arn builder for task integration resource
1 parent 723214e commit bee2953

13 files changed

+546
-79
lines changed

src/stepfunctions/steps/compute.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from stepfunctions.steps.states import Task
1616
from stepfunctions.steps.fields import Field
17-
from stepfunctions.steps.utils import get_aws_partition
17+
from stepfunctions.steps.utils import resource_integration_arn_builder
18+
from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, LambdaApi, GlueApi, BatchApi, EcsApi
1819

1920

2021
class LambdaStep(Task):
@@ -38,10 +39,20 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs):
3839
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
3940
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
4041
"""
42+
4143
if wait_for_callback:
42-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::lambda:invoke.waitForTaskToken'
44+
"""
45+
Example resource arn: arn:aws:states:::lambda:invoke.waitForTaskToken
46+
"""
47+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Lambda,
48+
LambdaApi.Invoke,
49+
IntegrationPattern.WaitForTaskToken)
4350
else:
44-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::lambda:invoke'
51+
"""
52+
Example resource arn: arn:aws:states:::lambda:invoke
53+
"""
54+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Lambda, LambdaApi.Invoke)
55+
4556

4657
super(LambdaStep, self).__init__(state_id, **kwargs)
4758

@@ -68,9 +79,18 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
6879
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
6980
"""
7081
if wait_for_completion:
71-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::glue:startJobRun.sync'
82+
"""
83+
Example resource arn: arn:aws:states:::glue:startJobRun.sync
84+
"""
85+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Glue,
86+
GlueApi.StartJobRun,
87+
IntegrationPattern.WaitForCompletion)
7288
else:
73-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::glue:startJobRun'
89+
"""
90+
Example resource arn: arn:aws:states:::glue:startJobRun
91+
"""
92+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Glue,
93+
GlueApi.StartJobRun)
7494

7595
super(GlueStartJobRunStep, self).__init__(state_id, **kwargs)
7696

@@ -97,9 +117,18 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
97117
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
98118
"""
99119
if wait_for_completion:
100-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::batch:submitJob.sync'
120+
"""
121+
Example resource arn: arn:aws:states:::batch:submitJob.sync
122+
"""
123+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Batch,
124+
BatchApi.SubmitJob,
125+
IntegrationPattern.WaitForCompletion)
101126
else:
102-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::batch:submitJob'
127+
"""
128+
Example resource arn: arn:aws:states:::batch:submitJob
129+
"""
130+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Batch,
131+
BatchApi.SubmitJob)
103132

104133
super(BatchSubmitJobStep, self).__init__(state_id, **kwargs)
105134

@@ -126,8 +155,17 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
126155
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
127156
"""
128157
if wait_for_completion:
129-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::ecs:runTask.sync'
158+
"""
159+
Example resource arn: arn:aws:states:::ecs:runTask.sync
160+
"""
161+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ECS,
162+
EcsApi.RunTask,
163+
IntegrationPattern.WaitForCompletion)
130164
else:
131-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::ecs:runTask'
165+
"""
166+
Example resource arn: arn:aws:states:::ecs:runTask
167+
"""
168+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ECS,
169+
EcsApi.RunTask)
132170

133171
super(EcsRunTaskStep, self).__init__(state_id, **kwargs)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
14+
from __future__ import absolute_import
15+
16+
from enum import Enum
17+
18+
"""
19+
Enum classes for task integration resource arn builder
20+
"""
21+
22+
23+
class IntegrationPattern(Enum):
24+
WaitForTaskToken = "waitForTaskToken"
25+
WaitForCompletion = "sync"
26+
27+
28+
class IntegrationServices(Enum):
29+
Lambda = "lambda"
30+
SageMaker = "sagemaker"
31+
Glue = "glue"
32+
ECS = "ecs"
33+
Batch = "batch"
34+
DynamoDB = "dynamodb"
35+
SNS = "sns"
36+
SQS = "sqs"
37+
ElasticMapReduce = "elasticmapreduce"
38+
39+
40+
class LambdaApi(Enum):
41+
Invoke = "invoke"
42+
43+
44+
class SageMakerApi(Enum):
45+
CreateTrainingJob = "createTrainingJob"
46+
CreateTransformJob = "createTransformJob"
47+
CreateModel = "createModel"
48+
CreateEndpointConfig = "createEndpointConfig"
49+
UpdateEndpoint = "updateEndpoint"
50+
CreateEndpoint = "createEndpoint"
51+
CreateHyperParameterTuningJob = "createHyperParameterTuningJob"
52+
CreateProcessingJob = "createProcessingJob"
53+
54+
55+
class GlueApi(Enum):
56+
StartJobRun = "startJobRun"
57+
58+
59+
class EcsApi(Enum):
60+
RunTask = "runTask"
61+
62+
63+
class BatchApi(Enum):
64+
SubmitJob = "submitJob"
65+
66+
67+
class DynamoDBApi(Enum):
68+
GetItem = "getItem"
69+
PutItem = "putItem"
70+
DeleteItem = "deleteItem"
71+
UpdateItem = "updateItem"
72+
73+
74+
class SnsApi(Enum):
75+
Publish = "publish"
76+
77+
78+
class SqsApi(Enum):
79+
SendMessage = "sendMessage"
80+
81+
82+
class ElasticMapReduceApi(Enum):
83+
CreateCluster = "createCluster"
84+
TerminateCluster = "terminateCluster"
85+
AddStep = "addStep"
86+
CancelStep = "cancelStep"
87+
SetClusterTerminationProtection = "setClusterTerminationProtection"
88+
ModifyInstanceFleetByName = "modifyInstanceFleetByName"
89+
ModifyInstanceGroupByName = "modifyInstanceGroupByName"

src/stepfunctions/steps/sagemaker.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from stepfunctions.inputs import ExecutionInput, StepInput
1616
from stepfunctions.steps.states import Task
1717
from stepfunctions.steps.fields import Field
18-
from stepfunctions.steps.utils import tags_dict_to_kv_list, get_aws_partition
18+
from stepfunctions.steps.utils import tags_dict_to_kv_list, resource_integration_arn_builder
19+
from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, SageMakerApi
1920

2021
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
2122
from sagemaker.model import Model, FrameworkModel
@@ -58,9 +59,18 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
5859
self.job_name = job_name
5960

6061
if wait_for_completion:
61-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTrainingJob.sync'
62+
"""
63+
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync
64+
"""
65+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
66+
SageMakerApi.CreateTrainingJob,
67+
IntegrationPattern.WaitForCompletion)
6268
else:
63-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTrainingJob'
69+
"""
70+
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob
71+
"""
72+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
73+
SageMakerApi.CreateTrainingJob)
6474

6575
if isinstance(job_name, str):
6676
parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
@@ -141,9 +151,18 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
141151
join_source (str): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None.
142152
"""
143153
if wait_for_completion:
144-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTransformJob.sync'
154+
"""
155+
Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync
156+
"""
157+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
158+
SageMakerApi.CreateTransformJob,
159+
IntegrationPattern.WaitForCompletion)
145160
else:
146-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTransformJob'
161+
"""
162+
Example resource arn: arn:aws:states:::sagemaker:createTransformJob
163+
"""
164+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
165+
SageMakerApi.CreateTransformJob)
147166

148167
if isinstance(job_name, str):
149168
parameters = transform_config(
@@ -225,7 +244,12 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
225244
parameters['Tags'] = tags_dict_to_kv_list(tags)
226245

227246
kwargs[Field.Parameters.value] = parameters
228-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createModel'
247+
248+
"""
249+
Example resource arn: arn:aws:states:::sagemaker:createModel
250+
"""
251+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
252+
SageMakerApi.CreateModel)
229253

230254
super(ModelStep, self).__init__(state_id, **kwargs)
231255

@@ -266,7 +290,12 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
266290
if tags:
267291
parameters['Tags'] = tags_dict_to_kv_list(tags)
268292

269-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createEndpointConfig'
293+
"""
294+
Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig
295+
"""
296+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
297+
SageMakerApi.CreateEndpointConfig)
298+
270299
kwargs[Field.Parameters.value] = parameters
271300

272301
super(EndpointConfigStep, self).__init__(state_id, **kwargs)
@@ -298,9 +327,17 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
298327
parameters['Tags'] = tags_dict_to_kv_list(tags)
299328

300329
if update:
301-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:updateEndpoint'
330+
"""
331+
Example resource arn: arn:aws:states:::sagemaker:updateEndpoint
332+
"""
333+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
334+
SageMakerApi.UpdateEndpoint)
302335
else:
303-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createEndpoint'
336+
"""
337+
Example resource arn: arn:aws:states:::sagemaker:createEndpoint
338+
"""
339+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
340+
SageMakerApi.CreateEndpoint)
304341

305342
kwargs[Field.Parameters.value] = parameters
306343

@@ -338,9 +375,18 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
338375
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
339376
"""
340377
if wait_for_completion:
341-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createHyperParameterTuningJob.sync'
378+
"""
379+
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync
380+
"""
381+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
382+
SageMakerApi.CreateHyperParameterTuningJob,
383+
IntegrationPattern.WaitForCompletion)
342384
else:
343-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createHyperParameterTuningJob'
385+
"""
386+
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob
387+
"""
388+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
389+
SageMakerApi.CreateHyperParameterTuningJob)
344390

345391
parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
346392

@@ -387,10 +433,19 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
387433
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
388434
"""
389435
if wait_for_completion:
390-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createProcessingJob.sync'
436+
"""
437+
Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync
438+
"""
439+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
440+
SageMakerApi.CreateProcessingJob,
441+
IntegrationPattern.WaitForCompletion)
391442
else:
392-
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createProcessingJob'
393-
443+
"""
444+
Example resource arn: arn:aws:states:::sagemaker:createProcessingJob
445+
"""
446+
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
447+
SageMakerApi.CreateProcessingJob)
448+
394449
if isinstance(job_name, str):
395450
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
396451
else:

0 commit comments

Comments
 (0)