Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure explicit output dtype for pad_across_processes #3219

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mariusarvinte
Copy link
Contributor

What does this PR do?

Fixes #3218.

Current implementation casts torch.bool to torch.int64 because of + pad_index, where pad_index is 0 by default:

new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index

Adds a test case for checking that torch.bool is output with the same type.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr @BenjaminBossan @SunMarc

@@ -669,7 +669,7 @@ def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
old_size = tensor.shape
new_size = list(old_size)
new_size[dim] = max_size
new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
new_tensor = (tensor.new_zeros(tuple(new_size)) + pad_index).to(tensor.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm wondering how safe this is in case that the tensor dtype cannot represent the new data. E.g. when pad_index is not 0 or 1, casting this to bool will result in a loss of information.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, pad_index in the context of LLMs is usually -100.

Copy link
Contributor Author

@mariusarvinte mariusarvinte Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's any loss of useful information per se, given that the original data is retained downstream

new_tensor[indices] = tensor

What this does change though is the actual pad value. Any non-zero pad_index (e.g., -100) will result in padding with True for bool.

Does it make sense to always pad with False for bool? In our usecase, we directly manipulated bool tensors across devices and left pad_index = 0 by default. Not sure if bool actually appears in LLMs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect type in output of utils.pad_across_processes when input is torch.bool
4 participants