Skip to content

Commit 162f9db

Browse files
author
Bowen Yuan
committed
fix: move arn builder method into integration_resources module
1 parent 0c9b0f0 commit 162f9db

File tree

6 files changed

+73
-77
lines changed

6 files changed

+73
-77
lines changed

src/stepfunctions/steps/compute.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
from enum import Enum
1616
from stepfunctions.steps.states import Task
1717
from stepfunctions.steps.fields import Field
18-
from stepfunctions.steps.utils import get_service_integration_arn
19-
from stepfunctions.steps.integration_resources import IntegrationPattern
18+
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn
2019

21-
Lambda = "lambda"
22-
Glue = "glue"
23-
Ecs = "ecs"
24-
Batch = "batch"
20+
LAMBDA_SERVICE_NAME = "lambda"
21+
GLUE_SERVICE_NAME = "glue"
22+
ECS_SERVICE_NAME = "ecs"
23+
BATCH_SERVICE_NAME = "batch"
2524

2625

2726
class LambdaApi(Enum):
@@ -67,15 +66,15 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs):
6766
Example resource arn: arn:aws:states:::lambda:invoke.waitForTaskToken
6867
"""
6968

70-
kwargs[Field.Resource.value] = get_service_integration_arn(Lambda,
71-
LambdaApi.Invoke,
72-
IntegrationPattern.WaitForTaskToken)
69+
kwargs[Field.Resource.value] = get_service_integration_arn(LAMBDA_SERVICE_NAME,
70+
LambdaApi.Invoke,
71+
IntegrationPattern.WaitForTaskToken)
7372
else:
7473
"""
7574
Example resource arn: arn:aws:states:::lambda:invoke
7675
"""
7776

78-
kwargs[Field.Resource.value] = get_service_integration_arn(Lambda, LambdaApi.Invoke)
77+
kwargs[Field.Resource.value] = get_service_integration_arn(LAMBDA_SERVICE_NAME, LambdaApi.Invoke)
7978

8079

8180
super(LambdaStep, self).__init__(state_id, **kwargs)
@@ -107,15 +106,15 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
107106
Example resource arn: arn:aws:states:::glue:startJobRun.sync
108107
"""
109108

110-
kwargs[Field.Resource.value] = get_service_integration_arn(Glue,
109+
kwargs[Field.Resource.value] = get_service_integration_arn(GLUE_SERVICE_NAME,
111110
GlueApi.StartJobRun,
112111
IntegrationPattern.WaitForCompletion)
113112
else:
114113
"""
115114
Example resource arn: arn:aws:states:::glue:startJobRun
116115
"""
117116

118-
kwargs[Field.Resource.value] = get_service_integration_arn(Glue,
117+
kwargs[Field.Resource.value] = get_service_integration_arn(GLUE_SERVICE_NAME,
119118
GlueApi.StartJobRun)
120119

121120
super(GlueStartJobRunStep, self).__init__(state_id, **kwargs)
@@ -147,15 +146,15 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
147146
Example resource arn: arn:aws:states:::batch:submitJob.sync
148147
"""
149148

150-
kwargs[Field.Resource.value] = get_service_integration_arn(Batch,
149+
kwargs[Field.Resource.value] = get_service_integration_arn(BATCH_SERVICE_NAME,
151150
BatchApi.SubmitJob,
152151
IntegrationPattern.WaitForCompletion)
153152
else:
154153
"""
155154
Example resource arn: arn:aws:states:::batch:submitJob
156155
"""
157156

158-
kwargs[Field.Resource.value] = get_service_integration_arn(Batch,
157+
kwargs[Field.Resource.value] = get_service_integration_arn(BATCH_SERVICE_NAME,
159158
BatchApi.SubmitJob)
160159

161160
super(BatchSubmitJobStep, self).__init__(state_id, **kwargs)
@@ -187,15 +186,15 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
187186
Example resource arn: arn:aws:states:::ecs:runTask.sync
188187
"""
189188

190-
kwargs[Field.Resource.value] = get_service_integration_arn(Ecs,
189+
kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME,
191190
EcsApi.RunTask,
192191
IntegrationPattern.WaitForCompletion)
193192
else:
194193
"""
195194
Example resource arn: arn:aws:states:::ecs:runTask
196195
"""
197196

198-
kwargs[Field.Resource.value] = get_service_integration_arn(Ecs,
197+
kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME,
199198
EcsApi.RunTask)
200199

201200
super(EcsRunTaskStep, self).__init__(state_id, **kwargs)

src/stepfunctions/steps/integration_resources.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from enum import Enum
17+
from stepfunctions.steps.utils import get_aws_partition
1718

1819

1920
class IntegrationPattern(Enum):
@@ -26,5 +27,22 @@ class IntegrationPattern(Enum):
2627
RequestResponse = ""
2728

2829

30+
def get_service_integration_arn(service, api, integration_pattern=IntegrationPattern.RequestResponse):
31+
32+
"""
33+
ARN builder for task integration
34+
Args:
35+
service(str): name of the task resource service
36+
api(<Service>Api): api to be integrated of the task resource service
37+
integration_pattern(IntegrationPattern, optional): integration pattern for the task resource.
38+
Default as request response.
39+
"""
40+
arn = ""
41+
if integration_pattern == IntegrationPattern.RequestResponse:
42+
arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}"
43+
else:
44+
arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}.{integration_pattern.value}"
45+
return arn
46+
2947

3048

src/stepfunctions/steps/sagemaker.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
from stepfunctions.inputs import ExecutionInput, StepInput
1717
from stepfunctions.steps.states import Task
1818
from stepfunctions.steps.fields import Field
19-
from stepfunctions.steps.utils import tags_dict_to_kv_list, get_service_integration_arn
20-
from stepfunctions.steps.integration_resources import IntegrationPattern
19+
from stepfunctions.steps.utils import tags_dict_to_kv_list
20+
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn
2121

2222
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
2323
from sagemaker.model import Model, FrameworkModel
2424
from sagemaker.model_monitor import DataCaptureConfig
2525

26-
SageMaker = "sagemaker"
26+
SAGEMAKER_SERVICE_NAME = "sagemaker"
2727

2828

2929
class SageMakerApi(Enum):
@@ -78,15 +78,15 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
7878
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync
7979
"""
8080

81-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
81+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
8282
SageMakerApi.CreateTrainingJob,
8383
IntegrationPattern.WaitForCompletion)
8484
else:
8585
"""
8686
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob
8787
"""
8888

89-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
89+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
9090
SageMakerApi.CreateTrainingJob)
9191

9292
if isinstance(job_name, str):
@@ -172,15 +172,15 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
172172
Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync
173173
"""
174174

175-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
175+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
176176
SageMakerApi.CreateTransformJob,
177177
IntegrationPattern.WaitForCompletion)
178178
else:
179179
"""
180180
Example resource arn: arn:aws:states:::sagemaker:createTransformJob
181181
"""
182182

183-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
183+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
184184
SageMakerApi.CreateTransformJob)
185185

186186
if isinstance(job_name, str):
@@ -268,7 +268,7 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
268268
Example resource arn: arn:aws:states:::sagemaker:createModel
269269
"""
270270

271-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
271+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
272272
SageMakerApi.CreateModel)
273273

274274
super(ModelStep, self).__init__(state_id, **kwargs)
@@ -314,7 +314,7 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
314314
Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig
315315
"""
316316

317-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
317+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
318318
SageMakerApi.CreateEndpointConfig)
319319

320320
kwargs[Field.Parameters.value] = parameters
@@ -352,14 +352,14 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
352352
Example resource arn: arn:aws:states:::sagemaker:updateEndpoint
353353
"""
354354

355-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
355+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
356356
SageMakerApi.UpdateEndpoint)
357357
else:
358358
"""
359359
Example resource arn: arn:aws:states:::sagemaker:createEndpoint
360360
"""
361361

362-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
362+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
363363
SageMakerApi.CreateEndpoint)
364364

365365
kwargs[Field.Parameters.value] = parameters
@@ -402,15 +402,15 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
402402
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync
403403
"""
404404

405-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
405+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
406406
SageMakerApi.CreateHyperParameterTuningJob,
407407
IntegrationPattern.WaitForCompletion)
408408
else:
409409
"""
410410
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob
411411
"""
412412

413-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
413+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
414414
SageMakerApi.CreateHyperParameterTuningJob)
415415

416416
parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
@@ -462,15 +462,15 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
462462
Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync
463463
"""
464464

465-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
465+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
466466
SageMakerApi.CreateProcessingJob,
467467
IntegrationPattern.WaitForCompletion)
468468
else:
469469
"""
470470
Example resource arn: arn:aws:states:::sagemaker:createProcessingJob
471471
"""
472472

473-
kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
473+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
474474
SageMakerApi.CreateProcessingJob)
475475

476476
if isinstance(job_name, str):

0 commit comments

Comments
 (0)