Skip to content

Commit 0c9b0f0

Browse files
author
Bowen Yuan
committed
fix: update arn builder method and its usage
1 parent bee2953 commit 0c9b0f0

File tree

6 files changed

+239
-273
lines changed

6 files changed

+239
-273
lines changed

src/stepfunctions/steps/compute.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,32 @@
1212
# permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
from enum import Enum
1516
from stepfunctions.steps.states import Task
1617
from stepfunctions.steps.fields import Field
17-
from stepfunctions.steps.utils import resource_integration_arn_builder
18-
from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, LambdaApi, GlueApi, BatchApi, EcsApi
18+
from stepfunctions.steps.utils import get_service_integration_arn
19+
from stepfunctions.steps.integration_resources import IntegrationPattern
20+
21+
Lambda = "lambda"
22+
Glue = "glue"
23+
Ecs = "ecs"
24+
Batch = "batch"
25+
26+
27+
class LambdaApi(Enum):
28+
Invoke = "invoke"
29+
30+
31+
class GlueApi(Enum):
32+
StartJobRun = "startJobRun"
33+
34+
35+
class EcsApi(Enum):
36+
RunTask = "runTask"
37+
38+
39+
class BatchApi(Enum):
40+
SubmitJob = "submitJob"
1941

2042

2143
class LambdaStep(Task):
@@ -44,14 +66,16 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs):
4466
"""
4567
Example resource arn: arn:aws:states:::lambda:invoke.waitForTaskToken
4668
"""
47-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Lambda,
48-
LambdaApi.Invoke,
49-
IntegrationPattern.WaitForTaskToken)
69+
70+
kwargs[Field.Resource.value] = get_service_integration_arn(Lambda,
71+
LambdaApi.Invoke,
72+
IntegrationPattern.WaitForTaskToken)
5073
else:
5174
"""
5275
Example resource arn: arn:aws:states:::lambda:invoke
5376
"""
54-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Lambda, LambdaApi.Invoke)
77+
78+
kwargs[Field.Resource.value] = get_service_integration_arn(Lambda, LambdaApi.Invoke)
5579

5680

5781
super(LambdaStep, self).__init__(state_id, **kwargs)
@@ -82,15 +106,17 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
82106
"""
83107
Example resource arn: arn:aws:states:::glue:startJobRun.sync
84108
"""
85-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Glue,
86-
GlueApi.StartJobRun,
87-
IntegrationPattern.WaitForCompletion)
109+
110+
kwargs[Field.Resource.value] = get_service_integration_arn(Glue,
111+
GlueApi.StartJobRun,
112+
IntegrationPattern.WaitForCompletion)
88113
else:
89114
"""
90115
Example resource arn: arn:aws:states:::glue:startJobRun
91116
"""
92-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Glue,
93-
GlueApi.StartJobRun)
117+
118+
kwargs[Field.Resource.value] = get_service_integration_arn(Glue,
119+
GlueApi.StartJobRun)
94120

95121
super(GlueStartJobRunStep, self).__init__(state_id, **kwargs)
96122

@@ -120,15 +146,17 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
120146
"""
121147
Example resource arn: arn:aws:states:::batch:submitJob.sync
122148
"""
123-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Batch,
124-
BatchApi.SubmitJob,
125-
IntegrationPattern.WaitForCompletion)
149+
150+
kwargs[Field.Resource.value] = get_service_integration_arn(Batch,
151+
BatchApi.SubmitJob,
152+
IntegrationPattern.WaitForCompletion)
126153
else:
127154
"""
128155
Example resource arn: arn:aws:states:::batch:submitJob
129156
"""
130-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Batch,
131-
BatchApi.SubmitJob)
157+
158+
kwargs[Field.Resource.value] = get_service_integration_arn(Batch,
159+
BatchApi.SubmitJob)
132160

133161
super(BatchSubmitJobStep, self).__init__(state_id, **kwargs)
134162

@@ -158,14 +186,16 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
158186
"""
159187
Example resource arn: arn:aws:states:::ecs:runTask.sync
160188
"""
161-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ECS,
162-
EcsApi.RunTask,
163-
IntegrationPattern.WaitForCompletion)
189+
190+
kwargs[Field.Resource.value] = get_service_integration_arn(Ecs,
191+
EcsApi.RunTask,
192+
IntegrationPattern.WaitForCompletion)
164193
else:
165194
"""
166195
Example resource arn: arn:aws:states:::ecs:runTask
167196
"""
168-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ECS,
169-
EcsApi.RunTask)
197+
198+
kwargs[Field.Resource.value] = get_service_integration_arn(Ecs,
199+
EcsApi.RunTask)
170200

171201
super(EcsRunTaskStep, self).__init__(state_id, **kwargs)

src/stepfunctions/steps/integration_resources.py

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -15,75 +15,16 @@
1515

1616
from enum import Enum
1717

18-
"""
19-
Enum classes for task integration resource arn builder
20-
"""
21-
2218

2319
class IntegrationPattern(Enum):
20+
"""
21+
Integration pattern enum classes for task integration resource arn builder
22+
"""
23+
2424
WaitForTaskToken = "waitForTaskToken"
2525
WaitForCompletion = "sync"
26+
RequestResponse = ""
2627

2728

28-
class IntegrationServices(Enum):
29-
Lambda = "lambda"
30-
SageMaker = "sagemaker"
31-
Glue = "glue"
32-
ECS = "ecs"
33-
Batch = "batch"
34-
DynamoDB = "dynamodb"
35-
SNS = "sns"
36-
SQS = "sqs"
37-
ElasticMapReduce = "elasticmapreduce"
38-
39-
40-
class LambdaApi(Enum):
41-
Invoke = "invoke"
42-
43-
44-
class SageMakerApi(Enum):
45-
CreateTrainingJob = "createTrainingJob"
46-
CreateTransformJob = "createTransformJob"
47-
CreateModel = "createModel"
48-
CreateEndpointConfig = "createEndpointConfig"
49-
UpdateEndpoint = "updateEndpoint"
50-
CreateEndpoint = "createEndpoint"
51-
CreateHyperParameterTuningJob = "createHyperParameterTuningJob"
52-
CreateProcessingJob = "createProcessingJob"
53-
54-
55-
class GlueApi(Enum):
56-
StartJobRun = "startJobRun"
57-
58-
59-
class EcsApi(Enum):
60-
RunTask = "runTask"
61-
62-
63-
class BatchApi(Enum):
64-
SubmitJob = "submitJob"
65-
66-
67-
class DynamoDBApi(Enum):
68-
GetItem = "getItem"
69-
PutItem = "putItem"
70-
DeleteItem = "deleteItem"
71-
UpdateItem = "updateItem"
72-
73-
74-
class SnsApi(Enum):
75-
Publish = "publish"
76-
77-
78-
class SqsApi(Enum):
79-
SendMessage = "sendMessage"
8029

8130

82-
class ElasticMapReduceApi(Enum):
83-
CreateCluster = "createCluster"
84-
TerminateCluster = "terminateCluster"
85-
AddStep = "addStep"
86-
CancelStep = "cancelStep"
87-
SetClusterTerminationProtection = "setClusterTerminationProtection"
88-
ModifyInstanceFleetByName = "modifyInstanceFleetByName"
89-
ModifyInstanceGroupByName = "modifyInstanceGroupByName"

src/stepfunctions/steps/sagemaker.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,31 @@
1212
# permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
from enum import Enum
1516
from stepfunctions.inputs import ExecutionInput, StepInput
1617
from stepfunctions.steps.states import Task
1718
from stepfunctions.steps.fields import Field
18-
from stepfunctions.steps.utils import tags_dict_to_kv_list, resource_integration_arn_builder
19-
from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, SageMakerApi
19+
from stepfunctions.steps.utils import tags_dict_to_kv_list, get_service_integration_arn
20+
from stepfunctions.steps.integration_resources import IntegrationPattern
2021

2122
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
2223
from sagemaker.model import Model, FrameworkModel
2324
from sagemaker.model_monitor import DataCaptureConfig
2425

