From 7428c7b137886c443a3c988112d2a669f913bc71 Mon Sep 17 00:00:00 2001 From: Eric Lippert Date: Tue, 18 Oct 2022 16:03:34 -0700 Subject: [PATCH] Refactoring and minor bug fixing in requirements fixer Summary: The requirements-fixing code is not easy to understand, and some small errors and inefficiencies have crept in. I'm going to fix them in this diff to make it then easier to add a broadcasting fixer in an upcoming diff. In this diff: * elementwise multiplication now correctly creates the requirement that the input types be identical with the output type. (As noted, in an upcoming diff we'll take care of the situation where inputs of dissimilar shape can be fixed by broadcasting.) * Fixed a typo with a misspelling of "requirements" * Methods `_meet_real_matrix_requirement_type`, and `_meet_real_matrix_requirement` were misordered and redundant to other logic in the operator requirements fixer. By more carefully ordering the cases in the operator requirements fixer, we can simply delete these redundant methods. * As noted, `_meet_operator_requirement` was too long and confusing. I've broken it up into many smaller cases each with its own logic for applying a simple fix. * Making these changes had the beneficial side effect of fixing a very small bug in a test case. In test_fix_vectorized_models_8 we have a multiplication of a matrix of probabilities and a constant matrix with all positive real values. The previous code converted both operands to a real matrix, but we can legally convert them to positive real matrix, a more specific type. We now do so. Reviewed By: AishwaryaSivaraman Differential Revision: D40459965 fbshipit-source-id: 6da742006341fef204d9fd9fc81a723e145b98b6 --- .../ppl/compiler/bmg_requirements.py | 24 +- .../ppl/compiler/fix_requirements.py | 254 +++++++++++------- src/beanmachine/ppl/compiler/lattice_typer.py | 5 + tests/ppl/compiler/broadcast_test.py | 1 + .../compiler/fix_vectorized_models_test.py | 2 +- 5 files changed, 170 insertions(+), 116 deletions(-) diff --git a/src/beanmachine/ppl/compiler/bmg_requirements.py b/src/beanmachine/ppl/compiler/bmg_requirements.py index 41dd2c8f03..3e09a2871e 100644 --- a/src/beanmachine/ppl/compiler/bmg_requirements.py +++ b/src/beanmachine/ppl/compiler/bmg_requirements.py @@ -35,16 +35,6 @@ # what their inputs are. _known_requirements: Dict[type, List[bt.Requirement]] = { - # TODO: This is wrong in several ways. - # First, RealMatrix does not meet the contract of a requirement; - # in particular, it cannot be printed out by the requirement diagnostic - # in gen_to_dot. - # Second, it is too strict; the requirement on matrix add is actually - # that the two operands be any double matrix (real, neg real, - # pos real or probability). - # Third, this requirement is too weak; we are missing the requirement - # that the operands have the same element type and shape. - bn.ElementwiseMultiplyNode: [bt.RealMatrix, bt.RealMatrix], bn.Observation: [bt.any_requirement], bn.Query: [bt.any_requirement], # Distributions @@ -75,6 +65,7 @@ # don't check them. bn.LogisticNode: [bt.Real], bn.Log1mexpNode: [bt.NegativeReal], + # TODO: Check the dimensions. Consider broadcasting if possible. bn.MatrixMultiplicationNode: [bt.any_real_matrix, bt.any_real_matrix], bn.MatrixExpNode: [bt.any_real_matrix], bn.MatrixLogNode: [bt.any_pos_real_matrix], @@ -115,6 +106,7 @@ def __init__(self, typer: LatticeTyper) -> None: bn.ChoiceNode: self._requirements_choice, bn.ColumnIndexNode: self._requirements_column_index, bn.ComplementNode: self._same_as_output, + bn.ElementwiseMultiplyNode: self._requirements_elementwise_mult, bn.ExpM1Node: self._same_as_output, bn.ExpNode: self._requirements_exp_neg, bn.IfThenElseNode: self._requirements_if, @@ -124,7 +116,7 @@ def __init__(self, typer: LatticeTyper) -> None: bn.LogSumExpVectorNode: self._requirements_logsumexp_vector, # TODO: bn.MatrixMultiplyNode: self._requirements_matrix_multiply, # see comment above - bn.MatrixComplementNode: self._requrirements_matrix_complement, + bn.MatrixComplementNode: self._requirements_matrix_complement, bn.MatrixAddNode: self._requirements_matrix_add, bn.MatrixScaleNode: self._requirements_matrix_scale, bn.MultiplicationNode: self._requirements_multiplication, @@ -428,7 +420,7 @@ def _requirements_multiplication( assert it in {bt.Probability, bt.PositiveReal, bt.Real} return [it] * len(node.inputs) # pyre-ignore - def _requrirements_matrix_complement( + def _requirements_matrix_complement( self, node: bn.MatrixComplementNode ) -> List[bt.Requirement]: it = self.typer[node] @@ -446,6 +438,14 @@ def _requrirements_matrix_complement( req = [bt.SimplexMatrix] return req + def _requirements_elementwise_mult( + self, node: bn.ElementwiseMultiplyNode + ) -> List[bt.Requirement]: + # Elementwise multiply requires that both operands be the same as the output type. + it = self.typer[node] + assert isinstance(it, bt.BMGMatrixType) + return [it, it] + def _requirements_matrix_add(self, node: bn.MatrixAddNode) -> List[bt.Requirement]: # Matrix add requires that both operands be the same as the output type. it = self.typer[node] diff --git a/src/beanmachine/ppl/compiler/fix_requirements.py b/src/beanmachine/ppl/compiler/fix_requirements.py index 953f4802fa..1e97136f5c 100644 --- a/src/beanmachine/ppl/compiler/fix_requirements.py +++ b/src/beanmachine/ppl/compiler/fix_requirements.py @@ -11,7 +11,7 @@ returned.""" -from typing import Tuple +from typing import Optional import beanmachine.ppl.compiler.bmg_nodes as bn import beanmachine.ppl.compiler.bmg_types as bt @@ -160,9 +160,7 @@ def _meet_constant_requirement( result = self.bmg.add_constant_of_matrix_type(node.value, required_type) else: result = self.bmg.add_constant_of_type(node.value, required_type) - assert self._node_meets_requirement( - result, requirement - ), f"{str(result)} {str(requirement)} {str(required_type)} {str(self._typer[result])} {str(self._type_meets_requirement(self._typer[result], requirement))}" + assert self._node_meets_requirement(result, requirement) return result # We cannot convert this node to any type that meets the requirement. @@ -320,47 +318,103 @@ def _can_force_to_neg_real( or requirement == bt.upper_bound(bt.NegativeReal) ) and node_type == bt.Real - def _meet_real_matrix_requirement_type( - self, node: bn.OperatorNode, node_dim: Tuple[int, int] - ) -> bn.BMGNode: - if node_dim[0] == 1 and node_dim[1] == 1: - result = self.bmg.add_to_real(node) - else: - result = self.bmg.add_to_real_matrix(node) + def _try_to_meet_any_real_matrix_requirement( + self, + node: bn.OperatorNode, + requirement: bt.Requirement, + ) -> Optional[bn.BMGNode]: + + assert not self._node_meets_requirement(node, requirement) + + # Is the requirement that we have a real-valued matrix, but we haven't got + # a real-valued matrix? Every value can be converted to a real-valued matrix, + # so just insert the conversion node. + + if requirement is not bt.any_real_matrix: + return None + + result = self.bmg.add_to_real_matrix(node) + assert self._node_meets_requirement(result, requirement) return result - def _meet_real_matrix_requirement( + def _try_to_meet_any_pos_real_matrix_requirement( self, node: bn.OperatorNode, - dim_req: Tuple[int, int], - node_dim: Tuple[int, int], + requirement: bt.Requirement, + ) -> Optional[bn.BMGNode]: + + assert not self._node_meets_requirement(node, requirement) + + # Is the requirement that we have a pos-real-valued matrix? Anything that + # is not known to be negative can be a positive real matrix. + + if requirement is not bt.any_pos_real_matrix: + return None + + node_type = self._typer[node] + if isinstance(node_type, bt.NegativeRealMatrix): + return None + + result = self.bmg.add_to_positive_real_matrix(node) + assert self._node_meets_requirement(result, requirement) + return result + + def _try_to_meet_upper_bound_requirement( + self, + node: bn.OperatorNode, + requirement: bt.Requirement, consumer: bn.BMGNode, edge: str, - ) -> bn.BMGNode: - result = None - req_rows, req_cols = dim_req - node_rows, node_cols = node_dim - node_is_scalar = node_rows == 1 and node_cols == 1 - requires_scalar = req_rows == 1 and req_cols == 1 - if requires_scalar and node_is_scalar: - result = self.bmg.add_to_real(node) - elif node_rows == req_rows and node_cols == req_cols: - result = self.bmg.add_to_real_matrix(node) - - if result is None: - self.errors.add_error( - Violation( - node, - self._typer[node], - bt.RealMatrix(req_rows, req_cols), - consumer, - edge, - self.bmg.execution_context.node_locations(consumer), - ) + ) -> Optional[bn.BMGNode]: + + assert not self._node_meets_requirement(node, requirement) + + node_type = self._typer[node] + if not self._type_meets_requirement(node_type, bt.upper_bound(requirement)): + return None + + # If we got here then the node did NOT meet the requirement, + # but its type DID meet an upper bound requirement, which + # implies that the requirement was not an upper bound requirement. + assert not isinstance(requirement, bt.UpperBound) + + # We definitely can meet the requirement by inserting some sort + # of conversion logic. We have different helper methods for + # the atomic type and matrix type cases. + if bt.must_be_matrix(requirement): + result = self._convert_operator_to_matrix_type( + node, requirement, consumer, edge ) - return node + else: + assert isinstance(requirement, bt.BMGLatticeType) + result = self._convert_operator_to_atomic_type( + node, requirement, consumer, edge + ) + assert self._node_meets_requirement(result, requirement) return result + def _try_to_force_to_prob(self, node, requirement) -> Optional[bn.BMGNode]: + # We cannot make the node meet the requirement "implicitly". We can + # "explicitly" meet a requirement of probability if we have a + # real or pos real. + + node_type = self._typer[node] + if not self._can_force_to_prob(node_type, requirement): + return None + assert node_type == bt.Real or node_type == bt.PositiveReal + assert self._node_meets_requirement(node, node_type) + return self.bmg.add_to_probability(node) + + def _try_to_force_to_neg_real(self, node, requirement) -> Optional[bn.BMGNode]: + # We cannot make the node meet the requirement "implicitly". We can + # "explicitly" meet a requirement of neg real if we have a value we do + # not know is positive. + node_type = self._typer[node] + if not self._can_force_to_neg_real(node_type, requirement): + return None + + return self.bmg.add_to_negative_real(node) + def _meet_operator_requirement( self, node: bn.OperatorNode, @@ -368,79 +422,73 @@ def _meet_operator_requirement( consumer: bn.BMGNode, edge: str, ) -> bn.BMGNode: - # If the operator node already meets the requirement, we're done. + # We should not have called this function if the input node already meets + # the requirement on the edge. + assert not self._node_meets_requirement(node, requirement) - # It does not meet the requirement. Can we convert this thing to a node - # whose type does meet the requirement? The lattice type is the - # smallest type that this node is convertible to, so if the lattice type - # meets an upper bound requirement, then the conversion we want exists. + # ---- + # + # TODO: Is the problem that we have a scalar but we need a matrix full + # of that value? Generate a matrix fill operation. + # + # TODO: Is the problem that we have a row or column matrix but we need + # a rectangular matrix? Generate a broadcast operation. + # + # TODO: Note that in either of these cases, we might *also* need to + # generate a type conversion, so we might not meet the requirement on + # after introducing the fill / broadcast node. + # + # ---- + + # Is the requirement that we have a real-valued matrix, but we haven't got + # a real-valued matrix? Every value can be converted to a real-valued matrix, + # so that's the easiest case. Knock it out first. + + result = self._try_to_meet_any_real_matrix_requirement(node, requirement) + if result is not None: + return result + + # Is the requirement that we have any positive real-valued matrix? Every value + # except negative real scalars and matrices can be converted to a positive real + # matrix. + + result = self._try_to_meet_any_pos_real_matrix_requirement(node, requirement) + if result is not None: + return result + + # If we weaken the requirement to an upper bound requirement, do we meet it? If so, + # then there is a conversion node we can add. + + result = self._try_to_meet_upper_bound_requirement( + node, requirement, consumer, edge + ) + if result is not None: + return result + + result = self._try_to_force_to_prob(node, requirement) + if result is not None: + return result node_type = self._typer[node] - if isinstance(node_type, bt.BMGMatrixType): - rows = node_type.rows - columns = node_type.columns - else: - rows = 1 - columns = 1 - if isinstance(requirement, bt.RealMatrix): - return self._meet_real_matrix_requirement( + + result = self._try_to_force_to_neg_real(node, requirement) + if result is not None: + return result + + # Those are the only techniques we have to make an operator meet a requirement. + # We have no way to make the conversion we need, so add an error. + self.errors.add_error( + Violation( node, - dim_req=(requirement.rows, requirement.columns), - node_dim=(rows, columns), - consumer=consumer, - edge=edge, - ) - elif requirement == bt.RealMatrix: - return self._meet_real_matrix_requirement_type(node, (rows, columns)) - elif requirement is bt.any_real_matrix: - result = self.bmg.add_to_real_matrix(node) - elif requirement is bt.any_pos_real_matrix: - result = self.bmg.add_to_positive_real_matrix(node) - elif self._type_meets_requirement(node_type, bt.upper_bound(requirement)): - # If we got here then the node did NOT meet the requirement, - # but its type DID meet an upper bound requirement, which - # implies that the requirement was not an upper bound requirement. - assert not isinstance(requirement, bt.UpperBound) - - # We definitely can meet the requirement by inserting some sort - # of conversion logic. We have different helper methods for - # the atomic type and matrix type cases. - if bt.must_be_matrix(requirement): - result = self._convert_operator_to_matrix_type( - node, requirement, consumer, edge - ) - else: - assert isinstance(requirement, bt.BMGLatticeType) - result = self._convert_operator_to_atomic_type( - node, requirement, consumer, edge - ) - elif self._can_force_to_prob(node_type, requirement): - # We cannot make the node meet the requirement "implicitly". We can - # "explicitly" meet a requirement of probability if we have a - # real or pos real. - assert node_type == bt.Real or node_type == bt.PositiveReal - assert self._node_meets_requirement(node, node_type) - result = self.bmg.add_to_probability(node) - elif self._can_force_to_neg_real(node_type, requirement): - # Similarly if we have a real but need a negative real - result = self.bmg.add_to_negative_real(node) - else: - # We have no way to make the conversion we need, so add an error. - self.errors.add_error( - Violation( - node, - node_type, - requirement, - consumer, - edge, - self.bmg.execution_context.node_locations(consumer), - ) + node_type, + requirement, + consumer, + edge, + self.bmg.execution_context.node_locations(consumer), ) - return node - - assert self._node_meets_requirement(result, requirement) - return result + ) + return node def _check_requirement_validity( self, diff --git a/src/beanmachine/ppl/compiler/lattice_typer.py b/src/beanmachine/ppl/compiler/lattice_typer.py index 4ab1d1d220..488b8e3b5a 100644 --- a/src/beanmachine/ppl/compiler/lattice_typer.py +++ b/src/beanmachine/ppl/compiler/lattice_typer.py @@ -238,6 +238,11 @@ def _lattice_type_for_element_type( def _type_binary_elementwise_op( self, node: bn.BinaryOperatorNode ) -> bt.BMGLatticeType: + # Elementwise multiplication and addition require that the operands be + # of the same type and size, and that's the resulting type. Rather than + # enforcing that here, find the supremum of the element types and a size + # where both operands can be broadcast to that size. We'll then add the + # appropriate broadcast nodes in the requirements fixer. left_type = self[node.left] right_type = self[node.right] assert isinstance(left_type, bt.BMGMatrixType) diff --git a/tests/ppl/compiler/broadcast_test.py b/tests/ppl/compiler/broadcast_test.py index 2542137bb7..a0a9699920 100644 --- a/tests/ppl/compiler/broadcast_test.py +++ b/tests/ppl/compiler/broadcast_test.py @@ -32,6 +32,7 @@ def broadcast_add(): class BroadcastTest(unittest.TestCase): + # TODO: Test broadcast multiplication as well. def test_broadcast_add(self) -> None: self.maxDiff = None observations = {} diff --git a/tests/ppl/compiler/fix_vectorized_models_test.py b/tests/ppl/compiler/fix_vectorized_models_test.py index ae2262ba50..eda81cff0e 100644 --- a/tests/ppl/compiler/fix_vectorized_models_test.py +++ b/tests/ppl/compiler/fix_vectorized_models_test.py @@ -878,7 +878,7 @@ def test_fix_vectorized_models_8(self) -> None: N07[label=2]; N08[label=1]; N09[label=ToMatrix]; - N10[label=ToRealMatrix]; + N10[label=ToPosRealMatrix]; N11[label="[5.0,6.0]"]; N12[label=ElementwiseMult]; N13[label=Query];