@@ -31,12 +31,11 @@ def test_mnist(sagemaker_session, image_uri, instance_type, framework_version):
3131 estimator = TensorFlow (
3232 entry_point = script ,
3333 role = "SageMakerRole" ,
34- train_instance_type = instance_type ,
35- train_instance_count = 1 ,
34+ instance_type = instance_type ,
35+ instance_count = 1 ,
3636 sagemaker_session = sagemaker_session ,
37- image_name = image_uri ,
37+ image_uri = image_uri ,
3838 framework_version = framework_version ,
39- script_mode = True ,
4039 )
4140 inputs = estimator .sagemaker_session .upload_data (
4241 path = os .path .join (resource_path , "mnist" , "data" ), key_prefix = "scriptmode/mnist"
@@ -51,12 +50,11 @@ def test_distributed_mnist_no_ps(sagemaker_session, image_uri, instance_type, fr
5150 estimator = TensorFlow (
5251 entry_point = script ,
5352 role = "SageMakerRole" ,
54- train_instance_count = 2 ,
55- train_instance_type = instance_type ,
53+ instance_count = 2 ,
54+ instance_type = instance_type ,
5655 sagemaker_session = sagemaker_session ,
57- image_name = image_uri ,
56+ image_uri = image_uri ,
5857 framework_version = framework_version ,
59- script_mode = True ,
6058 )
6159 inputs = estimator .sagemaker_session .upload_data (
6260 path = os .path .join (resource_path , "mnist" , "data" ), key_prefix = "scriptmode/mnist"
@@ -72,12 +70,11 @@ def test_distributed_mnist_ps(sagemaker_session, image_uri, instance_type, frame
7270 entry_point = script ,
7371 role = "SageMakerRole" ,
7472 hyperparameters = {"sagemaker_parameter_server_enabled" : True },
75- train_instance_count = 2 ,
76- train_instance_type = instance_type ,
73+ instance_count = 2 ,
74+ instance_type = instance_type ,
7775 sagemaker_session = sagemaker_session ,
78- image_name = image_uri ,
76+ image_uri = image_uri ,
7977 framework_version = framework_version ,
80- script_mode = True ,
8178 )
8279 inputs = estimator .sagemaker_session .upload_data (
8380 path = os .path .join (resource_path , "mnist" , "data-distributed" ),
@@ -95,10 +92,10 @@ def test_tuning(sagemaker_session, image_uri, instance_type, framework_version):
9592 estimator = TensorFlow (
9693 entry_point = script ,
9794 role = "SageMakerRole" ,
98- train_instance_type = instance_type ,
99- train_instance_count = 1 ,
95+ instance_type = instance_type ,
96+ instance_count = 1 ,
10097 sagemaker_session = sagemaker_session ,
101- image_name = image_uri ,
98+ image_uri = image_uri ,
10299 framework_version = framework_version ,
103100 script_mode = True ,
104101 )
0 commit comments