Skip to content

Commit 7a38877

Browse files
author
Bowen Yuan
committed
fix: make arns of all task resources aws-partition aware
1 parent f97fb20 commit 7a38877

File tree

7 files changed

+130
-53
lines changed

7 files changed

+130
-53
lines changed

src/stepfunctions/steps/compute.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from stepfunctions.steps.states import Task
1616
from stepfunctions.steps.fields import Field
17+
from stepfunctions.steps.utils import get_aws_partition
1718

1819

1920
class LambdaStep(Task):
@@ -38,9 +39,9 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs):
3839
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
3940
"""
4041
if wait_for_callback:
41-
kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke.waitForTaskToken'
42+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::lambda:invoke.waitForTaskToken'
4243
else:
43-
kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke'
44+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::lambda:invoke'
4445

4546
super(LambdaStep, self).__init__(state_id, **kwargs)
4647

@@ -67,9 +68,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
6768
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
6869
"""
6970
if wait_for_completion:
70-
kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun.sync'
71+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::glue:startJobRun.sync'
7172
else:
72-
kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun'
73+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::glue:startJobRun'
7374

7475
super(GlueStartJobRunStep, self).__init__(state_id, **kwargs)
7576

@@ -96,9 +97,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
9697
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
9798
"""
9899
if wait_for_completion:
99-
kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob.sync'
100+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::batch:submitJob.sync'
100101
else:
101-
kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob'
102+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::batch:submitJob'
102103

103104
super(BatchSubmitJobStep, self).__init__(state_id, **kwargs)
104105

@@ -125,8 +126,8 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
125126
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
126127
"""
127128
if wait_for_completion:
128-
kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask.sync'
129+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::ecs:runTask.sync'
129130
else:
130-
kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask'
131+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::ecs:runTask'
131132

132133
super(EcsRunTaskStep, self).__init__(state_id, **kwargs)

src/stepfunctions/steps/sagemaker.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from stepfunctions.inputs import ExecutionInput, StepInput
1616
from stepfunctions.steps.states import Task
1717
from stepfunctions.steps.fields import Field
18-
from stepfunctions.steps.utils import tags_dict_to_kv_list
18+
from stepfunctions.steps.utils import tags_dict_to_kv_list, get_aws_partition
1919

2020
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
2121
from sagemaker.model import Model, FrameworkModel
@@ -58,9 +58,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
5858
self.job_name = job_name
5959

6060
if wait_for_completion:
61-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob.sync'
61+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTrainingJob.sync'
6262
else:
63-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob'
63+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTrainingJob'
6464

6565
if isinstance(job_name, str):
6666
parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
@@ -141,9 +141,9 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
141141
join_source (str): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None.
142142
"""
143143
if wait_for_completion:
144-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob.sync'
144+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTransformJob.sync'
145145
else:
146-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob'
146+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTransformJob'
147147

148148
if isinstance(job_name, str):
149149
parameters = transform_config(
@@ -225,7 +225,7 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
225225
parameters['Tags'] = tags_dict_to_kv_list(tags)
226226

227227
kwargs[Field.Parameters.value] = parameters
228-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel'
228+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createModel'
229229

230230
super(ModelStep, self).__init__(state_id, **kwargs)
231231

@@ -266,7 +266,7 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
266266
if tags:
267267
parameters['Tags'] = tags_dict_to_kv_list(tags)
268268

269-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpointConfig'
269+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createEndpointConfig'
270270
kwargs[Field.Parameters.value] = parameters
271271

272272
super(EndpointConfigStep, self).__init__(state_id, **kwargs)
@@ -298,9 +298,9 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
298298
parameters['Tags'] = tags_dict_to_kv_list(tags)
299299

300300
if update:
301-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:updateEndpoint'
301+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:updateEndpoint'
302302
else:
303-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpoint'
303+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createEndpoint'
304304

305305
kwargs[Field.Parameters.value] = parameters
306306

@@ -338,9 +338,9 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
338338
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
339339
"""
340340
if wait_for_completion:
341-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync'
341+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createHyperParameterTuningJob.sync'
342342
else:
343-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob'
343+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createHyperParameterTuningJob'
344344

345345
parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
346346

@@ -387,9 +387,9 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
387387
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
388388
"""
389389
if wait_for_completion:
390-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob.sync'
390+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createProcessingJob.sync'
391391
else:
392-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob'
392+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createProcessingJob'
393393

394394
if isinstance(job_name, str):
395395
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)

src/stepfunctions/steps/service.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from stepfunctions.steps.states import Task
1616
from stepfunctions.steps.fields import Field
17+
from stepfunctions.steps.utils import get_aws_partition
1718

1819

1920
class DynamoDBGetItemStep(Task):
@@ -35,7 +36,7 @@ def __init__(self, state_id, **kwargs):
3536
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
3637
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
3738
"""
38-
kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:getItem'
39+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:getItem'
3940
super(DynamoDBGetItemStep, self).__init__(state_id, **kwargs)
4041

4142

@@ -59,7 +60,7 @@ def __init__(self, state_id, **kwargs):
5960
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
6061
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
6162
"""
62-
kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:putItem'
63+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:putItem'
6364
super(DynamoDBPutItemStep, self).__init__(state_id, **kwargs)
6465

6566

@@ -83,7 +84,7 @@ def __init__(self, state_id, **kwargs):
8384
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
8485
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
8586
"""
86-
kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:deleteItem'
87+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:deleteItem'
8788
super(DynamoDBDeleteItemStep, self).__init__(state_id, **kwargs)
8889

8990

@@ -107,7 +108,7 @@ def __init__(self, state_id, **kwargs):
107108
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
108109
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
109110
"""
110-
kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:updateItem'
111+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:updateItem'
111112
super(DynamoDBUpdateItemStep, self).__init__(state_id, **kwargs)
112113

113114

@@ -133,9 +134,9 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs):
133134
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
134135
"""
135136
if wait_for_callback:
136-
kwargs[Field.Resource.value] = 'arn:aws:states:::sns:publish.waitForTaskToken'
137+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sns:publish.waitForTaskToken'
137138
else:
138-
kwargs[Field.Resource.value] = 'arn:aws:states:::sns:publish'
139+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sns:publish'
139140

140141
super(SnsPublishStep, self).__init__(state_id, **kwargs)
141142

@@ -162,9 +163,9 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs):
162163
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
163164
"""
164165
if wait_for_callback:
165-
kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage.waitForTaskToken'
166+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sqs:sendMessage.waitForTaskToken'
166167
else:
167-
kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage'
168+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sqs:sendMessage'
168169

169170
super(SqsSendMessageStep, self).__init__(state_id, **kwargs)
170171

@@ -190,9 +191,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
190191
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True)
191192
"""
192193
if wait_for_completion:
193-
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster.sync'
194+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:createCluster.sync'
194195
else:
195-
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster'
196+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:createCluster'
196197

197198
super(EmrCreateClusterStep, self).__init__(state_id, **kwargs)
198199

@@ -218,9 +219,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
218219
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True)
219220
"""
220221
if wait_for_completion:
221-
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:terminateCluster.sync'
222+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:terminateCluster.sync'
222223
else:
223-
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:terminateCluster'
224+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:terminateCluster'
224225

225226
super(EmrTerminateClusterStep, self).__init__(state_id, **kwargs)
226227

@@ -246,9 +247,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
246247
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True)
247248
"""
248249
if wait_for_completion:
249-
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep.sync'
250+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:addStep.sync'
250251
else:
251-
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep'
252+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:addStep'
252253

253254
super(EmrAddStepStep, self).__init__(state_id, **kwargs)
254255

@@ -272,7 +273,7 @@ def __init__(self, state_id, **kwargs):
272273
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
273274
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
274275
"""
275-
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:cancelStep'
276+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:cancelStep'
276277

277278
super(EmrCancelStepStep, self).__init__(state_id, **kwargs)
278279

@@ -296,7 +297,7 @@ def __init__(self, state_id, **kwargs):
296297
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
297298
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
298299
"""
299-
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:setClusterTerminationProtection'
300+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:setClusterTerminationProtection'
300301

301302
super(EmrSetClusterTerminationProtectionStep, self).__init__(state_id, **kwargs)
302303

@@ -320,7 +321,7 @@ def __init__(self, state_id, **kwargs):
320321
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
321322
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
322323
"""
323-
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:modifyInstanceFleetByName'
324+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:modifyInstanceFleetByName'
324325

325326
super(EmrModifyInstanceFleetByNameStep, self).__init__(state_id, **kwargs)
326327

@@ -344,7 +345,7 @@ def __init__(self, state_id, **kwargs):
344345
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
345346
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
346347
"""
347-
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:modifyInstanceGroupByName'
348+
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:modifyInstanceGroupByName'
348349

349350
super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, **kwargs)
350351

src/stepfunctions/steps/utils.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,31 @@
1212
# permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import boto3
16+
import logging
17+
18+
logger = logging.getLogger('stepfunctions')
19+
20+
1521
def tags_dict_to_kv_list(tags_dict):
16-
kv_list = [{"Key": k, "Value": v} for k,v in tags_dict.items()]
17-
return kv_list
22+
kv_list = [{"Key": k, "Value": v} for k, v in tags_dict.items()]
23+
return kv_list
24+
25+
26+
# Obtain matching aws partition name based on region
27+
def get_aws_partition():
28+
partitions = boto3.session.Session().get_available_partitions()
29+
cur_region = boto3.session.Session().region_name
30+
cur_partition = "aws"
31+
32+
if cur_region is None:
33+
logger.warning("No region detected for the session, will use default partition: aws")
34+
return cur_partition
35+
36+
for partition in partitions:
37+
regions = boto3.session.Session().get_available_regions("stepfunctions", partition)
38+
if cur_region in regions:
39+
cur_partition = partition
40+
return cur_partition
41+
42+
return cur_partition

0 commit comments

Comments
 (0)