Skip to content

Commit 8b97898

Browse files
committed
Additional fixes for SKLearn and Tensorflow Estimators
* Removed sagemaker_session for SKLearn * Moved checkpoint_path into hyper parameters (https://sagemaker.readthedocs.io/en/v2.0.0.rc0/frameworks/tensorflow/upgrade_from_legacy.html) * Added framework_version and py_version * Update entry_point and renamed image_name to image_uri for TensorFlow
1 parent 2273ffa commit 8b97898

File tree

4 files changed

+18
-7
lines changed

4 files changed

+18
-7
lines changed

tests/integ/test_inference_pipeline.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def sklearn_preprocessor(sagemaker_role_arn, sagemaker_session):
4343
'one_p_mnist',
4444
'sklearn_mnist_preprocessor.py')
4545
sklearn_preprocessor = SKLearn(
46+
framework_version='0.20.0',
47+
py_version='py3',
4648
entry_point=script_path,
4749
role=sagemaker_role_arn,
4850
instance_type="ml.m5.large",
@@ -58,6 +60,8 @@ def sklearn_estimator(sagemaker_role_arn, sagemaker_session):
5860
'one_p_mnist',
5961
'sklearn_mnist_estimator.py')
6062
sklearn_estimator = SKLearn(
63+
framework_version='0.20.0',
64+
py_version='py3',
6165
entry_point=script_path,
6266
role=sagemaker_role_arn,
6367
instance_type="ml.m5.large",

tests/integ/test_training_pipeline_framework_estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
def torch_estimator(sagemaker_role_arn):
3434
script_path = os.path.join(DATA_DIR, "pytorch_mnist", "mnist.py")
3535
return PyTorch(
36+
py_version='py3',
3637
entry_point=script_path,
3738
role=sagemaker_role_arn,
3839
framework_version='1.2.0',
@@ -48,11 +49,12 @@ def torch_estimator(sagemaker_role_arn):
4849
def sklearn_estimator(sagemaker_role_arn):
4950
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py")
5051
return SKLearn(
52+
framework_version='0.20.0',
53+
py_version='py3',
5154
entry_point=script_path,
5255
role=sagemaker_role_arn,
5356
instance_count=1,
5457
instance_type='ml.m5.large',
55-
framework_version='0.20.0',
5658
hyperparameters={
5759
"epochs": 1
5860
}

tests/unit/test_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,12 @@ def sklearn_preprocessor():
6060
sagemaker_session.boto_region_name = 'us-east-1'
6161

6262
sklearn_preprocessor = SKLearn(
63+
framework_version='0.20.0',
64+
py_version='py3',
6365
entry_point=script_path,
6466
role=SAGEMAKER_EXECUTION_ROLE,
6567
instance_type="ml.c4.xlarge",
6668
source_dir=source_dir,
67-
sagemaker_session=sagemaker_session
6869
)
6970

7071
sklearn_preprocessor.debugger_hook_config = DebuggerHookConfig(

tests/unit/test_sagemaker_steps.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,21 @@ def tensorflow_estimator():
135135
s3_output_location = 's3://sagemaker/models'
136136
s3_source_location = 's3://sagemaker/source'
137137

138-
estimator = TensorFlow(entry_point='tf_train.py',
138+
estimator = TensorFlow(
139+
entry_point='tf_train.py',
139140
role=EXECUTION_ROLE,
140141
framework_version='1.13',
141-
training_steps=1000,
142-
evaluation_steps=100,
143142
instance_count=1,
144143
instance_type='ml.p2.xlarge',
145144
output_path=s3_output_location,
146145
source_dir=s3_source_location,
147-
image_name=TENSORFLOW_IMAGE,
148-
checkpoint_path='s3://sagemaker/models/sagemaker-tensorflow/checkpoints'
146+
image_uri=TENSORFLOW_IMAGE,
147+
model_dir=False,
148+
hyperparameters={
149+
'training_steps': 1000,
150+
'evaluation_steps': 100,
151+
'checkpoint_path': 's3://sagemaker/models/sagemaker-tensorflow/checkpoints',
152+
}
149153
)
150154

151155
estimator.debugger_hook_config = DebuggerHookConfig(

0 commit comments

Comments
 (0)