-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for Bring Your Own Service Account to tfc.run
PiperOrigin-RevId: 353946327
- Loading branch information
1 parent
db102b3
commit d252cb0
Showing
8 changed files
with
131 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -409,7 +409,8 @@ def _get_valid_machine_configurations(): | |
def validate_job_labels(job_labels): | ||
"""Validates job labels conform guidelines. | ||
Ref. https://cloud.google.com/ai-platform/training/docs/resource-labels | ||
Ref. [resource labels]( | ||
https://cloud.google.com/ai-platform/training/docs/resource-labels) | ||
Args: | ||
job_labels: String job label to validate. | ||
|
@@ -479,3 +480,25 @@ def validate_job_labels(job_labels): | |
"numeric characters, underscores and dashes." | ||
"Received: {}.".format(v) | ||
) | ||
|
||
|
||
def validate_service_account(service_account): | ||
"""Validates service_account conform guidelines. | ||
Ref.[user managed service accounts]( | ||
https://cloud.google.com/iam/docs/service-accounts#user-managed) | ||
Args: | ||
service_account: String service account to validate. | ||
Raises: | ||
ValueError if the given service_account is not conformant. | ||
""" | ||
|
||
if service_account and not re.match( | ||
r"^.*@([a-z0-9\-]){6,30}\.iam\.gserviceaccount\.com", | ||
service_account): | ||
raise ValueError( | ||
"Invalid service_account: service_account should follow " | ||
"[email protected] " | ||
"Received: {}.".format(service_account) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ def setUp(self): | |
self.docker_img = "custom-image-tag" | ||
self.entry_point_args = ["1000"] | ||
self.stream_logs = False | ||
self.service_account = "[email protected]" | ||
|
||
self.expected_request_dict = { | ||
"jobId": self.mock_job_id, | ||
|
@@ -53,6 +54,7 @@ def setUp(self): | |
"scaleTier": "custom", | ||
"region": self.region, | ||
"args": self.entry_point_args, | ||
"serviceAccount": self.service_account, | ||
"masterType": "n1-standard-16", | ||
"workerType": "n1-standard-8", | ||
"workerCount": str(self.worker_count), | ||
|
@@ -103,6 +105,7 @@ def test_deploy_job(self, mock_stdout): | |
self.worker_config, | ||
self.entry_point_args, | ||
self.stream_logs, | ||
service_account=self.service_account | ||
) | ||
|
||
self.assertEqual(job_name, self.mock_job_id) | ||
|
@@ -154,6 +157,35 @@ def test_deploy_job(self, mock_stdout): | |
), | ||
) | ||
|
||
def test_deploy_job_with_default_service_account_has_no_serviceaccount_key( | ||
self): | ||
# If user does not provide a service account (i.e. service_account=None, | ||
# the service account key should not be included in the request dict as | ||
# AI Platform will treat None as the name of the service account. | ||
_ = deploy.deploy_job( | ||
self.docker_img, | ||
self.chief_config, | ||
self.worker_count, | ||
self.worker_config, | ||
self.entry_point_args, | ||
self.stream_logs, | ||
) | ||
build_ret_val = self._mock_discovery_build.return_value | ||
proj_ret_val = build_ret_val.projects.return_value | ||
jobs_ret_val = proj_ret_val.jobs.return_value | ||
|
||
del self.expected_request_dict["trainingInput"]["serviceAccount"] | ||
|
||
# Verify job creation args | ||
_, kwargs = jobs_ret_val.create.call_args | ||
self.assertDictEqual( | ||
kwargs, | ||
{ | ||
"parent": "projects/" + self.mock_project_name, | ||
"body": self.expected_request_dict, | ||
}, | ||
) | ||
|
||
def test_request_dict_without_workers(self): | ||
worker_count = 0 | ||
|
||
|
@@ -164,6 +196,7 @@ def test_request_dict_without_workers(self): | |
None, | ||
self.entry_point_args, | ||
self.stream_logs, | ||
service_account=self.service_account | ||
) | ||
build_ret_val = self._mock_discovery_build.return_value | ||
proj_ret_val = build_ret_val.projects.return_value | ||
|
@@ -192,6 +225,7 @@ def test_request_dict_without_user_args(self): | |
self.worker_config, | ||
None, | ||
self.stream_logs, | ||
service_account=self.service_account | ||
) | ||
build_ret_val = self._mock_discovery_build.return_value | ||
proj_ret_val = build_ret_val.projects.return_value | ||
|
@@ -221,6 +255,7 @@ def test_request_dict_with_tpu_worker(self): | |
worker_config, | ||
self.entry_point_args, | ||
self.stream_logs, | ||
service_account=self.service_account | ||
) | ||
build_ret_val = self._mock_discovery_build.return_value | ||
proj_ret_val = build_ret_val.projects.return_value | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,6 +185,39 @@ def test_validate_invalid_job_label(self): | |
"val{}".format(i) for i in range(65)} | ||
) | ||
|
||
def testValidateServiceAccount_NotEndWithCorrectDomain_RaisesValueError( | ||
self): | ||
with self.assertRaisesRegex(ValueError, r"Invalid service_account"): | ||
# must end with .iam.gserviceaccount.com | ||
gcp.validate_service_account( | ||
"[email protected]_domain.com") | ||
|
||
def testValidateServiceAccount_ShortProjectId_RaisesValueError(self): | ||
with self.assertRaisesRegex(ValueError, r"Invalid service_account"): | ||
# Project id must be greater than 6 characters | ||
short_project_id = "a" * 5 | ||
gcp.validate_service_account( | ||
f"test_sa_name@{short_project_id}.iam.gserviceaccount.com") | ||
|
||
def testValidateServiceAccount_LongProjectId_RaisesValueError(self): | ||
with self.assertRaisesRegex(ValueError, r"Invalid service_account"): | ||
# Project id must be less than 30 characters | ||
long_project_id = "a" * 31 | ||
gcp.validate_service_account( | ||
f"test_sa_name@{long_project_id}.iam.gserviceaccount.com") | ||
|
||
def testValidateServiceAccount_ProjectIdWithDot_RaisesValueError(self): | ||
with self.assertRaisesRegex(ValueError, r"Invalid service_account"): | ||
# Project id can not contain . | ||
gcp.validate_service_account( | ||
"[email protected]") | ||
|
||
def testValidateServiceAccount_ProjectIdWithUnderScore_RaisesValueError( | ||
self): | ||
with self.assertRaisesRegex(ValueError, r"Invalid service_account"): | ||
# Project id can not contain _ | ||
gcp.validate_service_account( | ||
"test_sa_name@test_projectid.iam.gserviceaccount.com") | ||
|
||
if __name__ == "__main__": | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters