-
Notifications
You must be signed in to change notification settings - Fork 491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve and refine MLP tests for extensibility and A/B testing #8561
Conversation
@tengyifei @ManfeiBai I found myself having to largely improve/enhance the MLP tests, since I wanted to reuse this test for A/B convergence validation:
PTAL. |
fee509f
to
30718c9
Compare
30718c9
to
e9e35ab
Compare
3cea463
to
4d92118
Compare
4d92118
to
5e89a3e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synced offline. Let's add a --skip-gradient-checkpointing
CLI arg to the train testing script or similar to skip the gradient checkpointing on CPU, in order to avoid the confusing test_train_spmd_linear_model_grad_checkpointing
name.
I was just about to write this, done! |
5e89a3e
to
b65e408
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the test file is duplicated under test/utils
?
I don't see it. The core of the functionality was moved to utils - it's a standalone training script. The |
Hi @rpsilva-aws thank you for the explanation. I still don't get the A/B testing part. But if the body can be reused in other tests then it's fine |
Thanks. You can see it in this PR here: https://github.com/pytorch/xla/pull/8561/files#diff-09c6a280d1c6fc8053d5a17e919e29fe51a75fd593f6c2012496b6e970312c25R44 Essentially, it allows it to A/B test different functionality. It adds a bit more value than running each standalone test, and just checking that the losses/output are not 0. Alternatively, we'd need to assert against an hardcoded set of expected values. We want to reuse this for testing gradient accumulation with and without XLA's while loop. |
In this PR, we include various fixes, improvements and extensions, namely: