Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CTCLoss: Fix the hang issue caused by barrier divergence #1087

Merged
merged 5 commits into from
Nov 18, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/ATen/native/xpu/sycl/LossCTCKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ struct CTCLossLogAlphaKernelFunctor {
have_three = false;
}
for (int64_t t = 1; t < max_input_length_; t++) {
item.barrier(sycl_local_fence);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, it depends on workloads to expose such kind of issues. My question is why we could not find issues earlier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To catch such issues you typically need to do one of the 2 things, better both:

  1. Each or almost each change to the code should be accompanied with dedicated test. If you reuse existing test, we need to review that it tests all corner cases. With the issue we spotted we could check that early exit conditions are actually being tried out.
  2. Run real life tests, preferably from real life 3rd party framework or library. Huggingface Transformers gives you excellent way to do this.

1st item is a better guarantee that issues won't be missed. 2nd item is a lesser guarantee, but gives certainty that at least some real life cases will work. In both cases issues might be missed, but utilizing both we reduce the probability that something will get missed.

if ((t < input_length) && (s < 2 * target_length + 1)) {
// only for valid t, s. This is equation (6) and (7), la1, la2, la3
// are the three summands, lamax is the maximum for the logsumexp
Expand Down Expand Up @@ -161,7 +160,6 @@ struct CTCLossLogAlphaKernelFunctor {
}
}
}
item.barrier(sycl_local_fence);

// compute the loss (eq (8))
if (tid_x == 0) {
Expand Down Expand Up @@ -490,7 +488,6 @@ struct CTCLossBackwardLogBetaKernelFunctor {
// now go backward in t. Note that we need to skip the last timestep that
// we did above.
for (int64_t t = max_input_length_ - 2; t >= 0; t--) {
item.barrier(sycl_local_fence);
if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
scalar_t lb1 = log_beta_data_
[lb_batch_offset + lb_input_stride_ * (t + 1) +
Expand Down
Loading