Skip to content

Commit

Permalink
Fix torch.Tensor.to
Browse files Browse the repository at this point in the history
  • Loading branch information
and-ivanov committed Apr 13, 2023
1 parent 547f8e5 commit 7457a4a
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 7 deletions.
66 changes: 59 additions & 7 deletions src/sten/sten.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,60 @@ def sparse_tensor_data_set(base_impl, func, types, *args, **kwargs):
sparse_data_set(lhs, rhs)


def list_get(l, idx, default=None):
return l[idx] if len(l) > idx else default


# torch.Tensor.to allows signatures not distinguishable using only Python call semantic.
# Hence, it requires custom wrapper.
@implements(torch.Tensor.to)
def sparse_torch_tensor_to(base_impl, func, types, *args, **kwargs):
# There are following accepted signatures
# self, device=None, dtype=None, non_blocking=False, copy=False, *, memory_format=torch.preserve_format
# self, dtype, non_blocking=False, copy=False, *, memory_format=torch.preserve_format
# self, other, non_blocking=False, copy=False, *, memory_format=torch.preserve_format

if isinstance(list_get(args, 1), torch.dtype):
# case 2
device = args[0].device
dtype = args[1]
tail_args = args[2:]
elif isinstance(list_get(args, 1), torch.Tensor):
# case 3
device = args[1].device
dtype = args[1].dtype
tail_args = args[2:]
else:
# case 1
device = list_get(args, 1, kwargs.get("device"))
dtype = list_get(args, 2, kwargs.get("dtype"))
tail_args = args[3:]

non_blocking = list_get(tail_args, 0, kwargs.get("non_blocking", False))
copy = list_get(tail_args, 1, kwargs.get("copy", False))
memory_format = kwargs.get("memory_format", torch.preserve_format)

# return self if device and dtype are not changed
if device == args[0].device and dtype == args[0].dtype and not copy:
return args[0]

return sparse_fallback_with_backprop(
func,
args=[args[0]],
kwargs={
"device": device,
"dtype": dtype,
"non_blocking": non_blocking,
"copy": copy,
"memory_format": memory_format,
},
# in contrast to most purely functional operators
# this method is supposed to return tensor of the *same sparse format*
out_fmt=[(SameFormatSparsifier(args[0]), args[0].wrapped_tensor.__class__, KeepAll(), args[0].wrapped_tensor.__class__)],
grad_out_fmt=[args[0].grad_fmt],
)


def sparse_tensor_builder(
wrapper_type, wrapped_tensor_container, requires_grad, grad_fmt
):
Expand Down Expand Up @@ -493,7 +547,9 @@ def get_op_semantics(op):
return Semantics.Function


def sparse_fallback_with_backprop(func, args, kwargs):
def sparse_fallback_with_backprop(
func, args, kwargs, *, out_fmt=None, grad_out_fmt=None
):
# don't create new instance of sparsified op if it already exists at least for one input
all_sparse_tensors = flattened_sparse_tensors(args) + flattened_sparse_tensors(
kwargs
Expand All @@ -509,7 +565,7 @@ def sparse_fallback_with_backprop(func, args, kwargs):
break
if op is None:
# if operator is not created earlier, create it again
op = sparsified_op(func, None, None)
op = sparsified_op(func, out_fmt, grad_out_fmt)
# assign newly created operator to all participating tensors
for t in all_sparse_tensors:
if not hasattr(t, "sparsified_ops"):
Expand Down Expand Up @@ -1266,11 +1322,7 @@ def my_operator(ctx, grad_outputs, input_sparsifiers):
torch.Tensor.eq: lambda input, other: -1,
torch.mul: lambda input, other, *, out=None: -1,
torch.abs: lambda input, *, out=None: -1,
torch.Tensor.to: (
lambda dtype, non_blocking=False, copy=False, *, memory_format=torch.preserve_format: -1,
lambda device=None, dtype=None, non_blocking=False, copy=False, *, memory_format=torch.preserve_format: -1,
lambda other, non_blocking=False, *, copy=False: -1,
),
torch.Tensor.to: lambda self, device=None, dtype=None, non_blocking=False, copy=False, *, memory_format=torch.preserve_format: -1,
torch.addmm: lambda input, mat1, mat2, *, beta=1, alpha=1, out=None: -1,
torch.zeros_like: lambda input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None: -1,
torch.ones_like: lambda input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format: -1,
Expand Down
61 changes: 61 additions & 0 deletions tests/test_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sten
from sten import SparseTensorWrapper, SparseParameterWrapper, DenseTensor
import torch
import pytest
import os


def test_add():
Expand Down Expand Up @@ -232,6 +234,64 @@ def test_ones_like():
assert torch.allclose(torch.ones_like(sx), torch.ones_like(dx))


def test_to():
shape = (3, 3)
dx = torch.full(shape, 2.0, requires_grad=True, dtype=torch.float32)
sparsifier = sten.ScalarFractionSparsifier(0.5)
sx = SparseTensorWrapper.wrapped_from_dense(
sten.MaskedSparseTensor(
sten.scalar_mask_sparsify(dx, sparsifier.fraction),
sparsifier,
),
dx,
None,
)

assert sx.device == torch.device("cpu")
assert sx.dtype == torch.float32
assert sx.wrapped_tensor.data.device == torch.device("cpu")
assert sx.wrapped_tensor.data.dtype == torch.float32

# Signature:
# tensor.to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format)
sx64 = sx.to(torch.float64)
assert sx64.dtype == torch.float64
assert sx64.wrapped_tensor.data.dtype == torch.float64

if torch.cuda.device_count() == 0:
if "PYTEST_CURRENT_TEST" in os.environ:
pytest.skip("No CUDA-capable device found")
else:
return

# Signature:
# tensor.to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format)
sx_cuda = sx.to("cuda")
assert sx_cuda.device.type == "cuda"
assert sx_cuda.wrapped_tensor.data.device.type == "cuda"

# Signature:
# tensor.to(other, non_blocking=False, copy=False)

other = torch.full(
shape, 2.0, requires_grad=True, device=torch.device("cuda"), dtype=torch.float64
)
sx_other = sx.to(other)

assert sx_other.dtype == torch.float64
assert sx_other.wrapped_tensor.data.dtype == torch.float64
assert sx_other.device.type == "cuda"
assert sx_other.wrapped_tensor.data.device.type == "cuda"

# check correctess of copy semantics
sx_same = sx.to(dx)
assert id(sx) == id(sx_same)
sx_copy = sx.to(dx, copy=True)
assert id(sx) != id(sx_copy)
assert id(sx.wrapped_tensor) != id(sx_copy.wrapped_tensor)
assert id(sx.wrapped_tensor.data) != id(sx_copy.wrapped_tensor.data)


if __name__ == "__main__":
test_add()
test_add_()
Expand All @@ -245,3 +305,4 @@ def test_ones_like():
test_sizes()
test_scalar_mul()
test_ones_like()
test_to()

0 comments on commit 7457a4a

Please sign in to comment.