Skip to content

Commit 24f1b61

Browse files
committed
Merging changes from upstream master. Targeting v2+
2 parents a8053db + b9b0371 commit 24f1b61

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
6767
else:
6868
parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size)
6969

70-
if estimator.debugger_hook_config != None:
70+
if estimator.debugger_hook_config != None and estimator.debugger_hook_config is not False:
7171
parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict()
7272

7373
if estimator.rules != None:

tests/unit/test_sagemaker_steps.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,34 @@ def pca_estimator_with_debug_hook():
111111

112112
return pca
113113

114+
115+
@pytest.fixture
116+
def pca_estimator_with_falsy_debug_hook():
117+
s3_output_location = 's3://sagemaker/models'
118+
119+
pca = sagemaker.estimator.Estimator(
120+
PCA_IMAGE,
121+
role=EXECUTION_ROLE,
122+
train_instance_count=1,
123+
train_instance_type='ml.c4.xlarge',
124+
output_path=s3_output_location,
125+
debugger_hook_config = False
126+
)
127+
128+
pca.set_hyperparameters(
129+
feature_dim=50000,
130+
num_components=10,
131+
subtract_mean=True,
132+
algorithm_mode='randomized',
133+
mini_batch_size=200
134+
)
135+
136+
pca.sagemaker_session = MagicMock()
137+
pca.sagemaker_session.boto_region_name = 'us-east-1'
138+
pca.sagemaker_session._default_bucket = 'sagemaker'
139+
140+
return pca
141+
114142
@pytest.fixture
115143
def pca_model():
116144
model_data = 's3://sagemaker/models/pca.tar.gz'
@@ -287,6 +315,43 @@ def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook):
287315
'End': True
288316
}
289317

318+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
319+
def test_training_step_creation_with_falsy_debug_hook(pca_estimator_with_falsy_debug_hook):
320+
step = TrainingStep('Training',
321+
estimator=pca_estimator_with_falsy_debug_hook,
322+
job_name='TrainingJob')
323+
assert step.to_dict() == {
324+
'Type': 'Task',
325+
'Parameters': {
326+
'AlgorithmSpecification': {
327+
'TrainingImage': PCA_IMAGE,
328+
'TrainingInputMode': 'File'
329+
},
330+
'OutputDataConfig': {
331+
'S3OutputPath': 's3://sagemaker/models'
332+
},
333+
'StoppingCondition': {
334+
'MaxRuntimeInSeconds': 86400
335+
},
336+
'ResourceConfig': {
337+
'InstanceCount': 1,
338+
'InstanceType': 'ml.c4.xlarge',
339+
'VolumeSizeInGB': 30
340+
},
341+
'RoleArn': EXECUTION_ROLE,
342+
'HyperParameters': {
343+
'feature_dim': '50000',
344+
'num_components': '10',
345+
'subtract_mean': 'True',
346+
'algorithm_mode': 'randomized',
347+
'mini_batch_size': '200'
348+
},
349+
'TrainingJobName': 'TrainingJob'
350+
},
351+
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
352+
'End': True
353+
}
354+
290355
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
291356
def test_training_step_creation_with_model(pca_estimator):
292357
training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob')

0 commit comments

Comments
 (0)