From 605960048edc753ffa3abea7c7c274827e758d01 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 3 Apr 2021 05:34:38 -0400 Subject: [PATCH 01/12] initial commit --- funsor/factory.py | 55 ++++++++++++++++++++++++++++++++++---------- test/test_factory.py | 25 +++++++++++++++++--- 2 files changed, 65 insertions(+), 15 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index 127f55b0..e319f6bd 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -10,7 +10,7 @@ import makefun from funsor.instrument import debug_logged -from funsor.terms import Funsor, FunsorMeta, Variable, eager, to_funsor +from funsor.terms import Funsor, FunsorMeta, Subs, Variable, eager, to_funsor from funsor.util import as_callable @@ -67,6 +67,15 @@ class Bound: pass +class BindReturn(Bound): + """ + Type hint for :func:`make_funsor` decorated functions. This provides hints + for bind-return variables (names). + """ + + pass + + class ValueMeta(type): def __getitem__(cls, value_type): return Value(value_type) @@ -177,21 +186,24 @@ def Unflatten( """ input_types = typing.get_type_hints(as_callable(fn)) for name, hint in input_types.items(): - if not (hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value, Has))): + if not (hint in (Funsor, Bound, BindReturn) or isinstance(hint, (Fresh, Value, Has))): raise TypeError(f"Invalid type hint {name}: {hint}") output_type = input_types.pop("return") hints = tuple(input_types.values()) class ResultMeta(FunsorMeta): - def __call__(cls, *args): + def __call__(cls, *args, bind_return=None): args = list(args) + bind_return_names = [] # Compute domains of bound variables. for i, (name, arg) in enumerate(zip(cls._ast_fields, args)): hint = input_types[name] if hint is Funsor or isinstance(hint, Has): # TODO support domains args[i] = to_funsor(arg) - elif hint is Bound: + elif hint in (Bound, BindReturn): + if hint is BindReturn: + bind_return_names.append(arg) for other in args: if isinstance(other, Funsor): domain = other.inputs.get(arg, None) @@ -207,6 +219,9 @@ def __call__(cls, *args): ) args[i] = arg + if bind_return is None: + bind_return = frozenset((name, name) for name in bind_return_names) + args.append(bind_return) # Compute domains of fresh variables. dependent_args = _get_dependent_args(cls._ast_fields, hints, args) for i, (hint, arg) in enumerate(zip(hints, args)): @@ -216,21 +231,24 @@ def __call__(cls, *args): return super().__call__(*args) @makefun.with_signature( - "__init__({})".format(", ".join(["self"] + list(input_types))) + "__init__({})".format(", ".join(["self"] + list(input_types) + ["bind_return"])) ) def __init__(self, **kwargs): + # breakpoint() + bind_return = dict(kwargs["bind_return"]) args = tuple(kwargs[k] for k in self._ast_fields) dependent_args = _get_dependent_args(self._ast_fields, hints, args) output = output_type(**dependent_args) inputs = OrderedDict() bound = {} + fresh = frozenset() for hint, arg, arg_name in zip(hints, args, self._ast_fields): if hint is Funsor: assert isinstance(arg, Funsor) - inputs.update(arg.inputs) + inputs.update((bind_return.get(k, k), v) for k, v in arg.inputs.items()) elif isinstance(hint, Has): assert isinstance(arg, Funsor) - inputs.update(arg.inputs) + inputs.update((bind_return.get(k, k), v) for k, v in arg.inputs.items()) for name in hint.bound: if kwargs[name] not in arg.input_vars: warnings.warn( @@ -241,19 +259,32 @@ def __init__(self, **kwargs): for hint, arg in zip(hints, args): if hint is Bound: bound[arg.name] = inputs.pop(arg.name) + elif hint is BindReturn: + bound[arg.name] = inputs[bind_return[arg.name]] for hint, arg in zip(hints, args): if isinstance(hint, Fresh): for k, d in arg.inputs.items(): if k not in bound: inputs[k] = d - fresh = frozenset() + fresh |= frozenset({k}) + fresh |= frozenset(bind_return.values()) Funsor.__init__(self, inputs, output, fresh, bound) for name, arg in zip(self._ast_fields, args): + if name == "bind_return": + arg = dict(arg) setattr(self, name, arg) def _alpha_convert(self, alpha_subs): + bind_return = frozenset( + (alpha_subs.get(k, k), v) for k, v in self.bind_return.items() + ) alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - return Funsor._alpha_convert(self, alpha_subs) + return Funsor._alpha_convert(self, alpha_subs)[:-1] + (bind_return,) + + def new_fn(*args, **kwargs): + args, bind_return = args[:-1], args[-1] + result = fn(*args, **kwargs) + return Subs(result, bind_return) name = _get_name(fn) ResultMeta.__name__ = f"{name}Meta" @@ -261,9 +292,9 @@ def _alpha_convert(self, alpha_subs): name, (Funsor,), {"__init__": __init__, "_alpha_convert": _alpha_convert} ) pattern = (Result,) + tuple( - _hint_to_pattern(input_types[k]) for k in Result._ast_fields - ) - eager.register(*pattern)(_erase_types(fn)) + _hint_to_pattern(input_types[k]) for k in Result._ast_fields if k != "bind_return" + ) + (frozenset,) + eager.register(*pattern)(_erase_types(new_fn)) return Result diff --git a/test/test_factory.py b/test/test_factory.py index aff82a9a..961163d1 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -8,13 +8,14 @@ import funsor.ops as ops from funsor.domains import Array, Bint, Real, Reals -from funsor.factory import Bound, Fresh, Has, Value, make_funsor, to_funsor +from funsor.factory import BindReturn, Bound, Fresh, Has, Value, make_funsor, to_funsor from funsor.interpretations import reflect from funsor.interpreter import reinterpret from funsor.tensor import Tensor -from funsor.terms import Cat, Funsor, Lambda, Number, eager +from funsor.terms import Cat, Funsor, Lambda, Number, eager, lazy from funsor.testing import assert_close, check_funsor, random_tensor from funsor.util import get_backend +from funsor.optimizer import apply_optimizer def test_lambda_lambda(): @@ -67,7 +68,9 @@ def Flatten21( inputs["a"] = Bint[3] inputs["b"] = Bint[4] data = random_tensor(inputs, Real) - x = Flatten21(data, "a", "b", "ab") + with lazy: + x = Flatten21(data, "a", "b", "ab") + breakpoint() assert isinstance(x, Tensor) check_funsor(x, {"ab": Bint[12]}, Real, data.data.reshape(-1)) @@ -297,3 +300,19 @@ def MatMul( # To preserve extensionality, should only error on reflect xy = MatMul(x, y, "b") check_funsor(xy, {"a": Bint[3], "c": Bint[4], "d": Bint[3]}, Real) + + +def test_softmax(): + @make_funsor + def Softmax( + x: Funsor, + ax: BindReturn, + ) -> Fresh[lambda x: x]: + y = x - x.reduce(ops.logaddexp, ax) + return y.exp() + + x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4])) + with reflect: + y = Softmax(x, "a") + z = apply_optimizer(y) + check_funsor(y, {"a": Bint[3], "b": Bint[4]}, Real) From 8cb18bcad02f53e73a75582c085e0a53c0ea8966 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 3 Apr 2021 15:54:56 -0400 Subject: [PATCH 02/12] BindReturn --- funsor/factory.py | 77 ++++++++++++++++++++++++++++++-------------- test/test_factory.py | 13 ++++---- 2 files changed, 60 insertions(+), 30 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index e319f6bd..d70e375a 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -67,10 +67,10 @@ class Bound: pass -class BindReturn(Bound): +class BindReturn: """ Type hint for :func:`make_funsor` decorated functions. This provides hints - for bind-return variables (names). + for variables (names) that are bound and returned. """ pass @@ -161,6 +161,7 @@ def make_funsor(fn): - Funsor inputs are typed :class:`~funsor.terms.Funsor`. - Bound variable inputs (names) are typed :class:`Bound`. + - Bind and Return variable inputs (names) are typed :class:`BindReturn`. - Fresh variable inputs (names) are typed :class:`Fresh` together with lambda to compute the dependent domain. - Ground value inputs (e.g. Python ints) are typed :class:`Value` together with @@ -186,7 +187,9 @@ def Unflatten( """ input_types = typing.get_type_hints(as_callable(fn)) for name, hint in input_types.items(): - if not (hint in (Funsor, Bound, BindReturn) or isinstance(hint, (Fresh, Value, Has))): + if not ( + hint in (Funsor, Bound, BindReturn) or isinstance(hint, (Fresh, Value, Has)) + ): raise TypeError(f"Invalid type hint {name}: {hint}") output_type = input_types.pop("return") hints = tuple(input_types.values()) @@ -194,7 +197,12 @@ def Unflatten( class ResultMeta(FunsorMeta): def __call__(cls, *args, bind_return=None): args = list(args) - bind_return_names = [] + + # Bind-and-return variables + if bind_return is None: + bind_return = frozenset( + (arg, arg) for hint, arg in zip(hints, args) if hint is BindReturn + ) # Compute domains of bound variables. for i, (name, arg) in enumerate(zip(cls._ast_fields, args)): @@ -202,8 +210,6 @@ def __call__(cls, *args, bind_return=None): if hint is Funsor or isinstance(hint, Has): # TODO support domains args[i] = to_funsor(arg) elif hint in (Bound, BindReturn): - if hint is BindReturn: - bind_return_names.append(arg) for other in args: if isinstance(other, Funsor): domain = other.inputs.get(arg, None) @@ -219,24 +225,29 @@ def __call__(cls, *args, bind_return=None): ) args[i] = arg - if bind_return is None: - bind_return = frozenset((name, name) for name in bind_return_names) - args.append(bind_return) # Compute domains of fresh variables. dependent_args = _get_dependent_args(cls._ast_fields, hints, args) for i, (hint, arg) in enumerate(zip(hints, args)): if isinstance(hint, Fresh): domain = hint(**dependent_args) args[i] = to_funsor(arg, domain) + + # Append bind_return to args + if bind_return: + args.append(bind_return) return super().__call__(*args) + if BindReturn in hints: + bind_return = ["bind_return"] + else: + bind_return = [] + @makefun.with_signature( - "__init__({})".format(", ".join(["self"] + list(input_types) + ["bind_return"])) + "__init__({})".format(", ".join(["self"] + list(input_types) + bind_return)) ) def __init__(self, **kwargs): - # breakpoint() - bind_return = dict(kwargs["bind_return"]) args = tuple(kwargs[k] for k in self._ast_fields) + bind_return = dict(kwargs.get("bind_return", dict())) dependent_args = _get_dependent_args(self._ast_fields, hints, args) output = output_type(**dependent_args) inputs = OrderedDict() @@ -275,25 +286,43 @@ def __init__(self, **kwargs): setattr(self, name, arg) def _alpha_convert(self, alpha_subs): - bind_return = frozenset( - (alpha_subs.get(k, k), v) for k, v in self.bind_return.items() - ) + if hasattr(self, "bind_return"): + bind_return = frozenset( + (alpha_subs.get(k, k), v) for k, v in self.bind_return.items() + ) + alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} + return Funsor._alpha_convert(self, alpha_subs)[:-1] + (bind_return,) alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - return Funsor._alpha_convert(self, alpha_subs)[:-1] + (bind_return,) - - def new_fn(*args, **kwargs): - args, bind_return = args[:-1], args[-1] - result = fn(*args, **kwargs) - return Subs(result, bind_return) + return Funsor._alpha_convert(self, alpha_subs) + + if BindReturn in hints: + def new_fn(*args, **kwargs): + args, bind_return = args[:-1], args[-1] + result = fn(*args) + if bind_return: + result = Subs(result, bind_return) + return result + else: + new_fn = fn name = _get_name(fn) ResultMeta.__name__ = f"{name}Meta" Result = ResultMeta( name, (Funsor,), {"__init__": __init__, "_alpha_convert": _alpha_convert} ) - pattern = (Result,) + tuple( - _hint_to_pattern(input_types[k]) for k in Result._ast_fields if k != "bind_return" - ) + (frozenset,) + if BindReturn in hints: + bind_return_pattern = (frozenset,) + else: + bind_return_pattern = () + pattern = ( + (Result,) + + tuple( + _hint_to_pattern(input_types[k]) + for k in Result._ast_fields + if k != "bind_return" + ) + + bind_return_pattern + ) eager.register(*pattern)(_erase_types(new_fn)) return Result diff --git a/test/test_factory.py b/test/test_factory.py index 961163d1..df8afd85 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -11,11 +11,11 @@ from funsor.factory import BindReturn, Bound, Fresh, Has, Value, make_funsor, to_funsor from funsor.interpretations import reflect from funsor.interpreter import reinterpret +from funsor.optimizer import apply_optimizer from funsor.tensor import Tensor from funsor.terms import Cat, Funsor, Lambda, Number, eager, lazy from funsor.testing import assert_close, check_funsor, random_tensor from funsor.util import get_backend -from funsor.optimizer import apply_optimizer def test_lambda_lambda(): @@ -68,9 +68,7 @@ def Flatten21( inputs["a"] = Bint[3] inputs["b"] = Bint[4] data = random_tensor(inputs, Real) - with lazy: - x = Flatten21(data, "a", "b", "ab") - breakpoint() + x = Flatten21(data, "a", "b", "ab") assert isinstance(x, Tensor) check_funsor(x, {"ab": Bint[12]}, Real, data.data.reshape(-1)) @@ -314,5 +312,8 @@ def Softmax( x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4])) with reflect: y = Softmax(x, "a") - z = apply_optimizer(y) - check_funsor(y, {"a": Bint[3], "b": Bint[4]}, Real) + assert y.fresh == frozenset({"a"}) + assert all(bound in y.x.inputs for bound in y.bound) + assert isinstance(apply_optimizer(x), Tensor) + z = reinterpret(y) + check_funsor(z, {"a": Bint[3], "b": Bint[4]}, Real) From 8deba32f082c0d28e2ac5107f87ea33a0c24eeee Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 3 Apr 2021 16:21:51 -0400 Subject: [PATCH 03/12] make BindReturn like Fresh --- funsor/factory.py | 39 +++++++++++++++++++++++++-------------- test/test_factory.py | 2 +- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index d70e375a..bcd088e8 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -67,13 +67,24 @@ class Bound: pass -class BindReturn: +class BindReturnMeta(type): + def __getitem__(cls, fn): + return BindReturn(fn) + + +class BindReturn(metaclass=BindReturnMeta): """ Type hint for :func:`make_funsor` decorated functions. This provides hints for variables (names) that are bound and returned. """ - pass + def __init__(self, fn): + function = type(lambda: None) + self.fn = fn if isinstance(fn, function) else lambda: fn + self.args = inspect.getfullargspec(fn)[0] + + def __call__(self, **kwargs): + return self.fn(*map(kwargs.__getitem__, self.args)) class ValueMeta(type): @@ -146,7 +157,7 @@ def _get_dependent_args(fields, hints, args): return { name: arg if isinstance(hint, Value) else arg.output for name, arg, hint in zip(fields, args, hints) - if hint in (Funsor, Bound) or isinstance(hint, (Has, Value)) + if hint in (Funsor, Bound) or isinstance(hint, (Has, Value, BindReturn)) } @@ -188,12 +199,17 @@ def Unflatten( input_types = typing.get_type_hints(as_callable(fn)) for name, hint in input_types.items(): if not ( - hint in (Funsor, Bound, BindReturn) or isinstance(hint, (Fresh, Value, Has)) + hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value, Has, BindReturn)) ): raise TypeError(f"Invalid type hint {name}: {hint}") output_type = input_types.pop("return") hints = tuple(input_types.values()) + if any(isinstance(hint, BindReturn) for hint in hints): + bind_return = ["bind_return"] + else: + bind_return = [] + class ResultMeta(FunsorMeta): def __call__(cls, *args, bind_return=None): args = list(args) @@ -201,7 +217,7 @@ def __call__(cls, *args, bind_return=None): # Bind-and-return variables if bind_return is None: bind_return = frozenset( - (arg, arg) for hint, arg in zip(hints, args) if hint is BindReturn + (arg, arg) for hint, arg in zip(hints, args) if isinstance(hint, BindReturn) ) # Compute domains of bound variables. @@ -209,7 +225,7 @@ def __call__(cls, *args, bind_return=None): hint = input_types[name] if hint is Funsor or isinstance(hint, Has): # TODO support domains args[i] = to_funsor(arg) - elif hint in (Bound, BindReturn): + elif hint is Bound or isinstance(hint, BindReturn): for other in args: if isinstance(other, Funsor): domain = other.inputs.get(arg, None) @@ -228,7 +244,7 @@ def __call__(cls, *args, bind_return=None): # Compute domains of fresh variables. dependent_args = _get_dependent_args(cls._ast_fields, hints, args) for i, (hint, arg) in enumerate(zip(hints, args)): - if isinstance(hint, Fresh): + if isinstance(hint, (Fresh, BindReturn)): domain = hint(**dependent_args) args[i] = to_funsor(arg, domain) @@ -237,11 +253,6 @@ def __call__(cls, *args, bind_return=None): args.append(bind_return) return super().__call__(*args) - if BindReturn in hints: - bind_return = ["bind_return"] - else: - bind_return = [] - @makefun.with_signature( "__init__({})".format(", ".join(["self"] + list(input_types) + bind_return)) ) @@ -295,7 +306,7 @@ def _alpha_convert(self, alpha_subs): alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} return Funsor._alpha_convert(self, alpha_subs) - if BindReturn in hints: + if bind_return: def new_fn(*args, **kwargs): args, bind_return = args[:-1], args[-1] result = fn(*args) @@ -310,7 +321,7 @@ def new_fn(*args, **kwargs): Result = ResultMeta( name, (Funsor,), {"__init__": __init__, "_alpha_convert": _alpha_convert} ) - if BindReturn in hints: + if bind_return: bind_return_pattern = (frozenset,) else: bind_return_pattern = () diff --git a/test/test_factory.py b/test/test_factory.py index df8afd85..819e45d8 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -304,7 +304,7 @@ def test_softmax(): @make_funsor def Softmax( x: Funsor, - ax: BindReturn, + ax: BindReturn[lambda ax: ax], ) -> Fresh[lambda x: x]: y = x - x.reduce(ops.logaddexp, ax) return y.exp() From 281fa5d80bc6f118b3c700f686e7891357c360ce Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 3 Apr 2021 21:39:58 -0400 Subject: [PATCH 04/12] test unroll --- funsor/factory.py | 35 ++++++++++++++++++++++------------- test/test_factory.py | 26 +++++++++++++++++++++++--- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index bcd088e8..93123b39 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -10,7 +10,7 @@ import makefun from funsor.instrument import debug_logged -from funsor.terms import Funsor, FunsorMeta, Subs, Variable, eager, to_funsor +from funsor.terms import Funsor, FunsorMeta, Subs, Variable, eager, to_funsor, substitute from funsor.util import as_callable @@ -244,9 +244,13 @@ def __call__(cls, *args, bind_return=None): # Compute domains of fresh variables. dependent_args = _get_dependent_args(cls._ast_fields, hints, args) for i, (hint, arg) in enumerate(zip(hints, args)): - if isinstance(hint, (Fresh, BindReturn)): + if isinstance(hint, Fresh): domain = hint(**dependent_args) args[i] = to_funsor(arg, domain) + elif isinstance(hint, BindReturn): + domain = hint(**dependent_args) + args[i] = to_funsor(arg.name, domain) + # Append bind_return to args if bind_return: @@ -267,12 +271,12 @@ def __init__(self, **kwargs): for hint, arg, arg_name in zip(hints, args, self._ast_fields): if hint is Funsor: assert isinstance(arg, Funsor) - inputs.update((bind_return.get(k, k), v) for k, v in arg.inputs.items()) + inputs.update(arg.inputs) elif isinstance(hint, Has): assert isinstance(arg, Funsor) - inputs.update((bind_return.get(k, k), v) for k, v in arg.inputs.items()) + inputs.update(arg.inputs) for name in hint.bound: - if kwargs[name] not in arg.input_vars: + if kwargs[name].name not in arg.inputs: warnings.warn( f"Argument {arg_name} is missing bound variable {kwargs[name]} from argument {name}." f"Are you sure {name} will always appear in {arg_name}?", @@ -281,8 +285,9 @@ def __init__(self, **kwargs): for hint, arg in zip(hints, args): if hint is Bound: bound[arg.name] = inputs.pop(arg.name) - elif hint is BindReturn: - bound[arg.name] = inputs[bind_return[arg.name]] + elif isinstance(hint, BindReturn): + bound[arg.name] = inputs.pop(arg.name) + inputs[bind_return[arg.name]] = arg.output for hint, arg in zip(hints, args): if isinstance(hint, Fresh): for k, d in arg.inputs.items(): @@ -297,14 +302,18 @@ def __init__(self, **kwargs): setattr(self, name, arg) def _alpha_convert(self, alpha_subs): + result = [] + new_alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} + for hint, field, value in zip(hints, self._ast_fields, self._ast_values): + if isinstance(hint, BindReturn): + result.append(to_funsor(alpha_subs[value.name], value.output)) + else: + result.append(substitute(value, new_alpha_subs)) if hasattr(self, "bind_return"): - bind_return = frozenset( + result.append(frozenset( (alpha_subs.get(k, k), v) for k, v in self.bind_return.items() - ) - alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - return Funsor._alpha_convert(self, alpha_subs)[:-1] + (bind_return,) - alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - return Funsor._alpha_convert(self, alpha_subs) + )) + return tuple(result) if bind_return: def new_fn(*args, **kwargs): diff --git a/test/test_factory.py b/test/test_factory.py index 819e45d8..818bdcc2 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -300,10 +300,30 @@ def MatMul( check_funsor(xy, {"a": Bint[3], "c": Bint[4], "d": Bint[3]}, Real) +def test_unroll(): + @make_funsor + def Unroll( + x: Has[{"ax"}], + ax: BindReturn[lambda ax, k: Bint[ax.size - k + 1]], + k: Value[int], + kernel: Fresh[lambda k: Bint[k]] + ) -> Fresh[lambda x: x]: + return x(**{ax.name: ax + kernel}) + + x = random_tensor(OrderedDict(a=Bint[5])) + with reflect: + y = Unroll(x, "a", 2, "kernel") + assert y.fresh == frozenset({"a", "kernel"}) + assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound) + z = reinterpret(y) + assert isinstance(z, Tensor) + check_funsor(z, {"a": Bint[4], "kernel": Bint[2]}, Real) + + def test_softmax(): @make_funsor def Softmax( - x: Funsor, + x: Has[{"ax"}], ax: BindReturn[lambda ax: ax], ) -> Fresh[lambda x: x]: y = x - x.reduce(ops.logaddexp, ax) @@ -313,7 +333,7 @@ def Softmax( with reflect: y = Softmax(x, "a") assert y.fresh == frozenset({"a"}) - assert all(bound in y.x.inputs for bound in y.bound) - assert isinstance(apply_optimizer(x), Tensor) + assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound) z = reinterpret(y) + assert isinstance(z, Tensor) check_funsor(z, {"a": Bint[3], "b": Bint[4]}, Real) From 7dd0750e151388bd733e3c14d5b7cbb4a1a985ac Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 3 Apr 2021 21:42:33 -0400 Subject: [PATCH 05/12] lint --- funsor/factory.py | 25 +++++++++++++++++++------ test/test_factory.py | 9 ++++----- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index 93123b39..a0dad037 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -10,7 +10,15 @@ import makefun from funsor.instrument import debug_logged -from funsor.terms import Funsor, FunsorMeta, Subs, Variable, eager, to_funsor, substitute +from funsor.terms import ( + Funsor, + FunsorMeta, + Subs, + Variable, + eager, + substitute, + to_funsor, +) from funsor.util import as_callable @@ -217,7 +225,9 @@ def __call__(cls, *args, bind_return=None): # Bind-and-return variables if bind_return is None: bind_return = frozenset( - (arg, arg) for hint, arg in zip(hints, args) if isinstance(hint, BindReturn) + (arg, arg) + for hint, arg in zip(hints, args) + if isinstance(hint, BindReturn) ) # Compute domains of bound variables. @@ -251,7 +261,6 @@ def __call__(cls, *args, bind_return=None): domain = hint(**dependent_args) args[i] = to_funsor(arg.name, domain) - # Append bind_return to args if bind_return: args.append(bind_return) @@ -310,18 +319,22 @@ def _alpha_convert(self, alpha_subs): else: result.append(substitute(value, new_alpha_subs)) if hasattr(self, "bind_return"): - result.append(frozenset( - (alpha_subs.get(k, k), v) for k, v in self.bind_return.items() - )) + result.append( + frozenset( + (alpha_subs.get(k, k), v) for k, v in self.bind_return.items() + ) + ) return tuple(result) if bind_return: + def new_fn(*args, **kwargs): args, bind_return = args[:-1], args[-1] result = fn(*args) if bind_return: result = Subs(result, bind_return) return result + else: new_fn = fn diff --git a/test/test_factory.py b/test/test_factory.py index 818bdcc2..75354033 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -11,9 +11,8 @@ from funsor.factory import BindReturn, Bound, Fresh, Has, Value, make_funsor, to_funsor from funsor.interpretations import reflect from funsor.interpreter import reinterpret -from funsor.optimizer import apply_optimizer from funsor.tensor import Tensor -from funsor.terms import Cat, Funsor, Lambda, Number, eager, lazy +from funsor.terms import Cat, Funsor, Lambda, Number, eager from funsor.testing import assert_close, check_funsor, random_tensor from funsor.util import get_backend @@ -303,10 +302,10 @@ def MatMul( def test_unroll(): @make_funsor def Unroll( - x: Has[{"ax"}], + x: Has[{"ax"}], # noqa: F821 ax: BindReturn[lambda ax, k: Bint[ax.size - k + 1]], k: Value[int], - kernel: Fresh[lambda k: Bint[k]] + kernel: Fresh[lambda k: Bint[k]], ) -> Fresh[lambda x: x]: return x(**{ax.name: ax + kernel}) @@ -323,7 +322,7 @@ def Unroll( def test_softmax(): @make_funsor def Softmax( - x: Has[{"ax"}], + x: Has[{"ax"}], # noqa: F821 ax: BindReturn[lambda ax: ax], ) -> Fresh[lambda x: x]: y = x - x.reduce(ops.logaddexp, ax) From b2e5d32106f3eeba34aa8c0bb4803087ec5889de Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 3 Apr 2021 21:47:42 -0400 Subject: [PATCH 06/12] reorganize --- funsor/factory.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index a0dad037..f73ff089 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -214,9 +214,19 @@ def Unflatten( hints = tuple(input_types.values()) if any(isinstance(hint, BindReturn) for hint in hints): - bind_return = ["bind_return"] + bind_return_kwarg = ["bind_return"] + bind_return_pattern = (frozenset,) + + def new_fn(*args, **kwargs): + args, bind_return = args[:-1], args[-1] + result = fn(*args) + if bind_return: + result = Subs(result, bind_return) + return result else: - bind_return = [] + bind_return_kwarg = [] + bind_return_pattern = () + new_fn = fn class ResultMeta(FunsorMeta): def __call__(cls, *args, bind_return=None): @@ -267,7 +277,7 @@ def __call__(cls, *args, bind_return=None): return super().__call__(*args) @makefun.with_signature( - "__init__({})".format(", ".join(["self"] + list(input_types) + bind_return)) + "__init__({})".format(", ".join(["self"] + list(input_types) + bind_return_kwarg)) ) def __init__(self, **kwargs): args = tuple(kwargs[k] for k in self._ast_fields) @@ -326,27 +336,11 @@ def _alpha_convert(self, alpha_subs): ) return tuple(result) - if bind_return: - - def new_fn(*args, **kwargs): - args, bind_return = args[:-1], args[-1] - result = fn(*args) - if bind_return: - result = Subs(result, bind_return) - return result - - else: - new_fn = fn - name = _get_name(fn) ResultMeta.__name__ = f"{name}Meta" Result = ResultMeta( name, (Funsor,), {"__init__": __init__, "_alpha_convert": _alpha_convert} ) - if bind_return: - bind_return_pattern = (frozenset,) - else: - bind_return_pattern = () pattern = ( (Result,) + tuple( From a2a651b75976368ff5aa1992428ce9db41b02012 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 3 Apr 2021 21:56:15 -0400 Subject: [PATCH 07/12] misc --- funsor/factory.py | 2 +- test/test_factory.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index f73ff089..ead5f8d9 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -217,7 +217,7 @@ def Unflatten( bind_return_kwarg = ["bind_return"] bind_return_pattern = (frozenset,) - def new_fn(*args, **kwargs): + def new_fn(*args): args, bind_return = args[:-1], args[-1] result = fn(*args) if bind_return: diff --git a/test/test_factory.py b/test/test_factory.py index 75354033..3195e4f6 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -316,7 +316,7 @@ def Unroll( assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound) z = reinterpret(y) assert isinstance(z, Tensor) - check_funsor(z, {"a": Bint[4], "kernel": Bint[2]}, Real) + check_funsor(z, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real) def test_softmax(): From 3cc3214987733925d0de0bae65468ab2f916fc14 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 3 Apr 2021 22:07:44 -0400 Subject: [PATCH 08/12] lint --- funsor/factory.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/funsor/factory.py b/funsor/factory.py index ead5f8d9..62306488 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -223,6 +223,7 @@ def new_fn(*args): if bind_return: result = Subs(result, bind_return) return result + else: bind_return_kwarg = [] bind_return_pattern = () @@ -277,7 +278,9 @@ def __call__(cls, *args, bind_return=None): return super().__call__(*args) @makefun.with_signature( - "__init__({})".format(", ".join(["self"] + list(input_types) + bind_return_kwarg)) + "__init__({})".format( + ", ".join(["self"] + list(input_types) + bind_return_kwarg) + ) ) def __init__(self, **kwargs): args = tuple(kwargs[k] for k in self._ast_fields) From a21c972039c0b0cfda54c8d999ceb3f71838a1de Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 3 Apr 2021 22:16:15 -0400 Subject: [PATCH 09/12] remove _ast_fields in alpha_convert --- funsor/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/factory.py b/funsor/factory.py index 62306488..0ab47bf3 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -326,7 +326,7 @@ def __init__(self, **kwargs): def _alpha_convert(self, alpha_subs): result = [] new_alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - for hint, field, value in zip(hints, self._ast_fields, self._ast_values): + for hint, value in zip(hints, self._ast_values): if isinstance(hint, BindReturn): result.append(to_funsor(alpha_subs[value.name], value.output)) else: From b876c6da63d23f9b38d8c3fa8bb0bad3d9f6d07c Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 7 Apr 2021 04:47:52 -0400 Subject: [PATCH 10/12] replace Fresh with BindReturn --- funsor/factory.py | 56 ++++++++++++++++---------------------------- test/test_factory.py | 6 ++--- 2 files changed, 23 insertions(+), 39 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index 0ab47bf3..111ed3f1 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -75,26 +75,6 @@ class Bound: pass -class BindReturnMeta(type): - def __getitem__(cls, fn): - return BindReturn(fn) - - -class BindReturn(metaclass=BindReturnMeta): - """ - Type hint for :func:`make_funsor` decorated functions. This provides hints - for variables (names) that are bound and returned. - """ - - def __init__(self, fn): - function = type(lambda: None) - self.fn = fn if isinstance(fn, function) else lambda: fn - self.args = inspect.getfullargspec(fn)[0] - - def __call__(self, **kwargs): - return self.fn(*map(kwargs.__getitem__, self.args)) - - class ValueMeta(type): def __getitem__(cls, value_type): return Value(value_type) @@ -165,7 +145,9 @@ def _get_dependent_args(fields, hints, args): return { name: arg if isinstance(hint, Value) else arg.output for name, arg, hint in zip(fields, args, hints) - if hint in (Funsor, Bound) or isinstance(hint, (Has, Value, BindReturn)) + if hint in (Funsor, Bound) + or isinstance(hint, (Has, Value)) + or (isinstance(hint, Fresh) and name in hint.args) } @@ -180,7 +162,6 @@ def make_funsor(fn): - Funsor inputs are typed :class:`~funsor.terms.Funsor`. - Bound variable inputs (names) are typed :class:`Bound`. - - Bind and Return variable inputs (names) are typed :class:`BindReturn`. - Fresh variable inputs (names) are typed :class:`Fresh` together with lambda to compute the dependent domain. - Ground value inputs (e.g. Python ints) are typed :class:`Value` together with @@ -207,13 +188,16 @@ def Unflatten( input_types = typing.get_type_hints(as_callable(fn)) for name, hint in input_types.items(): if not ( - hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value, Has, BindReturn)) + hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value, Has)) ): raise TypeError(f"Invalid type hint {name}: {hint}") output_type = input_types.pop("return") hints = tuple(input_types.values()) - if any(isinstance(hint, BindReturn) for hint in hints): + if any( + isinstance(hint, Fresh) and arg in hint.args + for arg, hint in input_types.items() + ): bind_return_kwarg = ["bind_return"] bind_return_pattern = (frozenset,) @@ -237,8 +221,8 @@ def __call__(cls, *args, bind_return=None): if bind_return is None: bind_return = frozenset( (arg, arg) - for hint, arg in zip(hints, args) - if isinstance(hint, BindReturn) + for hint, arg, arg_name in zip(hints, args, cls._ast_fields) + if isinstance(hint, Fresh) and arg_name in hint.args ) # Compute domains of bound variables. @@ -246,7 +230,7 @@ def __call__(cls, *args, bind_return=None): hint = input_types[name] if hint is Funsor or isinstance(hint, Has): # TODO support domains args[i] = to_funsor(arg) - elif hint is Bound or isinstance(hint, BindReturn): + elif hint is Bound or (isinstance(hint, Fresh) and name in hint.args): for other in args: if isinstance(other, Funsor): domain = other.inputs.get(arg, None) @@ -264,13 +248,13 @@ def __call__(cls, *args, bind_return=None): # Compute domains of fresh variables. dependent_args = _get_dependent_args(cls._ast_fields, hints, args) - for i, (hint, arg) in enumerate(zip(hints, args)): - if isinstance(hint, Fresh): - domain = hint(**dependent_args) - args[i] = to_funsor(arg, domain) - elif isinstance(hint, BindReturn): + for i, (hint, arg, arg_name) in enumerate(zip(hints, args, cls._ast_fields)): + if isinstance(hint, Fresh) and arg_name in hint.args: domain = hint(**dependent_args) args[i] = to_funsor(arg.name, domain) + elif isinstance(hint, Fresh): + domain = hint(**dependent_args) + args[i] = to_funsor(arg, domain) # Append bind_return to args if bind_return: @@ -304,10 +288,10 @@ def __init__(self, **kwargs): f"Are you sure {name} will always appear in {arg_name}?", SyntaxWarning, ) - for hint, arg in zip(hints, args): + for hint, arg, arg_name in zip(hints, args, self._ast_fields): if hint is Bound: bound[arg.name] = inputs.pop(arg.name) - elif isinstance(hint, BindReturn): + elif isinstance(hint, Fresh) and arg_name in hint.args: bound[arg.name] = inputs.pop(arg.name) inputs[bind_return[arg.name]] = arg.output for hint, arg in zip(hints, args): @@ -326,8 +310,8 @@ def __init__(self, **kwargs): def _alpha_convert(self, alpha_subs): result = [] new_alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - for hint, value in zip(hints, self._ast_values): - if isinstance(hint, BindReturn): + for hint, value, arg_name in zip(hints, self._ast_values, self._ast_fields): + if isinstance(hint, Fresh) and arg_name in hint.args: result.append(to_funsor(alpha_subs[value.name], value.output)) else: result.append(substitute(value, new_alpha_subs)) diff --git a/test/test_factory.py b/test/test_factory.py index 3195e4f6..3efa7908 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -8,7 +8,7 @@ import funsor.ops as ops from funsor.domains import Array, Bint, Real, Reals -from funsor.factory import BindReturn, Bound, Fresh, Has, Value, make_funsor, to_funsor +from funsor.factory import Bound, Fresh, Has, Value, make_funsor, to_funsor from funsor.interpretations import reflect from funsor.interpreter import reinterpret from funsor.tensor import Tensor @@ -303,7 +303,7 @@ def test_unroll(): @make_funsor def Unroll( x: Has[{"ax"}], # noqa: F821 - ax: BindReturn[lambda ax, k: Bint[ax.size - k + 1]], + ax: Fresh[lambda ax, k: Bint[ax.size - k + 1]], k: Value[int], kernel: Fresh[lambda k: Bint[k]], ) -> Fresh[lambda x: x]: @@ -323,7 +323,7 @@ def test_softmax(): @make_funsor def Softmax( x: Has[{"ax"}], # noqa: F821 - ax: BindReturn[lambda ax: ax], + ax: Fresh[lambda ax: ax], ) -> Fresh[lambda x: x]: y = x - x.reduce(ops.logaddexp, ax) return y.exp() From ba295c2552a03e92eef479e296b7091933e549fc Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 7 Apr 2021 23:36:59 -0400 Subject: [PATCH 11/12] simplify --- funsor/factory.py | 49 ++++++++++++++++---------------------------- test/test_factory.py | 4 ++-- 2 files changed, 20 insertions(+), 33 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index 111ed3f1..ef116b3e 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -187,32 +187,25 @@ def Unflatten( """ input_types = typing.get_type_hints(as_callable(fn)) for name, hint in input_types.items(): - if not ( - hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value, Has)) - ): + if not (hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value, Has))): raise TypeError(f"Invalid type hint {name}: {hint}") - output_type = input_types.pop("return") - hints = tuple(input_types.values()) - if any( isinstance(hint, Fresh) and arg in hint.args for arg, hint in input_types.items() ): - bind_return_kwarg = ["bind_return"] - bind_return_pattern = (frozenset,) + input_types["bind_return"] = Value[frozenset] def new_fn(*args): args, bind_return = args[:-1], args[-1] result = fn(*args) - if bind_return: - result = Subs(result, bind_return) - return result + return Subs(result, bind_return) else: - bind_return_kwarg = [] - bind_return_pattern = () new_fn = fn + output_type = input_types.pop("return") + hints = tuple(input_types.values()) + class ResultMeta(FunsorMeta): def __call__(cls, *args, bind_return=None): args = list(args) @@ -248,7 +241,9 @@ def __call__(cls, *args, bind_return=None): # Compute domains of fresh variables. dependent_args = _get_dependent_args(cls._ast_fields, hints, args) - for i, (hint, arg, arg_name) in enumerate(zip(hints, args, cls._ast_fields)): + for i, (hint, arg, arg_name) in enumerate( + zip(hints, args, cls._ast_fields) + ): if isinstance(hint, Fresh) and arg_name in hint.args: domain = hint(**dependent_args) args[i] = to_funsor(arg.name, domain) @@ -262,9 +257,7 @@ def __call__(cls, *args, bind_return=None): return super().__call__(*args) @makefun.with_signature( - "__init__({})".format( - ", ".join(["self"] + list(input_types) + bind_return_kwarg) - ) + "__init__({})".format(", ".join(["self"] + list(input_types))) ) def __init__(self, **kwargs): args = tuple(kwargs[k] for k in self._ast_fields) @@ -313,14 +306,14 @@ def _alpha_convert(self, alpha_subs): for hint, value, arg_name in zip(hints, self._ast_values, self._ast_fields): if isinstance(hint, Fresh) and arg_name in hint.args: result.append(to_funsor(alpha_subs[value.name], value.output)) + elif arg_name == "bind_return": + result.append( + frozenset( + (alpha_subs.get(k, k), v) for k, v in self.bind_return.items() + ) + ) else: result.append(substitute(value, new_alpha_subs)) - if hasattr(self, "bind_return"): - result.append( - frozenset( - (alpha_subs.get(k, k), v) for k, v in self.bind_return.items() - ) - ) return tuple(result) name = _get_name(fn) @@ -328,14 +321,8 @@ def _alpha_convert(self, alpha_subs): Result = ResultMeta( name, (Funsor,), {"__init__": __init__, "_alpha_convert": _alpha_convert} ) - pattern = ( - (Result,) - + tuple( - _hint_to_pattern(input_types[k]) - for k in Result._ast_fields - if k != "bind_return" - ) - + bind_return_pattern + pattern = (Result,) + tuple( + _hint_to_pattern(input_types[k]) for k in Result._ast_fields ) eager.register(*pattern)(_erase_types(new_fn)) return Result diff --git a/test/test_factory.py b/test/test_factory.py index 3efa7908..74c884ba 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -313,7 +313,7 @@ def Unroll( with reflect: y = Unroll(x, "a", 2, "kernel") assert y.fresh == frozenset({"a", "kernel"}) - assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound) + assert all(bound in y.x.inputs and "__BOUND" in bound for bound in y.bound) z = reinterpret(y) assert isinstance(z, Tensor) check_funsor(z, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real) @@ -332,7 +332,7 @@ def Softmax( with reflect: y = Softmax(x, "a") assert y.fresh == frozenset({"a"}) - assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound) + assert all(bound in y.x.inputs and "__BOUND" in bound for bound in y.bound) z = reinterpret(y) assert isinstance(z, Tensor) check_funsor(z, {"a": Bint[3], "b": Bint[4]}, Real) From 3802f722f8f154efd18a16fb6c597ceb60c79b41 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 8 Apr 2021 01:09:30 -0400 Subject: [PATCH 12/12] misc --- funsor/factory.py | 9 ++++----- test/test_factory.py | 2 ++ 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index ef116b3e..cf2f51a3 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -287,13 +287,12 @@ def __init__(self, **kwargs): elif isinstance(hint, Fresh) and arg_name in hint.args: bound[arg.name] = inputs.pop(arg.name) inputs[bind_return[arg.name]] = arg.output + fresh |= frozenset({bind_return[arg.name]}) for hint, arg in zip(hints, args): if isinstance(hint, Fresh): - for k, d in arg.inputs.items(): - if k not in bound: - inputs[k] = d - fresh |= frozenset({k}) - fresh |= frozenset(bind_return.values()) + if arg.name not in bound: + inputs[arg.name] = arg.output + fresh |= frozenset({arg.name}) Funsor.__init__(self, inputs, output, fresh, bound) for name, arg in zip(self._ast_fields, args): if name == "bind_return": diff --git a/test/test_factory.py b/test/test_factory.py index 74c884ba..887a0eff 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -314,6 +314,7 @@ def Unroll( y = Unroll(x, "a", 2, "kernel") assert y.fresh == frozenset({"a", "kernel"}) assert all(bound in y.x.inputs and "__BOUND" in bound for bound in y.bound) + check_funsor(y, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real) z = reinterpret(y) assert isinstance(z, Tensor) check_funsor(z, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real) @@ -333,6 +334,7 @@ def Softmax( y = Softmax(x, "a") assert y.fresh == frozenset({"a"}) assert all(bound in y.x.inputs and "__BOUND" in bound for bound in y.bound) + check_funsor(y, {"a": Bint[3], "b": Bint[4]}, Real) z = reinterpret(y) assert isinstance(z, Tensor) check_funsor(z, {"a": Bint[3], "b": Bint[4]}, Real)