Skip to content

Commit

Permalink
Add support for Bring Your Own Service Account to tfc.run
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 353946327
  • Loading branch information
SinaChavoshi authored and Tensorflow Cloud maintainers committed Jan 26, 2021
1 parent db102b3 commit d252cb0
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 8 deletions.
13 changes: 11 additions & 2 deletions src/python/tensorflow_cloud/core/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def deploy_job(
entry_point_args,
enable_stream_logs,
job_labels=None,
service_account=None,
):
"""Deploys job with the given parameters to Google Cloud.
Expand All @@ -50,8 +51,12 @@ def deploy_job(
enable_stream_logs: Boolean flag which when enabled streams logs
back from the cloud job.
job_labels: Dict of str: str. Labels to organize jobs. See
https://cloud.google.com/ai-platform/training/docs/resource-labels.
[resource labels](
https://cloud.google.com/ai-platform/training/docs/resource-labels)
service_account: The email address of a user-managed service account
to be used for training instead of the service account that AI
Platform Training uses by default. See [custom service account](
https://cloud.google.com/ai-platform/training/docs/custom-service-account)
Returns:
ID of the invoked remote Cloud AI Platform job.
Expand All @@ -76,6 +81,7 @@ def deploy_job(
worker_config,
entry_point_args,
job_labels=job_labels or {},
service_account=service_account
)
try:
unused_response = (
Expand All @@ -102,6 +108,7 @@ def _create_request_dict(
worker_config,
entry_point_args,
job_labels,
service_account
):
"""Creates request dictionary for the CAIP training service."""
training_input = {}
Expand Down Expand Up @@ -162,6 +169,8 @@ def _create_request_dict(
request_dict["trainingInput"] = training_input
if job_labels:
request_dict["labels"] = job_labels
if service_account:
training_input["serviceAccount"] = service_account
return request_dict


Expand Down
25 changes: 24 additions & 1 deletion src/python/tensorflow_cloud/core/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ def _get_valid_machine_configurations():
def validate_job_labels(job_labels):
"""Validates job labels conform guidelines.
Ref. https://cloud.google.com/ai-platform/training/docs/resource-labels
Ref. [resource labels](
https://cloud.google.com/ai-platform/training/docs/resource-labels)
Args:
job_labels: String job label to validate.
Expand Down Expand Up @@ -479,3 +480,25 @@ def validate_job_labels(job_labels):
"numeric characters, underscores and dashes."
"Received: {}.".format(v)
)


def validate_service_account(service_account):
"""Validates service_account conform guidelines.
Ref.[user managed service accounts](
https://cloud.google.com/iam/docs/service-accounts#user-managed)
Args:
service_account: String service account to validate.
Raises:
ValueError if the given service_account is not conformant.
"""

