-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Bool data handled as tensors -> can't set batch size with 0-dim data. #1199
Comments
This is a design decision in I'm not going to hide from it: it is a "historical" thing in the sense that when we started, The official way to do what you want is to pass them as Now that being said, I'd be open to implement a feature by which these transformation do not occur but then we'd be facing an issue which is to decide what happens down the line: td = TensorDict(a=0, autocast=False)
td["a"] # returns an integer
td["b"] = 1 # cast or no cast? ie, is the td = TensorDict(autocast=False)
td["root"] = 0 # no cast
td["a", "b"] = 0 # no cast?
td["a"]["b"] = 0 # no cast?
super_td = TensorDict(td=td, autocast=True)
super_td["td", "a", "b"] = 0 # no cast? (in tensorclass we have the ability to control this via |
Okay, I see what you mean. Not a straightforward change. However, we can probably postpone this change but still address the initial issue by making a small change in tensordict/tensordict/utils.py Line 1791 in 0a8638d
to this tensor_data = [val for val in source.values() if not is_non_tensor(val) and val.ndim > 0] Which will ignore 0-dim tensors (cast or not) considering that these do not have a batch size. Does this sound reasonable or does it create another problem? |
hmmm no bc the concept of the batch-size is "the common leading dim of all tensors" You can run this under #1213 from tensordict import UnbatchedTensor, TensorDict
import torch
td = TensorDict(a=UnbatchedTensor(0), b=torch.randn(10, 11)).auto_batch_size_(batch_dims=2)
assert td.shape == (10, 11) |
Well, it would not be unreasonable to interpret this as "the common leading dim of all tensors that have dim" Still, even using NonTensorData is not enough though in general because the non-tensor-type does not transfer when assigning from one tensordict to another. (which comes back to your comment about casting again) Meaning: import torch
from tensordict import TensorDict, NonTensorData
td = TensorDict(a=NonTensorData(True))
assert isinstance(td["a"], bool) # when accessing a NonTensorData it returns the underlying type
td["a"] = td["a"] # but setting a boolean -- imagine this for the non trivial operation of setting td["a"] from td2["a"]
assert isinstance(td["a"], torch.Tensor) # ends up with a tensor While returning the underlying type for non-tensor data sure makes sense (and I like it this way since the only reason I use the NonTensorDict is to avoid the casting etc) it seems weird that running a seemingly no-op, the underlying type changes and creates again the problems with batch size as explained here, etc. I can still do something like: td["a"] = NonTensorData(td["a"]) if isinstance(td["a"], bool) else td["a"] But this seems overly convoluted just to be able to have a single boolean metadata inside the tensordict that does not affect the batch size. Why is it preferred to cast non-tensor (and even more non-sequence) data to tensors in general? (I haven't yet tested UnbatchedTensor to check if the same thing happens and we end up with a plain tensor on assignment though it seems quite probable) |
You raise a good point there, that's something I can fix. Re "the common leading dim of all tensors that have dim" the problem with this is that sometimes we want to do
Which is useful to preallocate data on disk. |
I can understand why this would not be possible if the casting of integer to tensor does not happen, however why would this not be possible with the changes in
Also, can you share an insight about this? I mean, apart from the historical fact you said, I can understand casting lists, arrays, etc, but intuitively I would expect the casting to ignore the primitive python types. We can always explicitly write |
Describe the bug
When using bool data, those are transformed to 0-dim tensors internally. As a result,
auto_batch_size_
can't infer the batch size, and also setting the batch size raises an exception.To Reproduce
Expected behavior
Bool should be handled like string and other non-tensor data.
Reason and Possible fixes
The reason is that bool arguments are internally transformed into tensor data. While I understand that using tensors as an internal representation might be more efficient, maybe we should ignore tensors with
.ndim == 0
in automatically calculating the batch size and also in_check_new_batch_size
.Checklist
The text was updated successfully, but these errors were encountered: