diff --git a/src/beanmachine/ppl/compiler/bmg_requirements.py b/src/beanmachine/ppl/compiler/bmg_requirements.py index 41dd2c8f03..3e09a2871e 100644 --- a/src/beanmachine/ppl/compiler/bmg_requirements.py +++ b/src/beanmachine/ppl/compiler/bmg_requirements.py @@ -35,16 +35,6 @@ # what their inputs are. _known_requirements: Dict[type, List[bt.Requirement]] = { - # TODO: This is wrong in several ways. - # First, RealMatrix does not meet the contract of a requirement; - # in particular, it cannot be printed out by the requirement diagnostic - # in gen_to_dot. - # Second, it is too strict; the requirement on matrix add is actually - # that the two operands be any double matrix (real, neg real, - # pos real or probability). - # Third, this requirement is too weak; we are missing the requirement - # that the operands have the same element type and shape. - bn.ElementwiseMultiplyNode: [bt.RealMatrix, bt.RealMatrix], bn.Observation: [bt.any_requirement], bn.Query: [bt.any_requirement], # Distributions @@ -75,6 +65,7 @@ # don't check them. bn.LogisticNode: [bt.Real], bn.Log1mexpNode: [bt.NegativeReal], + # TODO: Check the dimensions. Consider broadcasting if possible. bn.MatrixMultiplicationNode: [bt.any_real_matrix, bt.any_real_matrix], bn.MatrixExpNode: [bt.any_real_matrix], bn.MatrixLogNode: [bt.any_pos_real_matrix], @@ -115,6 +106,7 @@ def __init__(self, typer: LatticeTyper) -> None: bn.ChoiceNode: self._requirements_choice, bn.ColumnIndexNode: self._requirements_column_index, bn.ComplementNode: self._same_as_output, + bn.ElementwiseMultiplyNode: self._requirements_elementwise_mult, bn.ExpM1Node: self._same_as_output, bn.ExpNode: self._requirements_exp_neg, bn.IfThenElseNode: self._requirements_if, @@ -124,7 +116,7 @@ def __init__(self, typer: LatticeTyper) -> None: bn.LogSumExpVectorNode: self._requirements_logsumexp_vector, # TODO: bn.MatrixMultiplyNode: self._requirements_matrix_multiply, # see comment above - bn.MatrixComplementNode: self._requrirements_matrix_complement, + bn.MatrixComplementNode: self._requirements_matrix_complement, bn.MatrixAddNode: self._requirements_matrix_add, bn.MatrixScaleNode: self._requirements_matrix_scale, bn.MultiplicationNode: self._requirements_multiplication, @@ -428,7 +420,7 @@ def _requirements_multiplication( assert it in {bt.Probability, bt.PositiveReal, bt.Real} return [it] * len(node.inputs) # pyre-ignore - def _requrirements_matrix_complement( + def _requirements_matrix_complement( self, node: bn.MatrixComplementNode ) -> List[bt.Requirement]: it = self.typer[node] @@ -446,6 +438,14 @@ def _requrirements_matrix_complement( req = [bt.SimplexMatrix] return req + def _requirements_elementwise_mult( + self, node: bn.ElementwiseMultiplyNode + ) -> List[bt.Requirement]: + # Elementwise multiply requires that both operands be the same as the output type. + it = self.typer[node] + assert isinstance(it, bt.BMGMatrixType) + return [it, it] + def _requirements_matrix_add(self, node: bn.MatrixAddNode) -> List[bt.Requirement]: # Matrix add requires that both operands be the same as the output type. it = self.typer[node] diff --git a/src/beanmachine/ppl/compiler/fix_requirements.py b/src/beanmachine/ppl/compiler/fix_requirements.py index 953f4802fa..1e97136f5c 100644 --- a/src/beanmachine/ppl/compiler/fix_requirements.py +++ b/src/beanmachine/ppl/compiler/fix_requirements.py @@ -11,7 +11,7 @@ returned.""" -from typing import Tuple +from typing import Optional import beanmachine.ppl.compiler.bmg_nodes as bn import beanmachine.ppl.compiler.bmg_types as bt @@ -160,9 +160,7 @@ def _meet_constant_requirement( result = self.bmg.add_constant_of_matrix_type(node.value, required_type) else: result = self.bmg.add_constant_of_type(node.value, required_type) - assert self._node_meets_requirement( - result, requirement - ), f"{str(result)} {str(requirement)} {str(required_type)} {str(self._typer[result])} {str(self._type_meets_requirement(self._typer[result], requirement))}" + assert self._node_meets_requirement(result, requirement) return result # We cannot convert this node to any type that meets the requirement. @@ -320,47 +318,103 @@ def _can_force_to_neg_real( or requirement == bt.upper_bound(bt.NegativeReal) ) and node_type == bt.Real - def _meet_real_matrix_requirement_type( - self, node: bn.OperatorNode, node_dim: Tuple[int, int] - ) -> bn.BMGNode: - if node_dim[0] == 1 and node_dim[1] == 1: - result = self.bmg.add_to_real(node) - else: - result = self.bmg.add_to_real_matrix(node) + def _try_to_meet_any_real_matrix_requirement( + self, + node: bn.OperatorNode, + requirement: bt.Requirement, + ) -> Optional[bn.BMGNode]: + + assert not self._node_meets_requirement(node, requirement) + + # Is the requirement that we have a real-valued matrix, but we haven't got + # a real-valued matrix? Every value can be converted to a real-valued matrix, + # so just insert the conversion node. + + if requirement is not bt.any_real_matrix: + return None + + result = self.bmg.add_to_real_matrix(node) + assert self._node_meets_requirement(result, requirement) return result - def _meet_real_matrix_requirement( + def _try_to_meet_any_pos_real_matrix_requirement( self, node: bn.OperatorNode, - dim_req: Tuple[int, int], - node_dim: Tuple[int, int], + requirement: bt.Requirement, + ) -> Optional[bn.BMGNode]: + + assert not self._node_meets_requirement(node, requirement) + + # Is the requirement that we have a pos-real-valued matrix? Anything that + # is not known to be negative can be a positive real matrix. + + if requirement is not bt.any_pos_real_matrix: + return None + + node_type = self._typer[node] + if isinstance(node_type, bt.NegativeRealMatrix): + return None + + result = self.bmg.add_to_positive_real_matrix(node) + assert self._node_meets_requirement(result, requirement) + return result + + def _try_to_meet_upper_bound_requirement( + self, + node: bn.OperatorNode, + requirement: bt.Requirement, consumer: bn.BMGNode, edge: str, - ) -> bn.BMGNode: - result = None - req_rows, req_cols = dim_req - node_rows, node_cols = node_dim - node_is_scalar = node_rows == 1 and node_cols == 1 - requires_scalar = req_rows == 1 and req_cols == 1 - if requires_scalar and node_is_scalar: - result = self.bmg.add_to_real(node) - elif node_rows == req_rows and node_cols == req_cols: - result = self.bmg.add_to_real_matrix(node) - - if result is None: - self.errors.add_error( - Violation( - node, - self._typer[node], - bt.RealMatrix(req_rows, req_cols), - consumer, - edge, - self.bmg.execution_context.node_locations(consumer), - ) + ) -> Optional[bn.BMGNode]: + + assert not self._node_meets_requirement(node, requirement) + + node_type = self._typer[node] + if not self._type_meets_requirement(node_type, bt.upper_bound(requirement)): + return None + + # If we got here then the node did NOT meet the requirement, + # but its type DID meet an upper bound requirement, which + # implies that the requirement was not an upper bound requirement. + assert not isinstance(requirement, bt.UpperBound) + + # We definitely can meet the requirement by inserting some sort + # of conversion logic. We have different helper methods for + # the atomic type and matrix type cases. + if bt.must_be_matrix(requirement): + result = self._convert_operator_to_matrix_type( + node, requirement, consumer, edge ) - return node + else: + assert isinstance(requirement, bt.BMGLatticeType) + result = self._convert_operator_to_atomic_type( + node, requirement, consumer, edge + ) + assert self._node_meets_requirement(result, requirement) return result + def _try_to_force_to_prob(self, node, requirement) -> Optional[bn.BMGNode]: + # We cannot make the node meet the requirement "implicitly". We can + # "explicitly" meet a requirement of probability if we have a + # real or pos real. + + node_type = self._typer[node] + if not self._can_force_to_prob(node_type, requirement): + return None + assert node_type == bt.Real or node_type == bt.PositiveReal + assert self._node_meets_requirement(node, node_type) + return self.bmg.add_to_probability(node) + + def _try_to_force_to_neg_real(self, node, requirement) -> Optional[bn.BMGNode]: + # We cannot make the node meet the requirement "implicitly". We can + # "explicitly" meet a requirement of neg real if we have a value we do + # not know is positive. + node_type = self._typer[node] + if not self._can_force_to_neg_real(node_type, requirement): + return None + + return self.bmg.add_to_negative_real(node) + def _meet_operator_requirement( self, node: bn.OperatorNode, @@ -368,79 +422,73 @@ def _meet_operator_requirement( consumer: bn.BMGNode, edge: str, ) -> bn.BMGNode: - # If the operator node already meets the requirement, we're done. + # We should not have called this function if the input node already meets + # the requirement on the edge. + assert not self._node_meets_requirement(node, requirement) - # It does not meet the requirement. Can we convert this thing to a node - # whose type does meet the requirement? The lattice type is the - # smallest type that this node is convertible to, so if the lattice type - # meets an upper bound requirement, then the conversion we want exists. + # ---- + # + # TODO: Is the problem that we have a scalar but we need a matrix full + # of that value? Generate a matrix fill operation. + # + # TODO: Is the problem that we have a row or column matrix but we need + # a rectangular matrix? Generate a broadcast operation. + # + # TODO: Note that in either of these cases, we might *also* need to + # generate a type conversion, so we might not meet the requirement on + # after introducing the fill / broadcast node. + # + # ---- + + # Is the requirement that we have a real-valued matrix, but we haven't got + # a real-valued matrix? Every value can be converted to a real-valued matrix, + # so that's the easiest case. Knock it out first. + + result = self._try_to_meet_any_real_matrix_requirement(node, requirement) + if result is not None: + return result + + # Is the requirement that we have any positive real-valued matrix? Every value + # except negative real scalars and matrices can be converted to a positive real + # matrix. + + result = self._try_to_meet_any_pos_real_matrix_requirement(node, requirement) + if result is not None: + return result + + # If we weaken the requirement to an upper bound requirement, do we meet it? If so, + # then there is a conversion node we can add. + + result = self._try_to_meet_upper_bound_requirement( + node, requirement, consumer, edge + ) + if result is not None: + return result + + result = self._try_to_force_to_prob(node, requirement) + if result is not None: + return result node_type = self._typer[node] - if isinstance(node_type, bt.BMGMatrixType): - rows = node_type.rows - columns = node_type.columns - else: - rows = 1 - columns = 1 - if isinstance(requirement, bt.RealMatrix): - return self._meet_real_matrix_requirement( + + result = self._try_to_force_to_neg_real(node, requirement) + if result is not None: + return result + + # Those are the only techniques we have to make an operator meet a requirement. + # We have no way to make the conversion we need, so add an error. + self.errors.add_error( + Violation( node, - dim_req=(requirement.rows, requirement.columns), - node_dim=(rows, columns), - consumer=consumer, - edge=edge, - ) - elif requirement == bt.RealMatrix: - return self._meet_real_matrix_requirement_type(node, (rows, columns)) - elif requirement is bt.any_real_matrix: - result = self.bmg.add_to_real_matrix(node) - elif requirement is bt.any_pos_real_matrix: - result = self.bmg.add_to_positive_real_matrix(node) - elif self._type_meets_requirement(node_type, bt.upper_bound(requirement)): - # If we got here then the node did NOT meet the requirement, - # but its type DID meet an upper bound requirement, which - # implies that the requirement was not an upper bound requirement. - assert not isinstance(requirement, bt.UpperBound) - - # We definitely can meet the requirement by inserting some sort - # of conversion logic. We have different helper methods for - # the atomic type and matrix type cases. - if bt.must_be_matrix(requirement): - result = self._convert_operator_to_matrix_type( - node, requirement, consumer, edge - ) - else: - assert isinstance(requirement, bt.BMGLatticeType) - result = self._convert_operator_to_atomic_type( - node, requirement, consumer, edge - ) - elif self._can_force_to_prob(node_type, requirement): - # We cannot make the node meet the requirement "implicitly". We can - # "explicitly" meet a requirement of probability if we have a - # real or pos real. - assert node_type == bt.Real or node_type == bt.PositiveReal - assert self._node_meets_requirement(node, node_type) - result = self.bmg.add_to_probability(node) - elif self._can_force_to_neg_real(node_type, requirement): - # Similarly if we have a real but need a negative real - result = self.bmg.add_to_negative_real(node) - else: - # We have no way to make the conversion we need, so add an error. - self.errors.add_error( - Violation( - node, - node_type, - requirement, - consumer, - edge, - self.bmg.execution_context.node_locations(consumer), - ) + node_type, + requirement, + consumer, + edge, + self.bmg.execution_context.node_locations(consumer), ) - return node - - assert self._node_meets_requirement(result, requirement) - return result + ) + return node def _check_requirement_validity( self, diff --git a/src/beanmachine/ppl/compiler/lattice_typer.py b/src/beanmachine/ppl/compiler/lattice_typer.py index 4ab1d1d220..488b8e3b5a 100644 --- a/src/beanmachine/ppl/compiler/lattice_typer.py +++ b/src/beanmachine/ppl/compiler/lattice_typer.py @@ -238,6 +238,11 @@ def _lattice_type_for_element_type( def _type_binary_elementwise_op( self, node: bn.BinaryOperatorNode ) -> bt.BMGLatticeType: + # Elementwise multiplication and addition require that the operands be + # of the same type and size, and that's the resulting type. Rather than + # enforcing that here, find the supremum of the element types and a size + # where both operands can be broadcast to that size. We'll then add the + # appropriate broadcast nodes in the requirements fixer. left_type = self[node.left] right_type = self[node.right] assert isinstance(left_type, bt.BMGMatrixType) diff --git a/tests/ppl/compiler/broadcast_test.py b/tests/ppl/compiler/broadcast_test.py index 2542137bb7..a0a9699920 100644 --- a/tests/ppl/compiler/broadcast_test.py +++ b/tests/ppl/compiler/broadcast_test.py @@ -32,6 +32,7 @@ def broadcast_add(): class BroadcastTest(unittest.TestCase): + # TODO: Test broadcast multiplication as well. def test_broadcast_add(self) -> None: self.maxDiff = None observations = {} diff --git a/tests/ppl/compiler/fix_vectorized_models_test.py b/tests/ppl/compiler/fix_vectorized_models_test.py index ae2262ba50..eda81cff0e 100644 --- a/tests/ppl/compiler/fix_vectorized_models_test.py +++ b/tests/ppl/compiler/fix_vectorized_models_test.py @@ -878,7 +878,7 @@ def test_fix_vectorized_models_8(self) -> None: N07[label=2]; N08[label=1]; N09[label=ToMatrix]; - N10[label=ToRealMatrix]; + N10[label=ToPosRealMatrix]; N11[label="[5.0,6.0]"]; N12[label=ElementwiseMult]; N13[label=Query];