-
Notifications
You must be signed in to change notification settings - Fork 392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add parameter server train & side-car eval on k8s #182
base: master
Are you sure you want to change the base?
Conversation
ResNet56 model (with custom training loop) variables are created on parameter server jobs, and updated by workers. Evaluation is done using a dedicated job which uses the checkpoints saved during the training (side-car evaluation). The model is trained on CIFAR10 dataset.
Jinja template now turns off side-car evaluation by default so that only the inline distributed evaluation added with this CL can be used.i README updated. Added efficiency wrappers that will be useful once GPU is supported with ParameterServerStrategy. Moved kubernetes jinja template and renderer script to dedicated subdirectory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank for your PR!
Please first read the | ||
[documentation](https://www.tensorflow.org/tutorials/distribute/parameter_server_training) | ||
of Distribution Strategy for parameter server training. We also assume that readers | ||
of this page are familiar with [Google Cloud](https://cloud.google.com/) and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant space
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
- kubernetes/template.yaml.jinja: jinja template used for generating Kubernetes manifests | ||
- kubernetes/render_template.py: script for rendering the jinja template | ||
- Dockerfile.resnet_cifar_ps_strategy: a docker file to build the model image | ||
- resnet_cifar_ps_strategy.py: script for running any type of parameter server training task based on `TF_CONFIG` environment variable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"any type of ..." seems too general, maybe just say "a ResNet example using Cifar dataset for parameter server training"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
BATCH_SIZE = 64 | ||
EVAL_BATCH_SIZE = 8 | ||
|
||
def create_in_process_cluster(num_workers, num_ps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you update the work_config part according to this tutorial? https://www.tensorflow.org/tutorials/distribute/parameter_server_training#in-process_cluster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added inter ops.
set up distributed training | ||
""" | ||
|
||
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use tf.distribute.experimental.ParameterServerStrategy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
logging.info("Finished joining at epoch %d. Training accuracy: %f.", | ||
epoch, train_accuracy.result()) | ||
|
||
for _ in range(STEPS_PER_EPOCH): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should evaluation use a different steps_per_epoch? since you have a different batch_size for evaluation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Introducing EVAL_STEPS_PER_EPOCH and setting it to 88 in the next patch shortly. This gives us a probability of 0.99 for a row in the dataset to be evaluated.
logging.info("Finished joining at epoch %d. Training accuracy: %f.", | ||
epoch, train_accuracy.result()) | ||
|
||
for _ in range(STEPS_PER_EPOCH): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment here saying that we are running inline distributed evaluation, in this case an evaluator job is not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Also addressed the following: * Added inter_ops for workers * Replaced parameter_server_strategy_v2.ParameterServerStrategyV2 with tf.distribute.experimental.ParameterServerStrategy * Clarified resnet_cifar_ps_strategy.py description * Indicated that side-car evaluation job is ot needed since we are running inline-evaluation * Removed redundant spaces
flags.DEFINE_string("data_dir", "gs://cifar10_data/", | ||
"Directory for Resnet Cifar model input. Follow the " | ||
"instruction here to get Cifar10 data: " | ||
"https://github.com/tensorflow/models/tree/r1.13.0/official/resnet#cifar-10") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant new line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Split the help argument into multiple lines for readability; they are displayed as concatenated if help cmdline arg is passed.
parse_record_fn=cifar_preprocessing.parse_record, | ||
dtype=tf.float32, | ||
drop_remainder=True) | ||
eval_dataset_fn = lambda _: cifar_preprocessing.input_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the eval data shuffled? If not, could you add a comment and a TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you can just append a shuffle at the end of the dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_fn already shuffles the training data using process_record_dataset: code link
|
||
# Since we are running inline evaluation below, a side-car evaluator job is not necessary. | ||
for _ in range(EVAL_STEPS_PER_EPOCH): | ||
coordinator.schedule(worker_eval_fn, args=(per_worker_eval_iterator,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably build a similar API for DTensor async training. A major difficulty to sort out is what to do if worker_eval_fn( and or replica_fn) is multi-mesh -- for example if there is a summary Op that needs to run on a the CPU.
ResNet56 model (with custom training loop) variables are created on
parameter server jobs, and updated by workers. Evaluation is done using
a dedicated job which uses the checkpoints saved during the training
(side-car evaluation).
The model is trained on CIFAR10 dataset.