Skip to content

Commit

Permalink
Implement ScalarLoop in torch backend (pymc-devs#958)
Browse files Browse the repository at this point in the history
* Add for loop based scalar loop

* Pass all loop tests

* Fetch constants from op

* Add while loop test

* Fix while loop and nasty stack over dtypes

* Disable compile here based on CI result

* Fix mypy signature

* Remove unnecessary torch stack

* Only call .cpu when necessary

* Recursive false for torch compiler

* Add elemwise test

* Late import torch

* Do iteration instead of vmap for elemwise

* Clean up and add description

* Add unit test to verify iteration

* Refactor to ravel method

* Fix unpacking

Co-authored-by: Ricardo Vieira <[email protected]>

* Fix comment

* Remove extra return

* Update test

* Add single carry test

* Remove compiler disable

* Better name

* Lint

* Better docstring

* Pr comments

---------

Co-authored-by: Ian Schweer <[email protected]>
Co-authored-by: Ricardo Vieira <[email protected]>
  • Loading branch information
3 people authored Dec 8, 2024
1 parent 07bd48d commit 9858b33
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 9 deletions.
39 changes: 39 additions & 0 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
Expand All @@ -11,6 +12,7 @@
@pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op

base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)

def check_special_scipy(func_name):
Expand All @@ -33,6 +35,9 @@ def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)

elif isinstance(scalar_op, ScalarLoop):
return elemwise_ravel_fn(base_fn, op, node, **kwargs)

else:

def elemwise_fn(*inputs):
Expand Down Expand Up @@ -176,3 +181,37 @@ def softmax_grad(dy, sm):
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm

return softmax_grad


def elemwise_ravel_fn(base_fn, op, node, **kwargs):
"""
Dispatch methods using `.item()` (ScalarLoop + Elemwise) is common, but vmap
in torch has a limitation: https://github.com/pymc-devs/pytensor/issues/1031,
Instead, we can ravel all the inputs, broadcasted according to torch
"""

n_outputs = len(node.outputs)

def elemwise_fn(*inputs):
bcasted_inputs = torch.broadcast_tensors(*inputs)
raveled_inputs = [inp.ravel() for inp in bcasted_inputs]

out_shape = bcasted_inputs[0].size()
out_size = out_shape.numel()
raveled_outputs = [torch.empty(out_size) for out in node.outputs]

for i in range(out_size):
core_outs = base_fn(*(inp[i] for inp in raveled_inputs))
if n_outputs == 1:
raveled_outputs[0][i] = core_outs
else:
for o in range(n_outputs):
raveled_outputs[o][i] = core_outs[o]

outputs = tuple(out.view(out_shape) for out in raveled_outputs)
if n_outputs == 1:
return outputs[0]
else:
return outputs

return elemwise_fn
35 changes: 35 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Cast,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Softplus


Expand Down Expand Up @@ -62,3 +63,37 @@ def cast(x):
@pytorch_funcify.register(Softplus)
def pytorch_funcify_Softplus(op, node, **kwargs):
return torch.nn.Softplus()


@pytorch_funcify.register(ScalarLoop)
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
update = pytorch_funcify(op.fgraph, **kwargs)
state_length = op.nout
if op.is_while:

def scalar_loop(steps, *start_and_constants):
carry, constants = (
start_and_constants[:state_length],
start_and_constants[state_length:],
)
done = True
for _ in range(steps):
*carry, done = update(*carry, *constants)
if torch.any(done):
break
return *carry, done
else:

def scalar_loop(steps, *start_and_constants):
carry, constants = (
start_and_constants[:state_length],
start_and_constants[state_length:],
)
for _ in range(steps):
carry = update(*carry, *constants)
if len(node.outputs) == 1:
return carry[0]
else:
return carry

return scalar_loop
14 changes: 5 additions & 9 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,34 +54,30 @@ def __init__(self, fn, gen_functors):
self.fn = torch.compile(fn)
self.gen_functors = gen_functors.copy()

def __call__(self, *args, **kwargs):
def __call__(self, *inputs, **kwargs):
import pytensor.link.utils

# set attrs
for n, fn in self.gen_functors:
setattr(pytensor.link.utils, n[1:], fn)

res = self.fn(*args, **kwargs)
# Torch does not accept numpy inputs and may return GPU objects
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)

# unset attrs
for n, _ in self.gen_functors:
if getattr(pytensor.link.utils, n[1:], False):
delattr(pytensor.link.utils, n[1:])

return res
return tuple(out.cpu().numpy() for out in outs)

def __del__(self):
del self.gen_functors

inner_fn = wrapper(fn, self.gen_functors)
self.gen_functors = []

# Torch does not accept numpy inputs and may return GPU objects
def fn(*inputs, inner_fn=inner_fn):
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs))
return tuple(out.cpu().numpy() for out in outs)

return fn
return inner_fn

def create_thunk_inputs(self, storage_map):
thunk_inputs = []
Expand Down
86 changes: 86 additions & 0 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

import pytensor.tensor as pt
import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
Expand All @@ -17,7 +18,10 @@
from pytensor.ifelse import ifelse
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import float64, int64
from pytensor.scalar.loop import ScalarLoop
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.type import matrices, matrix, scalar, vector


Expand Down Expand Up @@ -385,3 +389,85 @@ def test_pytorch_softplus():
out = softplus(x)
f = FunctionGraph([x], [out])
compare_pytorch_and_py(f, [np.random.rand(3)])


def test_ScalarLoop():
n_steps = int64("n_steps")
x0 = float64("x0")
const = float64("const")
x = x0 + const

op = ScalarLoop(init=[x0], constant=[const], update=[x])
x = op(n_steps, x0, const)

fn = function([n_steps, x0, const], x, mode=pytorch_mode)
np.testing.assert_allclose(fn(5, 0, 1), 5)
np.testing.assert_allclose(fn(5, 0, 2), 10)
np.testing.assert_allclose(fn(4, 3, -1), -1)


def test_ScalarLoop_while():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 + 1
until = x >= 10

op = ScalarLoop(init=[x0], update=[x], until=until)
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
for res, expected in zip(
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
[[10, True], [10, True], [6, False]],
strict=True,
):
np.testing.assert_allclose(res[0], np.array(expected[0]))
np.testing.assert_allclose(res[1], np.array(expected[1]))


def test_ScalarLoop_Elemwise_single_carries():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 * 2
until = x >= 10

scalarop = ScalarLoop(init=[x0], update=[x], until=until)
op = Elemwise(scalarop)

n_steps = pt.scalar("n_steps", dtype="int32")
x0 = pt.vector("x0", dtype="float32")
state, done = op(n_steps, x0)

f = FunctionGraph([n_steps, x0], [state, done])
args = [
np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"),
]
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)


def test_ScalarLoop_Elemwise_multi_carries():
n_steps = int64("n_steps")
x0 = float64("x0")
x1 = float64("x1")
x = x0 * 2
x1_n = x1 * 3
until = x >= 10

scalarop = ScalarLoop(init=[x0, x1], update=[x, x1_n], until=until)
op = Elemwise(scalarop)

n_steps = pt.scalar("n_steps", dtype="int32")
x0 = pt.vector("x0", dtype="float32")
x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1))
*states, done = op(n_steps, x0, x1)

f = FunctionGraph([n_steps, x0, x1], [*states, done])
args = [
np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"),
np.random.rand(7, 3, 1).astype("float32"),
]
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)

0 comments on commit 9858b33

Please sign in to comment.