Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Bring Your Own Service Account to tfc.run #285

Merged
merged 1 commit into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/python/tensorflow_cloud/core/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def deploy_job(
entry_point_args,
enable_stream_logs,
job_labels=None,
service_account=None,
):
"""Deploys job with the given parameters to Google Cloud.

Expand All @@ -50,8 +51,12 @@ def deploy_job(
enable_stream_logs: Boolean flag which when enabled streams logs
back from the cloud job.
job_labels: Dict of str: str. Labels to organize jobs. See
https://cloud.google.com/ai-platform/training/docs/resource-labels.

[resource labels](
https://cloud.google.com/ai-platform/training/docs/resource-labels)
service_account: The email address of a user-managed service account
to be used for training instead of the service account that AI
Platform Training uses by default. See [custom service account](
https://cloud.google.com/ai-platform/training/docs/custom-service-account)
Returns:
ID of the invoked remote Cloud AI Platform job.

Expand All @@ -76,6 +81,7 @@ def deploy_job(
worker_config,
entry_point_args,
job_labels=job_labels or {},
service_account=service_account
)
try:
unused_response = (
Expand All @@ -102,6 +108,7 @@ def _create_request_dict(
worker_config,
entry_point_args,
job_labels,
service_account
):
"""Creates request dictionary for the CAIP training service."""
training_input = {}
Expand Down Expand Up @@ -162,6 +169,8 @@ def _create_request_dict(
request_dict["trainingInput"] = training_input
if job_labels:
request_dict["labels"] = job_labels
if service_account:
training_input["serviceAccount"] = service_account
return request_dict


Expand Down
25 changes: 24 additions & 1 deletion src/python/tensorflow_cloud/core/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ def _get_valid_machine_configurations():
def validate_job_labels(job_labels):
"""Validates job labels conform guidelines.

Ref. https://cloud.google.com/ai-platform/training/docs/resource-labels
Ref. [resource labels](
https://cloud.google.com/ai-platform/training/docs/resource-labels)

Args:
job_labels: String job label to validate.
Expand Down Expand Up @@ -479,3 +480,25 @@ def validate_job_labels(job_labels):
"numeric characters, underscores and dashes."
"Received: {}.".format(v)
)


def validate_service_account(service_account):
"""Validates service_account conform guidelines.

Ref.[user managed service accounts](
https://cloud.google.com/iam/docs/service-accounts#user-managed)

Args:
service_account: String service account to validate.
Raises:
ValueError if the given service_account is not conformant.
"""