if service_account and not re.match(
r"^.*@([a-z0-9\-]){6,30}\.iam\.gserviceaccount\.com",
service_account):
raise ValueError(
"Invalid service_account: service_account should follow "
"[email protected] "
"Received: {}.".format(service_account)
)
10 changes: 9 additions & 1 deletion src/python/tensorflow_cloud/core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def run(
entry_point_args=None,
stream_logs=False,
job_labels=None,
service_account=None,
**kwargs
):
"""Runs your Tensorflow code in Google Cloud Platform.
Expand Down Expand Up @@ -116,7 +117,12 @@ def run(
job_labels: Dict of str: str. Labels to organize jobs. You can specify
up to 64 key-value pairs in lowercase letters and numbers, where
the first character must be lowercase letter. For more details see
https://cloud.google.com/ai-platform/training/docs/resource-labels.
[resource-labels](
https://cloud.google.com/ai-platform/training/docs/resource-labels)
service_account: The email address of a user-managed service account
to be used for training instead of the service account that AI
Platform Training uses by default. see [custom-service-account](
https://cloud.google.com/ai-platform/training/docs/custom-service-account)
**kwargs: Additional keyword arguments.
Returns:
Expand Down Expand Up @@ -184,6 +190,7 @@ def run(
docker_config.image_build_bucket,
called_from_notebook,
job_labels=job_labels or {},
service_account=service_account,
docker_parent_image=docker_config.parent_image,
)

Expand Down Expand Up @@ -245,6 +252,7 @@ def run(
entry_point_args,
stream_logs,
job_labels=job_labels,
service_account=service_account,
)

# Call `exit` to prevent training the Keras model in the local env.
Expand Down
35 changes: 35 additions & 0 deletions src/python/tensorflow_cloud/core/tests/unit/deploy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def setUp(self):
self.docker_img = "custom-image-tag"
self.entry_point_args = ["1000"]
self.stream_logs = False
self.service_account = "[email protected]"

self.expected_request_dict = {
"jobId": self.mock_job_id,
Expand All @@ -53,6 +54,7 @@ def setUp(self):
"scaleTier": "custom",
"region": self.region,
"args": self.entry_point_args,
"serviceAccount": self.service_account,
"masterType": "n1-standard-16",
"workerType": "n1-standard-8",
"workerCount": str(self.worker_count),
Expand Down Expand Up @@ -103,6 +105,7 @@ def test_deploy_job(self, mock_stdout):
self.worker_config,
self.entry_point_args,
self.stream_logs,
service_account=self.service_account
)

self.assertEqual(job_name, self.mock_job_id)
Expand Down Expand Up @@ -154,6 +157,35 @@ def test_deploy_job(self, mock_stdout):
),
)

def test_deploy_job_with_default_service_account_has_no_serviceaccount_key(
self):
# If user does not provide a service account (i.e. service_account=None,
# the service account key should not be included in the request dict as
# AI Platform will treat None as the name of the service account.
_ = deploy.deploy_job(
self.docker_img,
self.chief_config,
self.worker_count,
self.worker_config,
self.entry_point_args,
self.stream_logs,
)
build_ret_val = self._mock_discovery_build.return_value
proj_ret_val = build_ret_val.projects.return_value
jobs_ret_val = proj_ret_val.jobs.return_value

del self.expected_request_dict["trainingInput"]["serviceAccount"]

# Verify job creation args
_, kwargs = jobs_ret_val.create.call_args
self.assertDictEqual(
kwargs,
{
"parent": "projects/" + self.mock_project_name,
"body": self.expected_request_dict,
},
)

def test_request_dict_without_workers(self):
worker_count = 0

Expand All @@ -164,6 +196,7 @@ def test_request_dict_without_workers(self):
None,
self.entry_point_args,
self.stream_logs,
service_account=self.service_account
)
build_ret_val = self._mock_discovery_build.return_value
proj_ret_val = build_ret_val.projects.return_value
Expand Down Expand Up @@ -192,6 +225,7 @@ def test_request_dict_without_user_args(self):
self.worker_config,
None,
self.stream_logs,
service_account=self.service_account
)
build_ret_val = self._mock_discovery_build.return_value
proj_ret_val = build_ret_val.projects.return_value
Expand Down Expand Up @@ -221,6 +255,7 @@ def test_request_dict_with_tpu_worker(self):
worker_config,
self.entry_point_args,
self.stream_logs,
service_account=self.service_account
)
build_ret_val = self._mock_discovery_build.return_value
proj_ret_val = build_ret_val.projects.return_value
Expand Down
33 changes: 33 additions & 0 deletions src/python/tensorflow_cloud/core/tests/unit/gcp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,39 @@ def test_validate_invalid_job_label(self):
"val{}".format(i) for i in range(65)}
)

def testValidateServiceAccount_NotEndWithCorrectDomain_RaisesValueError(
self):
with self.assertRaisesRegex(ValueError, r"Invalid service_account"):
# must end with .iam.gserviceaccount.com
gcp.validate_service_account(
"[email protected]_domain.com")

def testValidateServiceAccount_ShortProjectId_RaisesValueError(self):
with self.assertRaisesRegex(ValueError, r"Invalid service_account"):
# Project id must be greater than 6 characters
short_project_id = "a" * 5
gcp.validate_service_account(
f"test_sa_name@{short_project_id}.iam.gserviceaccount.com")

def testValidateServiceAccount_LongProjectId_RaisesValueError(self):
with self.assertRaisesRegex(ValueError, r"Invalid service_account"):
# Project id must be less than 30 characters
long_project_id = "a" * 31
gcp.validate_service_account(
f"test_sa_name@{long_project_id}.iam.gserviceaccount.com")

def testValidateServiceAccount_ProjectIdWithDot_RaisesValueError(self):
with self.assertRaisesRegex(ValueError, r"Invalid service_account"):
# Project id can not contain .
gcp.validate_service_account(
"[email protected]")

def testValidateServiceAccount_ProjectIdWithUnderScore_RaisesValueError(
self):
with self.assertRaisesRegex(ValueError, r"Invalid service_account"):
# Project id can not contain _
gcp.validate_service_account(
"test_sa_name@test_projectid.iam.gserviceaccount.com")

if __name__ == "__main__":
absltest.main()
14 changes: 13 additions & 1 deletion src/python/tensorflow_cloud/core/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def validate(
docker_image_build_bucket,
called_from_notebook,
job_labels=None,
service_account=None,
docker_parent_image=None,
):
"""Validates the inputs.
Expand Down Expand Up @@ -61,7 +62,12 @@ def validate(
called_from_notebook: Boolean. True if the API is run in a
notebook environment.
job_labels: Dict of str: str. Labels to organize jobs. See
https://cloud.google.com/ai-platform/training/docs/resource-labels.
[resource-labels](
https://cloud.google.com/ai-platform/training/docs/resource-labels).
service_account: The email address of a user-managed service account
to be used for training instead of the service account that AI
Platform Training uses by default. see [custom-service-account](
https://cloud.google.com/ai-platform/training/docs/custom-service-account)
docker_parent_image: Optional parent Docker image to use.
Defaults to None.
Expand All @@ -80,6 +86,7 @@ def validate(
docker_image_build_bucket,
called_from_notebook,
)
_validate_service_account(service_account)


def _validate_files(entry_point, requirements_txt):
Expand Down Expand Up @@ -179,6 +186,11 @@ def _validate_job_labels(job_labels):
gcp.validate_job_labels(job_labels)


def _validate_service_account(service_account):
"""Validates service_account."""
gcp.validate_service_account(service_account)


def _validate_other_args(
args, stream_logs, docker_image_build_bucket, called_from_notebook
):
Expand Down
6 changes: 4 additions & 2 deletions src/python/tensorflow_cloud/tuner/tests/unit/tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,8 @@ def test_get_job_spec_with_default_config(
worker_count=worker_count,
worker_config=worker_config,
entry_point_args=None,
job_labels=None)
job_labels=None,
service_account=None)

@mock.patch.object(super_tuner.Tuner, "__init__", autospec=True)
@mock.patch.object(deploy, "_create_request_dict", autospec=True)
Expand Down Expand Up @@ -742,7 +743,8 @@ def test_get_job_spec_with_default_with_custom_config(
worker_count=worker_count,
worker_config=replica_config,
entry_point_args=None,
job_labels=None)
job_labels=None,
service_account=None)

if __name__ == "__main__":
tf.test.main()
3 changes: 2 additions & 1 deletion src/python/tensorflow_cloud/tuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,8 @@ def _get_job_spec_from_config(self, job_id: Text) -> Dict[Text, Any]:
worker_count=worker_count,
worker_config=worker_config,
entry_point_args=None,
job_labels=None)
job_labels=None,
service_account=None)

def _get_remote_training_metrics(
self,
Expand Down

0 comments on commit d252cb0

Please sign in to comment.