26+
SageMaker = "sagemaker"
27+
28+
29+
class SageMakerApi(Enum):
30+
CreateTrainingJob = "createTrainingJob"
31+
CreateTransformJob = "createTransformJob"
32+
CreateModel = "createModel"
33+
CreateEndpointConfig = "createEndpointConfig"
34+
UpdateEndpoint = "updateEndpoint"
35+
CreateEndpoint = "createEndpoint"
36+
CreateHyperParameterTuningJob = "createHyperParameterTuningJob"
37+
CreateProcessingJob = "createProcessingJob"
38+
39+
2540
class TrainingStep(Task):
2641

2742
"""
@@ -62,15 +77,17 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
6277
"""
6378
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync
6479
"""
65-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
66-
SageMakerApi.CreateTrainingJob,
67-
IntegrationPattern.WaitForCompletion)
80+
81+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
82+
SageMakerApi.CreateTrainingJob,
83+
IntegrationPattern.WaitForCompletion)
6884
else:
6985
"""
7086
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob
7187
"""
72-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
73-
SageMakerApi.CreateTrainingJob)
88+
89+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
90+
SageMakerApi.CreateTrainingJob)
7491

7592
if isinstance(job_name, str):
7693
parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
@@ -154,15 +171,17 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
154171
"""
155172
Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync
156173
"""
157-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
158-
SageMakerApi.CreateTransformJob,
159-
IntegrationPattern.WaitForCompletion)
174+
175+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
176+
SageMakerApi.CreateTransformJob,
177+
IntegrationPattern.WaitForCompletion)
160178
else:
161179
"""
162180
Example resource arn: arn:aws:states:::sagemaker:createTransformJob
163181
"""
164-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
165-
SageMakerApi.CreateTransformJob)
182+
183+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
184+
SageMakerApi.CreateTransformJob)
166185

167186
if isinstance(job_name, str):
168187
parameters = transform_config(
@@ -248,8 +267,9 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
248267
"""
249268
Example resource arn: arn:aws:states:::sagemaker:createModel
250269
"""
251-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
252-
SageMakerApi.CreateModel)
270+
271+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
272+
SageMakerApi.CreateModel)
253273

254274
super(ModelStep, self).__init__(state_id, **kwargs)
255275

@@ -293,8 +313,9 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
293313
"""
294314
Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig
295315
"""
296-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
297-
SageMakerApi.CreateEndpointConfig)
316+
317+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
318+
SageMakerApi.CreateEndpointConfig)
298319

299320
kwargs[Field.Parameters.value] = parameters
300321

@@ -330,14 +351,16 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
330351
"""
331352
Example resource arn: arn:aws:states:::sagemaker:updateEndpoint
332353
"""
333-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
334-
SageMakerApi.UpdateEndpoint)
354+
355+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
356+
SageMakerApi.UpdateEndpoint)
335357
else:
336358
"""
337359
Example resource arn: arn:aws:states:::sagemaker:createEndpoint
338360
"""
339-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
340-
SageMakerApi.CreateEndpoint)
361+
362+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
363+
SageMakerApi.CreateEndpoint)
341364

342365
kwargs[Field.Parameters.value] = parameters
343366

@@ -378,15 +401,17 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
378401
"""
379402
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync
380403
"""
381-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
382-
SageMakerApi.CreateHyperParameterTuningJob,
383-
IntegrationPattern.WaitForCompletion)
404+
405+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
406+
SageMakerApi.CreateHyperParameterTuningJob,
407+
IntegrationPattern.WaitForCompletion)
384408
else:
385409
"""
386410
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob
387411
"""
388-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
389-
SageMakerApi.CreateHyperParameterTuningJob)
412+
413+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
414+
SageMakerApi.CreateHyperParameterTuningJob)
390415

391416
parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
392417

@@ -436,15 +461,17 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
436461
"""
437462
Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync
438463
"""
439-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
440-
SageMakerApi.CreateProcessingJob,
441-
IntegrationPattern.WaitForCompletion)
464+
465+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
466+
SageMakerApi.CreateProcessingJob,
467+
IntegrationPattern.WaitForCompletion)
442468
else:
443469
"""
444470
Example resource arn: arn:aws:states:::sagemaker:createProcessingJob
445471
"""
446-
kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker,
447-
SageMakerApi.CreateProcessingJob)
472+
473+
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
474+
SageMakerApi.CreateProcessingJob)
448475

449476
if isinstance(job_name, str):
450477
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)

0 commit comments

Comments
 (0)