if service_account and not re.match(
r"^.*@([a-z0-9\-]){6,30}\.iam\.gserviceaccount\.com",
service_account):
raise ValueError(
"Invalid service_account: service_account should follow "
"[email protected] "
"Received: {}.".format(service_account)
)
10 changes: 9 additions & 1 deletion src/python/tensorflow_cloud/core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def run(
entry_point_args=None,
stream_logs=False,
job_labels=None,
service_account=None,
**kwargs
):
"""Runs your Tensorflow code in Google Cloud Platform.
Expand Down Expand Up @@ -116,7 +117,12 @@ def run(
job_labels: Dict of str: str. Labels to organize jobs. You can specify
up to 64 key-value pairs in lowercase letters and numbers, where
the first character must be lowercase letter. For more details see
https://cloud.google.com/ai-platform/training/docs/resource-labels.
[resource-labels](
https://cloud.google.com/ai-platform/training/docs/resource-labels)
service_account: The email address of a user-managed service account
to be used for training instead of the service account that AI
Platform Training uses by default. see [custom-service-account](
https://cloud.google.com/ai-platform/training/docs/custom-service-account)
**kwargs: Additional keyword arguments.

Returns:
Expand Down Expand Up @@ -184,6 +190,7 @@ def run(
docker_config.image_build_bucket,
called_from_notebook,
job_labels=job_labels or {},
service_account=service_account,
docker_parent_image=docker_config.parent_image,
)

Expand Down Expand Up @@ -245,6 +252,7 @@ def run(
entry_point_args,
stream_logs,
job_labels=job_labels,
service_account=service_account,
)

# Call `exit` to prevent training the Keras model in the local env.
Expand Down
35 changes: 35 additions & 0 deletions src/python/tensorflow_cloud/core/tests/unit/deploy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def setUp(self):
self.docker_img = "custom-image-tag"
self.entry_point_args = ["1000"]
self.stream_logs = False
self.service_account = "[email protected]"

self.expected_request_dict = {
"jobId": self.mock_job_id,
Expand All @@ -53,6 +54,7 @@ def setUp(self):
"scaleTier": "custom",
"region": self.region,
"args": self.entry_point_args,
"serviceAccount": self.service_account,
"masterType": "n1-standard-16",
"workerType": "n1-standard-8",
"workerCount": str(self.worker_count),
Expand Down Expand Up @@ -103,6 +105,7 @@ def test_deploy_job(self, mock_stdout):
self.worker_config,
self.entry_point_args,
self.stream_logs,
service_account=self.service_account
)

self.assertEqual(job_name, self.mock_job_id)
Expand Down Expand Up @@ -154,6 +157,35 @@ def test_deploy_job(self, mock_stdout):
),
)

def test_deploy_job_with_default_service_account_has_no_serviceaccount_key(
self):
# If user does not provide a service account (i.e. service_account=None,
# the service account key should not be included in the request dict as
# AI Platform will treat None as the name of the service account.
_ = deploy.deploy_job(
self.docker_img,
self.chief_config,
self.worker_count,
self.worker_config,
self.entry_point_args,
self.stream_logs,
)
build_ret_val = self._mock_discovery_build.return_value
proj_ret_val = build_ret_val.projects.return_value
jobs_ret_val = proj_ret_val.jobs.return_value

del self.expected_request_dict["trainingInput"]["serviceAccount"]

# Verify job creation args
_, kwargs = jobs_ret_val.create.call_args
self.assertDictEqual(
kwargs,
{
"parent": "projects/" + self.mock_project_name,
"body": self.expected_request_dict,
},
)

def test_request_dict_without_workers(self):
worker_count = 0

Expand All @@ -164,6 +196,7 @@ def test_request_dict_without_workers(self):
None,
self.entry_point_args,
self.stream_logs,
service_account=self.service_account
)
build_ret_val = self._mock_discovery_build.return_value
proj_ret_val = build_ret_val.projects.return_value
Expand Down Expand Up @@ -192,6 +225,7 @@ def test_request_dict_without_user_args(self):
self.worker_config,
None,
self.stream_logs,
service_account=self.service_account
)
build_ret_val = self._mock_discovery_build.return_value
proj_ret_val = build_ret_val.projects.return_value
Expand Down Expand Up @@ -221,6 +255,7 @@ def test_request_dict_with_tpu_worker(self):
worker_config,
self.entry_point_args,
self.stream_logs,
service_account=self.service_account
)
build_ret_val = self._mock_discovery_build.return_value
proj_ret_val = build_ret_val.projects.return_value
Expand Down
33 changes: 33 additions & 0 deletions src/python/tensorflow_cloud/core/tests/unit/gcp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,39 @@ def test_validate_invalid_job_label(self):
"val{}".format(i) for i in range(65)}
)

def testValidateServiceAccount_NotEndWithCorrectDomain_RaisesValueError(
self):
with self.assertRaisesRegex(ValueError, r"Invalid service_account"):
# must end with .iam.gserviceaccount.com
gcp.validate_service_account(
"[email protected]_domain.com")

def testValidateServiceAccount_ShortProjectId_RaisesValueError(self):
with self.assertRaisesRegex(ValueError, r"Invalid service_account"):
# Project id must be greater than 6 characters
short_project_id = "a" * 5
gcp.validate_service_account(
f"test_sa_name@{short_project_id}.iam.gserviceaccount.com")

def testValidateServiceAccount_LongProjectId_RaisesValueError(self):
with self.assertRaisesRegex(ValueError, r"Invalid service_account"):
# Project id must be less than 30 characters
long_project_id = "a" * 31
gcp.validate_service_account(
f"test_sa_name@{long_project_id}.iam.gserviceaccount.com")

def testValidateServiceAccount_ProjectIdWithDot_RaisesValueError(self):
with self.assertRaisesRegex(ValueError, r"Invalid service_account"):
# Project id can not contain .
gcp.validate_service_account(
"[email protected]")

def testValidateServiceAccount_ProjectIdWithUnderScore_RaisesValueError(
self):
with self.assertRaisesRegex(ValueError, r"Invalid service_account"):
# Project id can not contain _
gcp.validate_service_account(
"test_sa_name@test_projectid.iam.gserviceaccount.com")

if __name__ == "__main__":
absltest.main()
14 changes: 13 additions & 1 deletion src/python/tensorflow_cloud/core/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def validate(
docker_image_build_bucket,
called_from_notebook,
job_labels=None,
service_account=None,
docker_parent_image=None,
):
"""Validates the inputs.
Expand Down Expand Up @@ -61,7 +62,12 @@ def validate(
called_from_notebook: Boolean. True if the API is run in a
notebook environment.
job_labels: Dict of str: str. Labels to organize jobs. See
https://cloud.google.com/ai-platform/training/docs/resource-labels.
[resource-labels](
https://cloud.google.com/ai-platform/training/docs/resource-labels).
service_account: The email address of a user-managed service account
to be used for training instead of the service account that AI
Platform Training uses by default. see [custom-service-account](
https://cloud.google.com/ai-platform/training/docs/custom-service-account)
docker_parent_image: Optional parent Docker image to use.
Defaults to None.

Expand All @@ -80,6 +86,7 @@ def validate(
docker_image_build_bucket,
called_from_notebook,
)
_validate_service_account(service_account)


def _validate_files(entry_point, requirements_txt):
Expand Down Expand Up @@ -179,6 +186,11 @@ def _validate_job_labels(job_labels):
gcp.validate_job_labels(job_labels)


def _validate_service_account(service_account):
"""Validates service_account."""
gcp.validate_service_account(service_account)


def _validate_other_args(
args, stream_logs, docker_image_build_bucket, called_from_notebook
):
Expand Down
6 changes: 4 additions & 2 deletions src/python/tensorflow_cloud/tuner/tests/unit/tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,8 @@ def test_get_job_spec_with_default_config(
worker_count=worker_count,
worker_config=worker_config,
entry_point_args=None,
job_labels=None)
job_labels=None,
service_account=None)

@mock.patch.object(super_tuner.Tuner, "__init__", autospec=True)
@mock.patch.object(deploy, "_create_request_dict", autospec=True)
Expand Down Expand Up @@ -742,7 +743,8 @@ def test_get_job_spec_with_default_with_custom_config(
worker_count=worker_count,
worker_config=replica_config,
entry_point_args=None,
job_labels=None)
job_labels=None,
service_account=None)

if __name__ == "__main__":
tf.test.main()
3 changes: 2 additions & 1 deletion src/python/tensorflow_cloud/tuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,8 @@ def _get_job_spec_from_config(self, job_id: Text) -> Dict[Text, Any]:
worker_count=worker_count,
worker_config=worker_config,
entry_point_args=None,
job_labels=None)
job_labels=None,
service_account=None)

def _get_remote_training_metrics(
self,
Expand Down